Skip to content

Commit

Permalink
[RLlib] Unify the way we create local replay buffer for all agents (r…
Browse files Browse the repository at this point in the history
…ay-project#19627)

* [RLlib] Unify the way we create and use LocalReplayBuffer for all the agents.

This change
1. Get rid of the try...except clause when we call execution_plan(),
   and get rid of the Deprecation warning as a result.
2. Fix the execution_plan() call in Trainer._try_recover() too.
3. Most importantly, makes it much easier to create and use different types
   of local replay buffers for all our agents.
   E.g., allow us to easily create a reservoir sampling replay buffer for
   APPO agent for Riot in the near future.
* Introduce explicit configuration for replay buffer types.
* Fix is_training key error.
* actually deprecate buffer_size field.
  • Loading branch information
gjoliver authored Oct 26, 2021
1 parent ab15dfd commit 99a0088
Show file tree
Hide file tree
Showing 22 changed files with 182 additions and 119 deletions.
7 changes: 5 additions & 2 deletions rllib/agents/a3c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
)


def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the A2C algorithm. Defines the distributed
dataflow.
Expand All @@ -42,6 +42,9 @@ def execution_plan(workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
assert len(kwargs) == 0, (
"A2C execution_plan does NOT take any additional parameters")

rollouts = ParallelRollouts(workers, mode="bulk_sync")

if config["microbatch_size"]:
Expand Down
7 changes: 5 additions & 2 deletions rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def validate_config(config: TrainerConfigDict) -> None:
raise ValueError("`num_workers` for A3C must be >= 1!")


def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
Expand All @@ -91,6 +91,9 @@ def execution_plan(workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
assert len(kwargs) == 0, (
"A3C execution_plan does NOT take any additional parameters")

# For A3C, compute policy gradients remotely on the rollout workers.
grads = AsyncGradients(workers)

Expand Down
39 changes: 12 additions & 27 deletions rllib/agents/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,21 @@
from ray.rllib.agents.sac.sac import SACTrainer, \
DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
replay_buffer = None

# yapf: disable
# __sphinx_doc_begin__
Expand All @@ -48,7 +47,11 @@
"min_q_weight": 5.0,
# Replay buffer should be larger or equal the size of the offline
# dataset.
"buffer_size": int(1e6),
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"type": "LocalReplayBuffer",
"capacity": int(1e6),
},
})
# __sphinx_doc_end__
# yapf: enable
Expand All @@ -74,29 +77,11 @@ def validate_config(config: TrainerConfigDict):
try_import_tfp(error=True)


def execution_plan(workers, config):
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}

local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)

global replay_buffer
replay_buffer = local_replay_buffer
def execution_plan(workers, config, **kwargs):
assert "local_replay_buffer" in kwargs, (
"CQL execution plan requires a local replay buffer.")

local_replay_buffer = kwargs["local_replay_buffer"]

def update_prio(item):
samples, info_dict = item
Expand Down Expand Up @@ -150,8 +135,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:

def after_init(trainer):
# Add the entire dataset to Replay Buffer (global variable)
global replay_buffer
reader = trainer.workers.local_worker().input_reader
replay_buffer = trainer.local_replay_buffer

# For d4rl, add the D4RLReaders' dataset to the buffer.
if isinstance(trainer.config["input"], str) and \
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/cql/tests/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def test_cql_compilation(self):
# Example on how to do evaluation on the trained Trainer
# using the data from CQL's global replay buffer.
# Get a sample (MultiAgentBatch -> SampleBatch).
from ray.rllib.agents.cql.cql import replay_buffer
batch = replay_buffer.replay().policy_batches["default_policy"]
batch = trainer.local_replay_buffer.replay().policy_batches[
"default_policy"]

if fw == "torch":
obs = torch.from_numpy(batch["obs"])
Expand Down
10 changes: 8 additions & 2 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
"num_gpus": 1,
"num_workers": 32,
"buffer_size": 2000000,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": None,
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
Expand Down Expand Up @@ -141,8 +144,11 @@ def __call__(self, item: Tuple[ActorHandle, SampleBatchType]):
metrics.counters["num_weight_syncs"] += 1


def apex_execution_plan(workers: WorkerSet,
config: dict) -> LocalIterator[dict]:
def apex_execution_plan(workers: WorkerSet, config: dict,
**kwargs) -> LocalIterator[dict]:
assert len(kwargs) == 0, (
"Apex execution_plan does NOT take any additional parameters")

# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
replay_actors = create_colocated(ReplayActor, [
Expand Down
29 changes: 6 additions & 23 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \
Expand Down Expand Up @@ -145,8 +144,8 @@ def validate_config(config: TrainerConfigDict) -> None:
"simple_optimizer=True if this doesn't work for you.")


def execution_plan(trainer: Trainer, workers: WorkerSet,
config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the DQN algorithm. Defines the distributed dataflow.
Args:
Expand All @@ -158,28 +157,12 @@ def execution_plan(trainer: Trainer, workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}

local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)
assert "local_replay_buffer" in kwargs, (
"DQN execution plan requires a local replay buffer.")

# Assign to Trainer, so we can store the LocalReplayBuffer's
# data when we save checkpoints.
trainer.local_replay_buffer = local_replay_buffer
local_replay_buffer = kwargs["local_replay_buffer"]

rollouts = ParallelRollouts(workers, mode="bulk_sync")

Expand Down
28 changes: 13 additions & 15 deletions rllib/agents/dqn/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@

from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator

Expand Down Expand Up @@ -62,7 +62,11 @@
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": 50000,
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"type": "LocalReplayBuffer",
"capacity": 50000,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
# Warnings will be created if:
Expand Down Expand Up @@ -122,8 +126,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
return SimpleQTorchPolicy


def execution_plan(trainer: Trainer, workers: WorkerSet,
config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the Simple Q algorithm. Defines the distributed dataflow.
Args:
Expand All @@ -135,16 +139,10 @@ def execution_plan(trainer: Trainer, workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config["replay_sequence_length"])
# Assign to Trainer, so we can store the LocalReplayBuffer's
# data when we save checkpoints.
trainer.local_replay_buffer = local_replay_buffer
assert "local_replay_buffer" in kwargs, (
"SimpleQ execution plan requires a local replay buffer.")

local_replay_buffer = kwargs["local_replay_buffer"]

rollouts = ParallelRollouts(workers, mode="bulk_sync")

Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def policy_stats(self, fetches):
return fetches[DEFAULT_POLICY_ID]["learner_stats"]


def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"Dreamer execution_plan does NOT take any additional parameters")

# Special replay buffer for Dreamer agent.
episode_buffer = EpisodicBuffer(length=config["batch_length"])

Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,10 @@ def gather_experiences_directly(workers, config):
return train_batches


def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"IMPALA execution_plan does NOT take any additional parameters")

if config["num_aggregation_workers"] > 0:
train_batches = gather_experiences_tree_aggregation(workers, config)
else:
Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/maml/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def inner_adaptation(workers, samples):
e.learn_on_batch.remote(samples[i])


def execution_plan(workers, config):
def execution_plan(workers, config, **kwargs):
assert len(kwargs) == 0, (
"MAML execution_plan does NOT take any additional parameters")

# Sync workers with meta policy
workers.sync_weights()

Expand Down
7 changes: 5 additions & 2 deletions rllib/agents/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
return MARWILTorchPolicy


def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
dataflow.
Expand All @@ -103,6 +103,9 @@ def execution_plan(workers: WorkerSet,
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
assert len(kwargs) == 0, (
"Marwill execution_plan does NOT take any additional parameters")

rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = LocalReplayBuffer(
learning_starts=config["learning_starts"],
Expand Down
7 changes: 5 additions & 2 deletions rllib/agents/mbmpo/mbmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ def post_process_samples(samples, config: TrainerConfigDict):
return samples, split_lst


def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
Args:
Expand All @@ -349,6 +349,9 @@ def execution_plan(workers: WorkerSet,
LocalIterator[dict]: The Policy class to use with PPOTrainer.
If None, use `default_policy` provided in build_trainer().
"""
assert len(kwargs) == 0, (
"MBMPO execution_plan does NOT take any additional parameters")

# Train TD Models on the driver.
workers.local_worker().foreach_policy(fit_dynamics)

Expand Down
7 changes: 5 additions & 2 deletions rllib/agents/ppo/ddppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def validate_config(config):
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")


def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the DD-PPO algorithm. Defines the distributed dataflow.
Args:
Expand All @@ -159,6 +159,9 @@ def execution_plan(workers: WorkerSet,
LocalIterator[dict]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
assert len(kwargs) == 0, (
"DDPPO execution_plan does NOT take any additional parameters")

rollouts = ParallelRollouts(workers, mode="raw")

# Setup the distributed processes.
Expand Down
7 changes: 5 additions & 2 deletions rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def warn_about_bad_reward_scales(config, result):
return result


def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs) -> LocalIterator[dict]:
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
Args:
Expand All @@ -266,6 +266,9 @@ def execution_plan(workers: WorkerSet,
LocalIterator[dict]: The Policy class to use with PPOTrainer.
If None, use `default_policy` provided in build_trainer().
"""
assert len(kwargs) == 0, (
"PPO execution_plan does NOT take any additional parameters")

rollouts = ParallelRollouts(workers, mode="bulk_sync")

# Collect batches for the trainable policies.
Expand Down
Loading

0 comments on commit 99a0088

Please sign in to comment.