From 99a00882337e85589a8fbc193b8ec77846a4dd6a Mon Sep 17 00:00:00 2001 From: gjoliver Date: Tue, 26 Oct 2021 11:56:02 -0700 Subject: [PATCH] [RLlib] Unify the way we create local replay buffer for all agents (#19627) * [RLlib] Unify the way we create and use LocalReplayBuffer for all the agents. This change 1. Get rid of the try...except clause when we call execution_plan(), and get rid of the Deprecation warning as a result. 2. Fix the execution_plan() call in Trainer._try_recover() too. 3. Most importantly, makes it much easier to create and use different types of local replay buffers for all our agents. E.g., allow us to easily create a reservoir sampling replay buffer for APPO agent for Riot in the near future. * Introduce explicit configuration for replay buffer types. * Fix is_training key error. * actually deprecate buffer_size field. --- rllib/agents/a3c/a2c.py | 7 +- rllib/agents/a3c/a3c.py | 7 +- rllib/agents/cql/cql.py | 39 ++++------- rllib/agents/cql/tests/test_cql.py | 4 +- rllib/agents/dqn/apex.py | 10 ++- rllib/agents/dqn/dqn.py | 29 ++------ rllib/agents/dqn/simple_q.py | 28 ++++---- rllib/agents/dreamer/dreamer.py | 5 +- rllib/agents/impala/impala.py | 5 +- rllib/agents/maml/maml.py | 5 +- rllib/agents/marwil/marwil.py | 7 +- rllib/agents/mbmpo/mbmpo.py | 7 +- rllib/agents/ppo/ddppo.py | 7 +- rllib/agents/ppo/ppo.py | 7 +- rllib/agents/qmix/qmix.py | 5 +- rllib/agents/sac/sac_tf_model.py | 2 +- rllib/agents/sac/sac_torch_model.py | 2 +- rllib/agents/slateq/slateq.py | 26 ++++--- rllib/agents/trainer.py | 69 ++++++++++++++++++- rllib/agents/trainer_template.py | 21 ++---- .../alpha_zero/core/alpha_zero_trainer.py | 5 +- rllib/examples/random_parametric_agent.py | 4 +- 22 files changed, 182 insertions(+), 119 deletions(-) diff --git a/rllib/agents/a3c/a2c.py b/rllib/agents/a3c/a2c.py index bc46620a3f62..44f31392f801 100644 --- a/rllib/agents/a3c/a2c.py +++ b/rllib/agents/a3c/a2c.py @@ -29,8 +29,8 @@ ) -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the A2C algorithm. Defines the distributed dataflow. @@ -42,6 +42,9 @@ def execution_plan(workers: WorkerSet, Returns: LocalIterator[dict]: A local iterator over training metrics. """ + assert len(kwargs) == 0, ( + "A2C execution_plan does NOT take any additional parameters") + rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["microbatch_size"]: diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index cc51ddcb0773..3ce8866edd0b 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -78,8 +78,8 @@ def validate_config(config: TrainerConfigDict) -> None: raise ValueError("`num_workers` for A3C must be >= 1!") -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the MARWIL/BC algorithm. Defines the distributed dataflow. @@ -91,6 +91,9 @@ def execution_plan(workers: WorkerSet, Returns: LocalIterator[dict]: A local iterator over training metrics. """ + assert len(kwargs) == 0, ( + "A3C execution_plan does NOT take any additional parameters") + # For A3C, compute policy gradients remotely on the rollout workers. grads = AsyncGradients(workers) diff --git a/rllib/agents/cql/cql.py b/rllib/agents/cql/cql.py index 3799592d845d..3d7e7e1ed441 100644 --- a/rllib/agents/cql/cql.py +++ b/rllib/agents/cql/cql.py @@ -9,7 +9,6 @@ from ray.rllib.agents.sac.sac import SACTrainer, \ DEFAULT_CONFIG as SAC_CONFIG from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.execution.replay_ops import Replay from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \ UpdateTargetNetwork @@ -17,6 +16,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import merge_dicts +from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import TrainerConfigDict @@ -24,7 +24,6 @@ tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() logger = logging.getLogger(__name__) -replay_buffer = None # yapf: disable # __sphinx_doc_begin__ @@ -48,7 +47,11 @@ "min_q_weight": 5.0, # Replay buffer should be larger or equal the size of the offline # dataset. - "buffer_size": int(1e6), + "buffer_size": DEPRECATED_VALUE, + "replay_buffer_config": { + "type": "LocalReplayBuffer", + "capacity": int(1e6), + }, }) # __sphinx_doc_end__ # yapf: enable @@ -74,29 +77,11 @@ def validate_config(config: TrainerConfigDict): try_import_tfp(error=True) -def execution_plan(workers, config): - if config.get("prioritized_replay"): - prio_args = { - "prioritized_replay_alpha": config["prioritized_replay_alpha"], - "prioritized_replay_beta": config["prioritized_replay_beta"], - "prioritized_replay_eps": config["prioritized_replay_eps"], - } - else: - prio_args = {} - - local_replay_buffer = LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - capacity=config["buffer_size"], - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config.get("replay_sequence_length", 1), - replay_burn_in=config.get("burn_in", 0), - replay_zero_init_states=config.get("zero_init_states", True), - **prio_args) - - global replay_buffer - replay_buffer = local_replay_buffer +def execution_plan(workers, config, **kwargs): + assert "local_replay_buffer" in kwargs, ( + "CQL execution plan requires a local replay buffer.") + + local_replay_buffer = kwargs["local_replay_buffer"] def update_prio(item): samples, info_dict = item @@ -150,8 +135,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: def after_init(trainer): # Add the entire dataset to Replay Buffer (global variable) - global replay_buffer reader = trainer.workers.local_worker().input_reader + replay_buffer = trainer.local_replay_buffer # For d4rl, add the D4RLReaders' dataset to the buffer. if isinstance(trainer.config["input"], str) and \ diff --git a/rllib/agents/cql/tests/test_cql.py b/rllib/agents/cql/tests/test_cql.py index 7e3ef58896f6..f90dc3a3a679 100644 --- a/rllib/agents/cql/tests/test_cql.py +++ b/rllib/agents/cql/tests/test_cql.py @@ -88,8 +88,8 @@ def test_cql_compilation(self): # Example on how to do evaluation on the trained Trainer # using the data from CQL's global replay buffer. # Get a sample (MultiAgentBatch -> SampleBatch). - from ray.rllib.agents.cql.cql import replay_buffer - batch = replay_buffer.replay().policy_batches["default_policy"] + batch = trainer.local_replay_buffer.replay().policy_batches[ + "default_policy"] if fw == "torch": obs = torch.from_numpy(batch["obs"]) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 29c71586977c..74124e7302de 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -55,6 +55,9 @@ "num_gpus": 1, "num_workers": 32, "buffer_size": 2000000, + # TODO(jungong) : add proper replay_buffer_config after + # DistributedReplayBuffer type is supported. + "replay_buffer_config": None, "learning_starts": 50000, "train_batch_size": 512, "rollout_fragment_length": 50, @@ -141,8 +144,11 @@ def __call__(self, item: Tuple[ActorHandle, SampleBatchType]): metrics.counters["num_weight_syncs"] += 1 -def apex_execution_plan(workers: WorkerSet, - config: dict) -> LocalIterator[dict]: +def apex_execution_plan(workers: WorkerSet, config: dict, + **kwargs) -> LocalIterator[dict]: + assert len(kwargs) == 0, ( + "Apex execution_plan does NOT take any additional parameters") + # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] replay_actors = create_colocated(ReplayActor, [ diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 1e4ac0a99e1b..8da70b9cb0ae 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -20,7 +20,6 @@ from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.concurrency_ops import Concurrently from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \ @@ -145,8 +144,8 @@ def validate_config(config: TrainerConfigDict) -> None: "simple_optimizer=True if this doesn't work for you.") -def execution_plan(trainer: Trainer, workers: WorkerSet, - config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the DQN algorithm. Defines the distributed dataflow. Args: @@ -158,28 +157,12 @@ def execution_plan(trainer: Trainer, workers: WorkerSet, Returns: LocalIterator[dict]: A local iterator over training metrics. """ - if config.get("prioritized_replay"): - prio_args = { - "prioritized_replay_alpha": config["prioritized_replay_alpha"], - "prioritized_replay_beta": config["prioritized_replay_beta"], - "prioritized_replay_eps": config["prioritized_replay_eps"], - } - else: - prio_args = {} - - local_replay_buffer = LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - capacity=config["buffer_size"], - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config.get("replay_sequence_length", 1), - replay_burn_in=config.get("burn_in", 0), - replay_zero_init_states=config.get("zero_init_states", True), - **prio_args) + assert "local_replay_buffer" in kwargs, ( + "DQN execution plan requires a local replay buffer.") + # Assign to Trainer, so we can store the LocalReplayBuffer's # data when we save checkpoints. - trainer.local_replay_buffer = local_replay_buffer + local_replay_buffer = kwargs["local_replay_buffer"] rollouts = ParallelRollouts(workers, mode="bulk_sync") diff --git a/rllib/agents/dqn/simple_q.py b/rllib/agents/dqn/simple_q.py index 390093e4cacc..d3dda026db66 100644 --- a/rllib/agents/dqn/simple_q.py +++ b/rllib/agents/dqn/simple_q.py @@ -14,17 +14,17 @@ from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.concurrency_ops import Concurrently from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \ UpdateTargetNetwork from ray.rllib.policy.policy import Policy +from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -62,7 +62,11 @@ # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then # each worker will have a replay buffer of this size. - "buffer_size": 50000, + "buffer_size": DEPRECATED_VALUE, + "replay_buffer_config": { + "type": "LocalReplayBuffer", + "capacity": 50000, + }, # Set this to True, if you want the contents of your buffer(s) to be # stored in any saved checkpoints as well. # Warnings will be created if: @@ -122,8 +126,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: return SimpleQTorchPolicy -def execution_plan(trainer: Trainer, workers: WorkerSet, - config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the Simple Q algorithm. Defines the distributed dataflow. Args: @@ -135,16 +139,10 @@ def execution_plan(trainer: Trainer, workers: WorkerSet, Returns: LocalIterator[dict]: A local iterator over training metrics. """ - local_replay_buffer = LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - capacity=config["buffer_size"], - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config["replay_sequence_length"]) - # Assign to Trainer, so we can store the LocalReplayBuffer's - # data when we save checkpoints. - trainer.local_replay_buffer = local_replay_buffer + assert "local_replay_buffer" in kwargs, ( + "SimpleQ execution plan requires a local replay buffer.") + + local_replay_buffer = kwargs["local_replay_buffer"] rollouts = ParallelRollouts(workers, mode="bulk_sync") diff --git a/rllib/agents/dreamer/dreamer.py b/rllib/agents/dreamer/dreamer.py index b3433f62cd5a..b31c774a6260 100644 --- a/rllib/agents/dreamer/dreamer.py +++ b/rllib/agents/dreamer/dreamer.py @@ -185,7 +185,10 @@ def policy_stats(self, fetches): return fetches[DEFAULT_POLICY_ID]["learner_stats"] -def execution_plan(workers, config): +def execution_plan(workers, config, **kwargs): + assert len(kwargs) == 0, ( + "Dreamer execution_plan does NOT take any additional parameters") + # Special replay buffer for Dreamer agent. episode_buffer = EpisodicBuffer(length=config["batch_length"]) diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index 514040e701d8..cd0613e83f8c 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -312,7 +312,10 @@ def gather_experiences_directly(workers, config): return train_batches -def execution_plan(workers, config): +def execution_plan(workers, config, **kwargs): + assert len(kwargs) == 0, ( + "IMPALA execution_plan does NOT take any additional parameters") + if config["num_aggregation_workers"] > 0: train_batches = gather_experiences_tree_aggregation(workers, config) else: diff --git a/rllib/agents/maml/maml.py b/rllib/agents/maml/maml.py index 42e62a19b87f..c41d97ed4f9b 100644 --- a/rllib/agents/maml/maml.py +++ b/rllib/agents/maml/maml.py @@ -156,7 +156,10 @@ def inner_adaptation(workers, samples): e.learn_on_batch.remote(samples[i]) -def execution_plan(workers, config): +def execution_plan(workers, config, **kwargs): + assert len(kwargs) == 0, ( + "MAML execution_plan does NOT take any additional parameters") + # Sync workers with meta policy workers.sync_weights() diff --git a/rllib/agents/marwil/marwil.py b/rllib/agents/marwil/marwil.py index 5df9bfb19014..c4a5fa4029be 100644 --- a/rllib/agents/marwil/marwil.py +++ b/rllib/agents/marwil/marwil.py @@ -90,8 +90,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: return MARWILTorchPolicy -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the MARWIL/BC algorithm. Defines the distributed dataflow. @@ -103,6 +103,9 @@ def execution_plan(workers: WorkerSet, Returns: LocalIterator[dict]: A local iterator over training metrics. """ + assert len(kwargs) == 0, ( + "Marwill execution_plan does NOT take any additional parameters") + rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = LocalReplayBuffer( learning_starts=config["learning_starts"], diff --git a/rllib/agents/mbmpo/mbmpo.py b/rllib/agents/mbmpo/mbmpo.py index 3e087a8bded4..268013bdace4 100644 --- a/rllib/agents/mbmpo/mbmpo.py +++ b/rllib/agents/mbmpo/mbmpo.py @@ -336,8 +336,8 @@ def post_process_samples(samples, config: TrainerConfigDict): return samples, split_lst -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the PPO algorithm. Defines the distributed dataflow. Args: @@ -349,6 +349,9 @@ def execution_plan(workers: WorkerSet, LocalIterator[dict]: The Policy class to use with PPOTrainer. If None, use `default_policy` provided in build_trainer(). """ + assert len(kwargs) == 0, ( + "MBMPO execution_plan does NOT take any additional parameters") + # Train TD Models on the driver. workers.local_worker().foreach_policy(fit_dynamics) diff --git a/rllib/agents/ppo/ddppo.py b/rllib/agents/ppo/ddppo.py index 46ab1291c089..2138be48c502 100644 --- a/rllib/agents/ppo/ddppo.py +++ b/rllib/agents/ppo/ddppo.py @@ -146,8 +146,8 @@ def validate_config(config): raise ValueError("DDPPO doesn't support KL penalties like PPO-1") -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the DD-PPO algorithm. Defines the distributed dataflow. Args: @@ -159,6 +159,9 @@ def execution_plan(workers: WorkerSet, LocalIterator[dict]: The Policy class to use with PGTrainer. If None, use `default_policy` provided in build_trainer(). """ + assert len(kwargs) == 0, ( + "DDPPO execution_plan does NOT take any additional parameters") + rollouts = ParallelRollouts(workers, mode="raw") # Setup the distributed processes. diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index e43d460087b8..5f168dc91e2e 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -253,8 +253,8 @@ def warn_about_bad_reward_scales(config, result): return result -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the PPO algorithm. Defines the distributed dataflow. Args: @@ -266,6 +266,9 @@ def execution_plan(workers: WorkerSet, LocalIterator[dict]: The Policy class to use with PPOTrainer. If None, use `default_policy` provided in build_trainer(). """ + assert len(kwargs) == 0, ( + "PPO execution_plan does NOT take any additional parameters") + rollouts = ParallelRollouts(workers, mode="bulk_sync") # Collect batches for the trainable policies. diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index 48341f6ceec4..646ae032d96c 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -100,7 +100,10 @@ # yapf: enable -def execution_plan(workers, config): +def execution_plan(workers, config, **kwargs): + assert len(kwargs) == 0, ( + "QMIX execution_plan does NOT take any additional parameters") + rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["buffer_size"]) diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index 725fd3a5e663..0b78f65a526f 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -248,7 +248,7 @@ def _get_q_value(self, model_out, actions, net): input_dict = {"obs": model_out} # Switch on training mode (when getting Q-values, we are usually in # training). - input_dict.is_training = True + input_dict["is_training"] = True out, _ = net(input_dict, [], None) return out diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index 808d891f0d34..1fdc09412da1 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -256,7 +256,7 @@ def _get_q_value(self, model_out, actions, net): input_dict = {"obs": model_out} # Switch on training mode (when getting Q-values, we are usually in # training). - input_dict.is_training = True + input_dict["is_training"] = True out, _ = net(input_dict, [], None) return out diff --git a/rllib/agents/slateq/slateq.py b/rllib/agents/slateq/slateq.py index 8a0cb2d6be9c..f16c4954d783 100644 --- a/rllib/agents/slateq/slateq.py +++ b/rllib/agents/slateq/slateq.py @@ -22,11 +22,11 @@ from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.execution.concurrency_ops import Concurrently from ray.rllib.execution.metric_ops import StandardMetricsReporting -from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import TrainOneStep from ray.rllib.policy.policy import Policy +from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator @@ -82,7 +82,11 @@ # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then # each worker will have a replay buffer of this size. - "buffer_size": 50000, + "buffer_size": DEPRECATED_VALUE, + "replay_buffer_config": { + "type": "LocalReplayBuffer", + "capacity": 50000, + }, # The number of contiguous environment steps to replay at once. This may # be set to greater than 1 to support recurrent models. "replay_sequence_length": 1, @@ -152,8 +156,8 @@ def validate_config(config: TrainerConfigDict) -> None: "For SARSA strategy, batch_mode must be 'complete_episodes'") -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: """Execution plan of the SlateQ algorithm. Defines the distributed dataflow. Args: @@ -164,14 +168,8 @@ def execution_plan(workers: WorkerSet, Returns: LocalIterator[dict]: A local iterator over training metrics. """ - local_replay_buffer = LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - capacity=config["buffer_size"], - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config["replay_sequence_length"], - ) + assert "local_replay_buffer" in kwargs, ( + "SlateQ execution plan requires a local replay buffer.") rollouts = ParallelRollouts(workers, mode="bulk_sync") @@ -179,12 +177,12 @@ def execution_plan(workers: WorkerSet, # (1) Generate rollouts and store them in our local replay buffer. Calling # next() on store_op drives this. store_op = rollouts.for_each( - StoreToReplayBuffer(local_buffer=local_replay_buffer)) + StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"])) # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step. - replay_op = Replay(local_buffer=local_replay_buffer) \ + replay_op = Replay(local_buffer=kwargs["local_replay_buffer"]) \ .for_each(TrainOneStep(workers)) if config["slateq_strategy"] != "RANDOM": diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index a1f4b64ee242..44b23dbc7a68 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -23,6 +23,7 @@ from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch @@ -706,6 +707,67 @@ def log_result(self, result: ResultDict): # to mutate the result Trainable.log_result(self, result) + @DeveloperAPI + def _create_local_replay_buffer_if_necessary(self, config): + """Create a LocalReplayBuffer instance if necessary. + + Args: + config (dict): Algorithm-specific configuration data. + + Returns: + LocalReplayBuffer instance based on trainer config. + None, if local replay buffer is not needed. + """ + # These are the agents that utilizes a local replay buffer. + if ("replay_buffer_config" not in config + or not config["replay_buffer_config"]): + # Does not need a replay buffer. + return None + + replay_buffer_config = config["replay_buffer_config"] + if ("type" not in replay_buffer_config + or replay_buffer_config["type"] != "LocalReplayBuffer"): + # DistributedReplayBuffer coming soon. + return None + + capacity = config.get("buffer_size", DEPRECATED_VALUE) + if capacity != DEPRECATED_VALUE: + # Print a deprecation warning. + deprecation_warning( + old="config['buffer_size']", + new="config['replay_buffer_config']['capacity']", + error=False) + else: + # Get capacity out of replay_buffer_config. + capacity = replay_buffer_config["capacity"] + + if config.get("prioritized_replay"): + prio_args = { + "prioritized_replay_alpha": config["prioritized_replay_alpha"], + "prioritized_replay_beta": config["prioritized_replay_beta"], + "prioritized_replay_eps": config["prioritized_replay_eps"], + } + else: + prio_args = {} + + return LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + capacity=capacity, + replay_batch_size=config["train_batch_size"], + replay_mode=config["multiagent"]["replay_mode"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), + **prio_args) + + @DeveloperAPI + def _kwargs_for_execution_plan(self): + kwargs = {} + if self.local_replay_buffer: + kwargs["local_replay_buffer"] = self.local_replay_buffer + return kwargs + @override(Trainable) def setup(self, config: PartialTrainerConfigDict): env = self._env_id @@ -773,6 +835,10 @@ def env_creator_from_classpath(env_context): if self.config.get("log_level"): logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) + # Create local replay buffer if necessary. + self.local_replay_buffer = ( + self._create_local_replay_buffer_if_necessary(self.config)) + self._init(self.config, self.env_creator) # Evaluation setup. @@ -1747,7 +1813,8 @@ def _try_recover(self): logger.warning("Recreating execution plan after failure") workers.reset(healthy_workers) - self.train_exec_impl = self.execution_plan(workers, self.config) + self.train_exec_impl = self.execution_plan( + workers, self.config, **self._kwargs_for_execution_plan()) @override(Trainable) def _export_model(self, export_formats: List[str], diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index ce92d4f4840a..b3a8ff71c29c 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -18,7 +18,11 @@ logger = logging.getLogger(__name__) -def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict): +def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs): + assert len(kwargs) == 0, ( + "Default execution_plan does NOT take any additional parameters") + # Collects experiences in parallel from multiple RolloutWorker actors. rollouts = ParallelRollouts(workers, mode="bulk_sync") @@ -175,19 +179,8 @@ def _init(self, config: TrainerConfigDict, config=config, num_workers=self.config["num_workers"]) self.execution_plan = execution_plan - try: - self.train_exec_impl = execution_plan(self, self.workers, - config) - except TypeError as e: - # Keyword error: Try old way w/o kwargs. - if "() takes 2 positional arguments but 3" in e.args[0]: - self.train_exec_impl = execution_plan(self.workers, config) - logger.warning( - "`execution_plan` functions should accept " - "`trainer`, `workers`, and `config` as args!") - # Other error -> re-raise. - else: - raise e + self.train_exec_impl = execution_plan( + self.workers, config, **self._kwargs_for_execution_plan()) if after_init: after_init(self) diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py index 4af65eba6413..82c46c85cd8c 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_trainer.py @@ -160,7 +160,10 @@ def mcts_creator(): _env_creator) -def execution_plan(workers, config): +def execution_plan(workers, config, **kwargs): + assert len(kwargs) == 0, ( + "Alpha zero execution_plan does NOT take any additional parameters") + rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["simple_optimizer"]: diff --git a/rllib/examples/random_parametric_agent.py b/rllib/examples/random_parametric_agent.py index 535466376d08..62c9dfe1ca76 100644 --- a/rllib/examples/random_parametric_agent.py +++ b/rllib/examples/random_parametric_agent.py @@ -63,8 +63,8 @@ def set_weights(self, weights): pass -def execution_plan(workers: WorkerSet, - config: TrainerConfigDict) -> LocalIterator[dict]: +def execution_plan(workers: WorkerSet, config: TrainerConfigDict, + **kwargs) -> LocalIterator[dict]: rollouts = ParallelRollouts(workers, mode="async") # Collect batches for the trainable policies.