Skip to content

Commit

Permalink
[RLlib] Uniform random n-step sampling in `PrioritizedEpisodeReplayBu…
Browse files Browse the repository at this point in the history
…ffer`. (ray-project#43258)
  • Loading branch information
simonsays1980 authored Feb 21, 2024
1 parent de1eb62 commit 73293e2
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 21 deletions.
21 changes: 16 additions & 5 deletions rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ def training(
This is the inverse of reward scale, and will be optimized
automatically.
n_step: N-step target updates. If >1, sars' tuples in trajectories will be
postprocessed to become sa[discounted sum of R][s t+n] tuples.
postprocessed to become sa[discounted sum of R][s t+n] tuples. An
integer will be interpreted as a fixed n-step value. In case of a tuple
the n-step value will be drawn for each sample in the train batch from
a uniform distribution over the interval defined by the 'n-step'-tuple.
store_buffer_in_checkpoints: Set this to True, if you want the contents of
your buffer(s) to be stored in any saved checkpoints as well.
Warnings will be created if:
Expand Down Expand Up @@ -327,15 +330,23 @@ def validate(self) -> None:
super().validate()

# Check rollout_fragment_length to be compatible with n_step.
if isinstance(self.n_step, tuple):
min_rollout_fragment_length = self.n_step[1]
else:
min_rollout_fragment_length = self.n_step

if (
not self.in_evaluation
and self.rollout_fragment_length != "auto"
and self.rollout_fragment_length < self.n_step
and self.rollout_fragment_length
< min_rollout_fragment_length # (self.n_step or 1)
):
raise ValueError(
f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is "
f"smaller than `n_step` ({self.n_step})! "
f"Try setting config.rollouts(rollout_fragment_length={self.n_step})."
f"smaller than needed for `n_step` ({self.n_step})! If `n_step` is "
f"an integer try setting `rollout_fragment_length={self.n_step}`. If "
"`n_step` is a tuple, try setting "
f"`rollout_fragment_length={self.n_step[1]}`."
)

if self.use_state_preprocessor != DEPRECATED_VALUE:
Expand Down Expand Up @@ -371,7 +382,7 @@ def validate(self) -> None:
@override(AlgorithmConfig)
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
if self.rollout_fragment_length == "auto":
return self.n_step
return self.n_step[1] if isinstance(self.n_step, tuple) else self.n_step
else:
return self.rollout_fragment_length

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def compute_loss_for_module(
# backpropagate through the target network when optimizing the Q loss.
q_selected_target = (
batch[SampleBatch.REWARDS]
+ (self.config.gamma**self.config.n_step) * q_next_masked
+ (self.config.gamma ** batch["n_steps"]) * q_next_masked
).detach()

# Calculate the TD-error. Note, this is needed for the priority weights in
Expand Down
35 changes: 23 additions & 12 deletions rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(
capacity=capacity, batch_size_B=batch_size_B, batch_length_T=batch_length_T
)

assert alpha > 0
# `alpha` should be non-negative.
assert alpha >= 0
self._alpha = alpha

# Initialize segment trees for the priority weights. Note, b/c the trees
Expand Down Expand Up @@ -215,7 +216,7 @@ def sample(
be the one executed n steps before such that we always have the
state-action pair that triggered the rewards.
If `n_step` is a tuple, it is considered as a range to sample
from. If `None`, we sample from the range (1, 5).
from. If `None`, we use `n_step=1`.
beta: The exponent of the importance sampling weight (see Schaul et
al. (2016)). A `beta=0.0` does not correct for the bias introduced
by prioritized replay and `beta=1.0` fully corrects for it.
Expand Down Expand Up @@ -246,12 +247,12 @@ def sample(
batch_length_T = batch_length_T or self.batch_length_T

# Sample the n-step if necessary.
if n_step is None:
# If ' None' we sample from the closed interval (1, 5).
n_step = self.rng.integers(1, 5)
elif isinstance(n_step, tuple):
# Otherwise sample from the user-defined interval.
n_step = self.rng.integers(n_step[0], n_step[1])
if isinstance(n_step, tuple):
# Use random n-step sampling.
random_n_step = True
else:
actual_n_step = n_step
random_n_step = False

# Rows to return.
observations = [[] for _ in range(batch_size_B)]
Expand All @@ -261,6 +262,7 @@ def sample(
is_terminated = [[False] * batch_length_T for _ in range(batch_size_B)]
is_truncated = [[False] * batch_length_T for _ in range(batch_size_B)]
weights = [[] for _ in range(batch_size_B)]
n_steps = [[] for _ in range(batch_size_B)]
# If `info` should be included, construct also a container for them.
# TODO (simon): Add also `extra_model_outs`.
if include_infos:
Expand Down Expand Up @@ -300,25 +302,33 @@ def sample(
index_triple[1],
)
episode = self.episodes[episode_idx]

# If we use random n-step sampling, draw the n-step for this item.
if random_n_step:
actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))
# If we are at the end of an episode, continue.
# Note, priority sampling got us `o_(t+n)` and we need for the loss
# calculation in addition `o_t`.
# TODO (simon): Maybe introduce a variable `num_retries` until the
# while loop should break when not enough samples have been collected
# to make n-step possible.
if episode_ts - n_step < 0:
if episode_ts - actual_n_step < 0:
continue
else:
n_steps[B].append(actual_n_step)

# Starting a new chunk.
# Ensure that each row contains a tuple of the form:
# (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step))
# TODO (simon): Implement version for sequence sampling when using RNNs.
eps_observations = episode.get_observations(
slice(episode_ts - n_step, episode_ts + 1)
slice(episode_ts - actual_n_step, episode_ts + 1)
)
# Note, the reward that is collected by transitioning from `o_t` to
# `o_(t+1)` is stored in the next transition in `SingleAgentEpisode`.
eps_rewards = episode.get_rewards(slice(episode_ts - n_step, episode_ts))
eps_rewards = episode.get_rewards(
slice(episode_ts - actual_n_step, episode_ts)
)
observations[B].append(eps_observations[0])
next_observations[B].append(eps_observations[-1])
# Note, this will be the reward after executing action
Expand Down Expand Up @@ -364,6 +374,7 @@ def sample(
SampleBatch.TERMINATEDS: np.array(is_terminated),
SampleBatch.TRUNCATEDS: np.array(is_truncated),
"weights": np.array(weights),
"n_steps": np.array(n_steps),
}
if include_infos:
ret.update(
Expand Down Expand Up @@ -430,7 +441,7 @@ def update_priorities(self, priorities: NDArray) -> None:
# Note, TD-errors come in as absolute values or results from
# cross-entropy loss calculations.
assert priority > 0
assert 0 <= idx < self.get_num_timesteps()
assert 0 <= idx < self._sum_segment.capacity
# TODO (simon): Create metrics.
# delta = priority**self._alpha - self._sum_segment[idx]
# Update the priorities in the segment trees.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,24 @@ def test_prioritized_buffer_sample_logic(self):

for _ in range(1000):
sample = buffer.sample(batch_size_B=16, n_step=1)
obs, actions, rewards, next_obs, is_terminated, is_truncated, weights = (
(
obs,
actions,
rewards,
next_obs,
is_terminated,
is_truncated,
weights,
n_steps,
) = (
sample["obs"],
sample["actions"],
sample["rewards"],
sample["new_obs"],
sample["terminateds"],
sample["truncateds"],
sample["weights"],
sample["n_steps"],
)

# Make sure terminated and truncated are never both True.
Expand All @@ -124,21 +134,34 @@ def test_prioritized_buffer_sample_logic(self):
# Assert that the reward comes from the next observation.
self.assertTrue(np.all(rewards * 10 - next_obs < tolerance))

# Furthermore, assert that the improtance sampling weights are
# Furthermore, assert that the importance sampling weights are
# one for `beta=0.0`.
self.assertTrue(np.all(weights - 1.0 < tolerance))

# Assert that all n-steps are 1.0 as passed into `sample`.
self.assertTrue(np.all(n_steps - 1.0 < tolerance))

# Now test a 3-step sampling.
for _ in range(1000):
sample = buffer.sample(batch_size_B=16, n_step=3, beta=1.0)
obs, actions, rewards, next_obs, is_terminated, is_truncated, weights = (
(
obs,
actions,
rewards,
next_obs,
is_terminated,
is_truncated,
weights,
n_steps,
) = (
sample["obs"],
sample["actions"],
sample["rewards"],
sample["new_obs"],
sample["terminateds"],
sample["truncateds"],
sample["weights"],
sample["n_steps"],
)

# Make sure terminated and truncated are never both True.
Expand Down Expand Up @@ -168,6 +191,55 @@ def test_prioritized_buffer_sample_logic(self):
) * 0.1
self.assertTrue(np.all(rewards - reward_sum < tolerance))

# Furtermore, ensure that all n-steps are 3 as passed into `sample`.
self.assertTrue(np.all(n_steps - 3.0 < tolerance))

# Now test a random n-step sampling.
for _ in range(1000):
sample = buffer.sample(batch_size_B=16, n_step=None, beta=1.0)
(
obs,
actions,
rewards,
next_obs,
is_terminated,
is_truncated,
weights,
n_steps,
) = (
sample["obs"],
sample["actions"],
sample["rewards"],
sample["new_obs"],
sample["terminateds"],
sample["truncateds"],
sample["weights"],
sample["n_steps"],
)

# Make sure terminated and truncated are never both True.
assert not np.any(np.logical_and(is_truncated, is_terminated))

# All fields have same shape.
assert (
obs.shape[:2]
== rewards.shape
== actions.shape
== next_obs.shape
== is_truncated.shape
== is_terminated.shape
)

# Note, floating point numbers cannot be compared directly.
tolerance = 1e-8

# Furtermore, ensure that n-steps are in between 1 and 5.
self.assertTrue(np.all(n_steps - 5.0 < tolerance))
self.assertTrue(np.all(n_steps - 1.0 > -tolerance))

# Ensure that there is variation in the n-steps.
self.assertTrue(np.var(n_steps) > 0.0)

def test_update_priorities(self):
# Define replay buffer (alpha=1.0).
buffer = PrioritizedEpisodeReplayBuffer(capacity=100)
Expand Down

0 comments on commit 73293e2

Please sign in to comment.