From 7589bd56803d9d591809fba83b9c782bc06e9333 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 8 May 2024 13:39:58 +0200 Subject: [PATCH] [RLlib] Change (prioritized) SA episode buffer to return episode lists (instead of batch) from `sample()`. (#45123) --- rllib/BUILD | 2 +- rllib/algorithms/dqn/dqn.py | 22 +- rllib/algorithms/dqn/dqn_rainbow_learner.py | 23 +- .../dqn/torch/dqn_rainbow_torch_learner.py | 4 +- rllib/algorithms/ppo/ppo_learner.py | 1 + .../algorithms/sac/torch/sac_torch_learner.py | 2 +- rllib/connectors/learner/__init__.py | 4 + ...servations_from_episodes_to_train_batch.py | 116 +++++ rllib/core/learner/learner.py | 2 +- .../multi_agent_episode_replay_buffer.py | 4 +- .../prioritized_episode_replay_buffer.py | 449 +++--------------- .../test_multi_agent_episode_replay_buffer.py | 6 +- .../test_prioritized_episode_replay_buffer.py | 10 +- rllib/utils/replay_buffers/utils.py | 36 +- 14 files changed, 256 insertions(+), 425 deletions(-) create mode 100644 rllib/connectors/learner/add_next_observations_from_episodes_to_train_batch.py diff --git a/rllib/BUILD b/rllib/BUILD index e956d5ea116a..406175ff7538 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -283,7 +283,7 @@ py_test( # ) py_test( - name = "learning_tests_carpole_dqn_envrunner", + name = "learning_tests_cartpole_dqn_envrunner", main = "tests/run_regression_tests.py", tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete"], size = "large", diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index de66606d94ad..58e0a0f29ecc 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -24,7 +24,7 @@ from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, ) -from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.execution.train_ops import ( train_one_step, multi_gpu_train_one_step, @@ -656,24 +656,20 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: self.learner_group.foreach_learner(lambda lrnr: lrnr._reset_noise()) # Run multiple sample-from-buffer and update iterations. for _ in range(sample_and_train_weight): - # Sample training batch from replay_buffer. - # TODO (simon): Use sample_with_keys() here. + # Sample a list of episodes used for learning from the replay buffer. with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)): - train_dict = self.local_replay_buffer.sample( + episodes = self.local_replay_buffer.sample( num_items=self.config.train_batch_size, n_step=self.config.n_step, gamma=self.config.gamma, beta=self.config.replay_buffer_config["beta"], ) - train_batch = SampleBatch(train_dict) - # Convert to multi-agent batch as `LearnerGroup` depends on it. - # TODO (sven, simon): Remove this conversion once the `LearnerGroup` - # supports dict. - train_batch = train_batch.as_multi_agent() # Perform an update on the buffer-sampled train batch. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): - learner_results = self.learner_group.update_from_batch(train_batch) + learner_results = self.learner_group.update_from_episodes( + episodes=episodes, + ) # Isolate TD-errors from result dicts (we should not log these to # disk or WandB, they might be very large). td_errors = defaultdict(list) @@ -713,10 +709,8 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: # Update replay buffer priorities. with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)): update_priorities_in_episode_replay_buffer( - self.local_replay_buffer, - self.config, - train_batch, - td_errors, + replay_buffer=self.local_replay_buffer, + td_errors=td_errors, ) # Update the target networks, if necessary. diff --git a/rllib/algorithms/dqn/dqn_rainbow_learner.py b/rllib/algorithms/dqn/dqn_rainbow_learner.py index b58f544a536b..1aba7f757008 100644 --- a/rllib/algorithms/dqn/dqn_rainbow_learner.py +++ b/rllib/algorithms/dqn/dqn_rainbow_learner.py @@ -3,7 +3,16 @@ from typing import TYPE_CHECKING from ray.rllib.core.learner.learner import Learner -from ray.rllib.utils.annotations import override +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) from ray.rllib.utils.metrics import LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES from ray.rllib.utils.typing import ModuleID @@ -28,6 +37,18 @@ class DQNRainbowLearner(Learner): + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def build(self) -> None: + super().build() + # Prepend a NEXT_OBS from episodes to train batch connector piece (right + # after the observation default piece). + if self.config.add_default_connectors_to_learner_pipeline: + self._learner_connector.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + @override(Learner) def additional_update_for_module( self, *, module_id: ModuleID, config: "DQNConfig", timestep: int, **kwargs diff --git a/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py b/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py index 112f5d8baee0..0a887908b114 100644 --- a/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py +++ b/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py @@ -121,7 +121,7 @@ def compute_loss_for_module( r_tau = torch.clamp( batch[Columns.REWARDS].unsqueeze(dim=-1) + ( - config.gamma ** batch["n_steps"] + config.gamma ** batch["n_step"] * (1.0 - batch[Columns.TERMINATEDS].float()) ).unsqueeze(dim=-1) * z, @@ -171,7 +171,7 @@ def compute_loss_for_module( # backpropagate through the target network when optimizing the Q loss. q_selected_target = ( batch[Columns.REWARDS] - + (config.gamma ** batch["n_steps"]) * q_next_best_masked + + (config.gamma ** batch["n_step"]) * q_next_best_masked ).detach() # Choose the requested loss function. Note, in case of the Huber loss diff --git a/rllib/algorithms/ppo/ppo_learner.py b/rllib/algorithms/ppo/ppo_learner.py index b4ca1d786f7e..5070859aa0e6 100644 --- a/rllib/algorithms/ppo/ppo_learner.py +++ b/rllib/algorithms/ppo/ppo_learner.py @@ -62,6 +62,7 @@ def _update_from_batch_or_episodes( # episodes). if self.config.enable_env_runner_and_connector_v2: batch, episodes = self._compute_gae_from_episodes(episodes=episodes) + # Now that GAE (advantages and value targets) have been added to the train # batch, we can proceed normally (calling super method) with the update step. return super()._update_from_batch_or_episodes( diff --git a/rllib/algorithms/sac/torch/sac_torch_learner.py b/rllib/algorithms/sac/torch/sac_torch_learner.py index 9f452c627e6f..0565e950136b 100644 --- a/rllib/algorithms/sac/torch/sac_torch_learner.py +++ b/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -204,7 +204,7 @@ def compute_loss_for_module( # Detach this node from the computation graph as we do not want to # backpropagate through the target network when optimizing the Q loss. q_selected_target = ( - batch[Columns.REWARDS] + (config.gamma ** batch["n_steps"]) * q_next_masked + batch[Columns.REWARDS] + (config.gamma ** batch["n_step"]) * q_next_masked ).detach() # Calculate the TD-error. Note, this is needed for the priority weights in diff --git a/rllib/connectors/learner/__init__.py b/rllib/connectors/learner/__init__.py index 8f117a6261e2..33ea3c80c4f6 100644 --- a/rllib/connectors/learner/__init__.py +++ b/rllib/connectors/learner/__init__.py @@ -10,12 +10,16 @@ from ray.rllib.connectors.learner.add_columns_from_episodes_to_train_batch import ( AddColumnsFromEpisodesToTrainBatch, ) +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) from ray.rllib.connectors.learner.learner_connector_pipeline import ( LearnerConnectorPipeline, ) __all__ = [ "AddColumnsFromEpisodesToTrainBatch", + "AddNextObservationsFromEpisodesToTrainBatch", "AddObservationsFromEpisodesToBatch", "AddStatesFromEpisodesToBatch", "AgentToModuleMapping", diff --git a/rllib/connectors/learner/add_next_observations_from_episodes_to_train_batch.py b/rllib/connectors/learner/add_next_observations_from_episodes_to_train_batch.py new file mode 100644 index 000000000000..4812ca43c524 --- /dev/null +++ b/rllib/connectors/learner/add_next_observations_from_episodes_to_train_batch.py @@ -0,0 +1,116 @@ +from typing import Any, List, Optional + +import gymnasium as gym + +from ray.rllib.core.columns import Columns +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import EpisodeType + + +class AddNextObservationsFromEpisodesToTrainBatch(ConnectorV2): + """Adds the NEXT_OBS column with the correct episode observations to train batch. + + - Operates on a list of Episode objects. + - Gets all observation(s) from all the given episodes (except the very first ones) + and adds them to the batch under construction in the NEXT_OBS column (as a list of + individual observations). + - Does NOT alter any observations (or other data) in the given episodes. + - Can be used in Learner connector pipelines. + + .. testcode:: + + import gymnasium as gym + import numpy as np + + from ray.rllib.connectors.learner import ( + AddNextObservationsFromEpisodesToTrainBatch + ) + from ray.rllib.core.columns import Columns + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check + + # Create two dummy SingleAgentEpisodes, each containing 3 observations, + # 2 actions and 2 rewards (both episodes are length=2). + obs_space = gym.spaces.Box(-1.0, 1.0, (2,), np.float32) + act_space = gym.spaces.Discrete(2) + + episodes = [SingleAgentEpisode( + observations=[obs_space.sample(), obs_space.sample(), obs_space.sample()], + actions=[act_space.sample(), act_space.sample()], + rewards=[1.0, 2.0], + len_lookback_buffer=0, + ) for _ in range(2)] + eps_1_next_obses = episodes[0].get_observations([1, 2]) + eps_2_next_obses = episodes[1].get_observations([1, 2]) + print(f"1st Episode's next obses are {eps_1_next_obses}") + print(f"2nd Episode's next obses are {eps_2_next_obses}") + + # Create an instance of this class. + connector = AddNextObservationsFromEpisodesToTrainBatch() + + # Call the connector with the two created episodes. + # Note that this particular connector works without an RLModule, so we + # simplify here for the sake of this example. + output_data = connector( + rl_module=None, + data={}, + episodes=episodes, + explore=True, + shared_data={}, + ) + # The output data should now contain the last observations of both episodes, + # in a "per-episode organized" fashion. + check( + output_data, + { + Columns.NEXT_OBS: { + (episodes[0].id_,): eps_1_next_obses, + (episodes[1].id_,): eps_2_next_obses, + }, + }, + ) + """ + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + **kwargs, + ): + """Initializes a AddNextObservationsFromEpisodesToTrainBatch instance.""" + super().__init__( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + **kwargs, + ) + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # If "obs" already in data, early out. + if Columns.NEXT_OBS in data: + return data + + for sa_episode in self.single_agent_episode_iterator( + # This is a Learner-only connector -> Get all episodes (for train batch). + episodes, + agents_that_stepped_only=False, + ): + self.add_n_batch_items( + data, + Columns.NEXT_OBS, + items_to_add=sa_episode.get_observations(slice(1, len(sa_episode) + 1)), + num_items=len(sa_episode), + single_agent_episode=sa_episode, + ) + return data diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index b4062944b945..5b742a5d1b60 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -1287,7 +1287,7 @@ def _update_from_batch_or_episodes( # Call the learner connector pipeline. batch = self._learner_connector( rl_module=self.module, - data=batch, + data=batch if batch is not None else {}, episodes=episodes, shared_data={}, ) diff --git a/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py b/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py index 26435e8ceddc..c32415ba01a9 100644 --- a/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py @@ -591,7 +591,7 @@ def _sample_independent( Columns.TERMINATEDS: np.array(is_terminated), Columns.TRUNCATEDS: np.array(is_truncated), "weights": np.array(weights), - "n_steps": np.array(n_steps), + "n_step": np.array(n_steps), } # Include infos if necessary. if include_infos: @@ -804,7 +804,7 @@ def _sample_synchonized( Columns.TERMINATEDS: np.array(is_terminated[module_id]), Columns.TRUNCATEDS: np.array(is_truncated[module_id]), "weights": np.array(weights[module_id]), - "n_steps": np.array(n_steps[module_id]), + "n_step": np.array(n_steps[module_id]), } for module_id in observations.keys() } diff --git a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py index ce9f7ec52f77..f7e818a659cf 100644 --- a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py @@ -6,13 +6,11 @@ from numpy.typing import NDArray from typing import Any, Dict, List, Optional, Tuple, Union -from ray.rllib.core.columns import Columns from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree from ray.rllib.utils import force_list from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.annotations import override -from ray.rllib.utils.spaces.space_utils import batch from ray.rllib.utils.typing import SampleBatchType @@ -267,9 +265,7 @@ def add( [ ( eps_idx, - # Note, we add 1 b/c the first timestep is never - # sampled. - old_len + i + 1, + old_len + i, # Get the index in the segment trees. self._get_free_node_and_assign(j + i, weight), ) @@ -286,9 +282,7 @@ def add( [ ( eps_idx, - # Note, we add 1 b/c the first timestep is never - # sampled. - i + 1, + i, self._get_free_node_and_assign(j + i, weight), ) for i in range(len(eps)) @@ -361,9 +355,8 @@ def sample( timestep at which the action is computed). Returns: - A sample batch (observations, actions, rewards, new observations, - terminateds, truncateds, weights) and if requested infos of dimension - [B, 1]. + A list of 1-step long episodes containing all basic episode data and if + requested infos and extra model outputs. """ assert beta >= 0.0 @@ -380,33 +373,18 @@ def sample( batch_length_T = batch_length_T or self.batch_length_T # Sample the n-step if necessary. + actual_n_step = n_step or 1 + random_n_step = False if isinstance(n_step, tuple): - # Use random n-step sampling. random_n_step = True - else: - actual_n_step = n_step or 1 - random_n_step = False - - # Rows to return. - observations = [[] for _ in range(batch_size_B)] - next_observations = [[] for _ in range(batch_size_B)] - actions = [[] for _ in range(batch_size_B)] - rewards = [[] for _ in range(batch_size_B)] - is_terminated = [False for _ in range(batch_size_B)] - is_truncated = [False for _ in range(batch_size_B)] - weights = [[] for _ in range(batch_size_B)] - n_steps = [[] for _ in range(batch_size_B)] - # If `info` should be included, construct also a container for them. - if include_infos: - infos = [[] for _ in range(batch_size_B)] - # If `extra_model_outputs` should be included, construct a container for them. - if include_extra_model_outputs: - extra_model_outputs = [[] for _ in range(batch_size_B)] + # Keep track of the indices that were sampled last for updating the # weights later (see `ray.rllib.utils.replay_buffer.utils. # update_priorities_in_episode_replay_buffer`). self._last_sampled_indices = [] + sampled_episodes = [] + # Sample proportionally from replay buffer's segments using the weights. total_segment_sum = self._sum_segment.sum() p_min = self._min_segment.min() / total_segment_sum @@ -441,337 +419,74 @@ def sample( # If we use random n-step sampling, draw the n-step for this item. if random_n_step: actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) - # If we are at the end of an episode, continue. - # Note, priority sampling got us `o_(t+n)` and we need for the loss - # calculation in addition `o_t`. - # TODO (simon): Maybe introduce a variable `num_retries` until the - # while loop should break when not enough samples have been collected - # to make n-step possible. - if episode_ts - actual_n_step < 0: - continue - else: - n_steps[B] = actual_n_step - - # Starting a new chunk. - # Ensure that each row contains a tuple of the form: - # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) - # TODO (simon): Implement version for sequence sampling when using RNNs. - eps_observations = episode.get_observations( - slice(episode_ts - actual_n_step, episode_ts + 1) - ) - # Note, the reward that is collected by transitioning from `o_t` to - # `o_(t+1)` is stored in the next transition in `SingleAgentEpisode`. - eps_rewards = episode.get_rewards( - slice(episode_ts - actual_n_step, episode_ts) - ) - observations[B] = eps_observations[0] - next_observations[B] = eps_observations[-1] - # Note, this will be the reward after executing action - # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the sum of - # all rewards that were collected over the last n steps. - rewards[B] = scipy.signal.lfilter( - [1], [1, -gamma], eps_rewards[::-1], axis=0 - )[-1] - # Note, `SingleAgentEpisode` stores the action that followed - # `o_t` with `o_(t+1)`, therefore, we need the next one. - actions[B] = episode.get_actions(episode_ts - actual_n_step) - if include_infos: - # If infos are included we include the ones from the last timestep - # as usually the info contains additional values about the last state. - infos[B] = episode.get_infos(episode_ts) - if include_extra_model_outputs: - # If `extra_model_outputs` are included we include the ones from the - # first timestep as usually the `extra_model_outputs` contain additional - # values from the forward pass that produced the action at the first - # timestep. - # Note, we extract them into single row dictionaries similar to the - # infos, in a connector we can then extract these into single batch - # rows. - extra_model_outputs[B] = { - k: episode.get_extra_model_outputs(k, episode_ts - actual_n_step) - for k in episode.extra_model_outputs.keys() - } - - # If the sampled time step is the episode's last time step check, if - # the episode is terminated or truncated. - if episode_ts == episode.t: - is_terminated[B] = episode.is_terminated - is_truncated[B] = episode.is_truncated - - # TODO (simon): Check, if we have to correct here for sequences - # later. - actual_size = 1 - weights[B] = weight / max_weight * actual_size - - # Increment counter. - B += 1 - - # Keep track of sampled indices for updating priorities later. - self._last_sampled_indices.append(idx) - - self.sampled_timesteps += batch_size_B - - # TODO Return SampleBatch instead of this simpler dict. - # TODO (simon): Check, if for stateful modules we want to sample - # here the sequences. If not remove the double list for obs. - ret = { - # Note, observation and action spaces could be complex. `batch` - # takes care of these. - Columns.OBS: batch(observations), - Columns.ACTIONS: batch(actions), - Columns.REWARDS: np.array(rewards), - Columns.NEXT_OBS: batch(next_observations), - Columns.TERMINATEDS: np.array(is_terminated), - Columns.TRUNCATEDS: np.array(is_truncated), - "weights": np.array(weights), - "n_steps": np.array(n_steps), - } - # Include infos if necessary. - if include_infos: - ret.update( - { - Columns.INFOS: infos, - } - ) - # Include extra model outputs, if necessary. - if include_extra_model_outputs: - ret.update( - # These could be complex, too. - batch(extra_model_outputs) - ) - - return ret - - # TODO (simon): Adjust docstring. - def sample_with_keys( - self, - num_items: Optional[int] = None, - *, - batch_size_B: Optional[int] = None, - batch_length_T: Optional[int] = None, - n_step: Optional[Union[int, Tuple]] = None, - beta: float = 0.0, - gamma: float = 0.99, - include_infos: bool = False, - include_extra_model_outputs: bool = False, - ) -> SampleBatchType: - """Samples from a buffer in a prioritized way. - - This sampling method also adds (importance sampling) weights to - the returned batch. See for prioritized sampling Schaul et al. - (2016). - - Each sampled item defines a transition of the form: - - `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))` - - where `o_(t+n)` is drawn by prioritized sampling, i.e. the priority - of `o_(t+n)` led to the sample and defines the importance weight that - is returned in the sample batch. `n` is defined by the `n_step` applied. - If requested, `info`s of a transitions last timestep `t+n` are added to - the batch. - - Args: - num_items: Number of items (transitions) to sample from this - buffer. - batch_size_B: The number of rows (transitions) to return in the - batch - n_step: The n-step to apply. For the default the batch contains in - `"new_obs"` the observation and in `"obs"` the observation `n` - time steps before. The reward will be the sum of rewards - collected in between these two observations and the action will - be the one executed n steps before such that we always have the - state-action pair that triggered the rewards. - If `n_step` is a tuple, it is considered as a range to sample - from. If `None`, we use `n_step=1`. - beta: The exponent of the importance sampling weight (see Schaul et - al. (2016)). A `beta=0.0` does not correct for the bias introduced - by prioritized replay and `beta=1.0` fully corrects for it. - gamma: The discount factor to be used when applying n-step caluclations. - The default of `0.99` should be replaced by the `Algorithm`s - discount factor. - include_infos: A boolean indicating, if `info`s should be included in - the batch. This could be of advantage, if the `info` contains - values from the environment important for loss computation. If - `True`, the info at the `"new_obs"` in the batch is included. - include_extra_model_outputs: A boolean indicating, if - `extra_model_outputs` should be included in the batch. This could be - of advantage, if the `extra_mdoel_outputs` contain outputs from the - model important for loss computation and only able to compute with the - actual state of model e.g. action log-probabilities, etc.). If `True`, - the extra model outputs at the `"obs"` in the batch is included (the - timestep at which the action is computed). - - Returns: - A sample batch (observations, actions, rewards, new observations, - terminateds, truncateds, weights) and if requested infos and extra model - outputs. Extra model outputs are extracted to single columns in the batch - and infos are kept as a list of dictionaries. The batch keys are the episode - ids. - """ - assert beta >= 0.0 - - if num_items is not None: - assert batch_size_B is None, ( - "Cannot call `sample()` with both `num_items` and `batch_size_B` " - "provided! Use either one." - ) - batch_size_B = num_items - - # Use our default values if no sizes/lengths provided. - batch_size_B = batch_size_B or self.batch_size_B - batch_length_T = batch_length_T or self.batch_length_T - - # Sample the n-step if necessary. - if isinstance(n_step, tuple): - # Use random n-step sampling. - random_n_step = True - else: - actual_n_step = n_step or 1 - random_n_step = False - - # Columns to return. - observations = {} - next_observations = {} - actions = {} - rewards = {} - is_terminated = {} - is_truncated = {} - weights = {} - n_steps = {} - # If `info` should be included, construct also a container for them. - if include_infos: - infos = {} - # If `extra_model_outputs` should be included, construct a container for them. - if include_extra_model_outputs: - # Get the keys from an episode in the buffer. - # TODO (simon, sven): What happens, if different episodes have different - # extra model outputs or some are missing? - extra_model_outputs = { - k: {} for k in self.episodes[0].extra_model_outputs.keys() - } - # Keep track of the indices that were sampled last for updating the - # weights later (see `ray.rllib.utils.replay_buffer.utils. - # update_priorities_in_episode_replay_buffer`). - self._last_sampled_indices = [] - - # Sample proportionally from replay buffer's segments using the weights. - total_segment_sum = self._sum_segment.sum() - p_min = self._min_segment.min() / total_segment_sum - max_weight = (p_min * self.get_num_timesteps()) ** (-beta) - B = 0 - while B < batch_size_B: - # First, draw a random sample from Uniform(0, sum over all weights). - # Note, transitions with higher weight get sampled more often (as - # more random draws fall into larger intervals). - random_sum = self.rng.random() * self._sum_segment.sum(0, self._max_idx + 1) - # Get the highest index in the sum-tree for which the sum is - # smaller or equal the random sum sample. - # Note, we sample `o_(t + n_step)` as this is the state that - # brought the information contained in the TD-error (see Schaul - # et al. (2018), Algorithm 1). - idx = self._sum_segment.find_prefixsum_idx(random_sum) - # Get the theoretical probability mass for drawing this sample. - p_sample = self._sum_segment[idx] / total_segment_sum - # Compute the importance sampling weight. - weight = (p_sample * self.get_num_timesteps()) ** (-beta) - # Now, get the transition stored at this index. - index_triple = self._indices[self._tree_idx_to_sample_idx[idx]] - - # Compute the actual episode index (offset by the number of - # already evicted episodes) - episode_idx, episode_ts = ( - index_triple[0] - self._num_episodes_evicted, - index_triple[1], - ) - episode = self.episodes[episode_idx] - - # If we use random n-step sampling, draw the n-step for this item. - if random_n_step: - actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) - # If we are at the end of an episode, continue. - # Note, priority sampling got us `o_(t+n)` and we need for the loss - # calculation in addition `o_t`. - # TODO (simon): Maybe introduce a variable `num_retries` until the - # while loop should break when not enough samples have been collected - # to make n-step possible. - if episode_ts - actual_n_step < 0: + # Skip, if we are too far to the end and `episode_ts` + n_step would go + # beyond the episode's end. + if episode_ts + actual_n_step > len(episode): continue - # Starting a new chunk. - # Ensure that each row contains a tuple of the form: - # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) - # TODO (simon): Implement version for sequence sampling when using RNNs. - eps_observations = episode.get_observations( - slice(episode_ts - actual_n_step, episode_ts + 1) - ) - # Note, the reward that is collected by transitioning from `o_t` to - # `o_(t+1)` is stored in the next transition in `SingleAgentEpisode`. - eps_rewards = episode.get_rewards( - slice(episode_ts - actual_n_step, episode_ts) - ) - if (episode.id_,) not in observations: - # Add the key to all containers. - observations[(episode.id_,)] = [] - next_observations[(episode.id_,)] = [] - actions[(episode.id_,)] = [] - rewards[(episode.id_,)] = [] - is_terminated[(episode.id_,)] = [] - is_truncated[(episode.id_,)] = [] - weights[(episode.id_,)] = [] - n_steps[(episode.id_,)] = [] - if include_infos: - infos[(episode.id_,)] = [] - if include_extra_model_outputs: - # 'extra_model_outputs` has a structure - # `{"output_1": {(eps_id0,): [0.4, 2.3], ...}, ...}`` - for k in extra_model_outputs: - extra_model_outputs[k][(episode.id_,)] = [] - - # Add the `n_step` used for this item. - n_steps[(episode.id_,)].append(actual_n_step) - - observations[(episode.id_,)].append(eps_observations[0]) - next_observations[(episode.id_,)].append(eps_observations[-1]) # Note, this will be the reward after executing action # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the sum of - # all rewards that were collected over the last n steps. - rewards[(episode.id_,)].append( - scipy.signal.lfilter([1], [1, -gamma], eps_rewards[::-1], axis=0)[-1] + # all discounted rewards that were collected over the last n steps. + raw_rewards = episode.get_rewards( + slice(episode_ts, episode_ts + actual_n_step) ) - # Note, `SingleAgentEpisode` stores the action that followed - # `o_t` with `o_(t+1)`, therefore, we need the next one. - actions[(episode.id_,)].append( - episode.get_actions(episode_ts - actual_n_step) + rewards = scipy.signal.lfilter([1], [1, -gamma], raw_rewards[::-1], axis=0)[ + -1 + ] + + # Generate the episode to be returned. + sampled_episode = SingleAgentEpisode( + # Ensure that each episode contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # Two observations (t and t+n). + observations=[ + episode.get_observations(episode_ts), + episode.get_observations(episode_ts + actual_n_step), + ], + observation_space=episode.observation_space, + infos=( + [ + episode.get_infos(episode_ts), + episode.get_infos(episode_ts + actual_n_step), + ] + if include_infos + else None + ), + actions=[episode.get_actions(episode_ts)], + action_space=episode.action_space, + rewards=[rewards], + # If the sampled time step is the episode's last time step check, if + # the episode is terminated or truncated. + terminated=( + False + if episode_ts + actual_n_step < len(episode) + else episode.is_terminated + ), + truncated=( + False + if episode_ts + actual_n_step < len(episode) + else episode.is_truncated + ), + extra_model_outputs={ + # TODO (simon): Check, if we have to correct here for sequences + # later. + "weights": [weight / max_weight * 1], # actual_size=1 + "n_step": [actual_n_step], + **( + { + k: [episode.get_extra_model_outputs(k, episode_ts)] + for k in episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + # TODO (sven): Support lookback buffers. + len_lookback_buffer=0, + t_started=episode_ts, ) - if include_infos: - # If infos are included we include the ones from the last timestep - # as usually the info contains additional values about the last state. - infos[(episode.id_,)].append(episode.get_infos(episode_ts)) - if include_extra_model_outputs: - # If `extra_model_outputs` are included we include the ones from the - # first timestep as usually the `extra_model_outputs` contain additional - # values from the forward pass that produced the action at the first - # timestep. - for k in extra_model_outputs: - extra_model_outputs[k][(episode.id_,)].append( - episode.get_extra_model_outputs(k, episode_ts - actual_n_step) - ) - - # If the sampled time step is the episode's last time step check, if - # the episode is terminated or truncated. - if episode_ts == episode.t: - is_terminated[(episode.id_,)].append(episode.is_terminated) - is_truncated[(episode.id_,)].append(episode.is_truncated) - else: - is_terminated[(episode.id_,)].append(False) - is_truncated[(episode.id_,)].append(False) - - # TODO (simon): Check, if we have to correct here for sequences - # later. - actual_size = 1 - weights[(episode.id_,)].append(weight / max_weight * actual_size) + sampled_episodes.append(sampled_episode) # Increment counter. B += 1 @@ -781,29 +496,7 @@ def sample_with_keys( self.sampled_timesteps += batch_size_B - # TODO Return SampleBatch instead of this simpler dict. - ret = { - Columns.OBS: observations, - Columns.ACTIONS: actions, - Columns.REWARDS: rewards, - Columns.NEXT_OBS: next_observations, - Columns.TERMINATEDS: is_terminated, - Columns.TRUNCATEDS: is_truncated, - "weights": weights, - "n_steps": n_steps, - } - # Include infos if necessary. - if include_infos: - ret.update( - { - Columns.INFOS: infos, - } - ) - # Include extra model outputs, if necessary. - if include_extra_model_outputs: - ret.update(extra_model_outputs) - - return ret + return sampled_episodes @override(EpisodeReplayBuffer) def get_state(self) -> Dict[str, Any]: diff --git a/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py index 1108440a86dd..14a3860c5e6c 100644 --- a/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py @@ -171,7 +171,7 @@ def test_buffer_independent_sample_logic(self): sample[module_id]["terminateds"], sample[module_id]["truncateds"], sample[module_id]["weights"], - sample[module_id]["n_steps"], + sample[module_id]["n_step"], ) # Make sure terminated and truncated are never both True. @@ -238,7 +238,7 @@ def test_buffer_synchronized_sample_logic(self): sample[module_id]["terminateds"], sample[module_id]["truncateds"], sample[module_id]["weights"], - sample[module_id]["n_steps"], + sample[module_id]["n_step"], ) # Make sure terminated and truncated are never both True. @@ -309,7 +309,7 @@ def test_sample_with_modules_to_sample(self): sample[module_id]["terminateds"], sample[module_id]["truncateds"], sample[module_id]["weights"], - sample[module_id]["n_steps"], + sample[module_id]["n_step"], ) # Make sure terminated and truncated are never both True. diff --git a/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py index 2f316ecd879a..6d0d8b7c3645 100644 --- a/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py @@ -112,7 +112,7 @@ def test_prioritized_buffer_sample_logic(self): sample["terminateds"], sample["truncateds"], sample["weights"], - sample["n_steps"], + sample["n_step"], ) # Make sure terminated and truncated are never both True. @@ -165,7 +165,7 @@ def test_prioritized_buffer_sample_logic(self): sample["terminateds"], sample["truncateds"], sample["weights"], - sample["n_steps"], + sample["n_step"], ) # Make sure terminated and truncated are never both True. @@ -218,7 +218,7 @@ def test_prioritized_buffer_sample_logic(self): sample["terminateds"], sample["truncateds"], sample["weights"], - sample["n_steps"], + sample["n_step"], ) # Make sure terminated and truncated are never both True. @@ -284,7 +284,7 @@ def test_infos_and_extra_model_outputs(self): sample["terminateds"], sample["truncateds"], sample["weights"], - sample["n_steps"], + sample["n_step"], sample["infos"], sample[0], sample[1], @@ -350,7 +350,7 @@ def test_sample_with_keys(self): sample["terminateds"], sample["truncateds"], sample["weights"], - sample["n_steps"], + sample["n_step"], sample["infos"], sample[0], sample[1], diff --git a/rllib/utils/replay_buffers/utils.py b/rllib/utils/replay_buffers/utils.py index 21b88dcd8801..3b1bb6b6924f 100644 --- a/rllib/utils/replay_buffers/utils.py +++ b/rllib/utils/replay_buffers/utils.py @@ -1,6 +1,6 @@ import logging import psutil -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional import numpy as np @@ -18,39 +18,39 @@ MultiAgentReplayBuffer, ) from ray.rllib.policy.sample_batch import concat_samples, MultiAgentBatch, SampleBatch -from ray.rllib.utils.typing import ResultDict, SampleBatchType, AlgorithmConfigDict +from ray.rllib.utils.typing import ( + AlgorithmConfigDict, + ModuleID, + ResultDict, + SampleBatchType, + TensorType, +) from ray.util import log_once from ray.util.annotations import DeveloperAPI -if TYPE_CHECKING: - from ray.rllib.algorithms.algorithm_config import AlgorithmConfig - logger = logging.getLogger(__name__) +# TODO (simon): Move all regular keys to the metric constants file. +TD_ERROR_KEY = "td_error" + @DeveloperAPI def update_priorities_in_episode_replay_buffer( + *, replay_buffer: EpisodeReplayBuffer, - config: "AlgorithmConfig", - train_batch: SampleBatchType, - train_results: ResultDict, + td_errors: Dict[ModuleID, TensorType], ) -> None: # Only update priorities, if the buffer supports them. if isinstance(replay_buffer, PrioritizedEpisodeReplayBuffer): # The `ResultDict` will be multi-agent. - for module_id, result_dict in train_results.items(): + for module_id, td_error in td_errors.items(): # Skip the `"__all__"` keys. if module_id in ["__all__", ALL_MODULES]: continue - from ray.rllib.algorithms.dqn.dqn_rainbow_learner import TD_ERROR_KEY - - # Get the TD-error from the results. - td_error = result_dict.get(TD_ERROR_KEY, None) - # Warn once, if we have no TD-errors to update priorities. - if td_error is None: + if TD_ERROR_KEY not in td_error or td_error[TD_ERROR_KEY] is None: if log_once( "no_td_error_in_train_results_from_module_{}".format(module_id) ): @@ -62,10 +62,12 @@ def update_priorities_in_episode_replay_buffer( ) continue # TODO (simon): Implement multi-agent version. Remove, happens in buffer. - assert len(td_error) == len(replay_buffer._last_sampled_indices) + assert len(td_error[TD_ERROR_KEY]) == len( + replay_buffer._last_sampled_indices + ) # TODO (simon): Implement for stateful modules. - replay_buffer.update_priorities(td_error) + replay_buffer.update_priorities(td_error[TD_ERROR_KEY]) @OldAPIStack