Skip to content

Commit

Permalink
[rllib] Add option for RNN state and value estimates to span episodes (
Browse files Browse the repository at this point in the history
…ray-project#4429)

* wip soft horizon

* tests
  • Loading branch information
ericl authored Apr 2, 2019
1 parent c2c548b commit 55a2d39
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 13 deletions.
8 changes: 7 additions & 1 deletion python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@
# === Environment ===
# Discount factor of the MDP
"gamma": 0.99,
# Number of steps after which the episode is forced to terminate
# Number of steps after which the episode is forced to terminate. Defaults
# to `env.spec.max_episode_steps` (if present) for Gym envs.
"horizon": None,
# Calculate rewards but don't reset the environment when the horizon is
# hit. This allows value estimation and RNN state to span across logical
# episodes denoted by horizon. This only has an effect if horizon != inf.
"soft_horizon": False,
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
Expand Down Expand Up @@ -746,6 +751,7 @@ def session_creator():
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
_fake_sampler=config.get("_fake_sampler", False))

@override(Trainable)
Expand Down
13 changes: 13 additions & 0 deletions python/ray/rllib/evaluation/episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def __init__(self, policies, policy_mapping_fn, batch_builder_factory,
self._agent_to_prev_action = {}
self._agent_reward_history = defaultdict(list)

@DeveloperAPI
def soft_reset(self):
"""Clears rewards and metrics, but retains RNN and other state.
This is used to carry state across multiple logical episodes in the
same env (i.e., if `soft_horizon` is set).
"""
self.length = 0
self.episode_id = random.randrange(2e9)
self.total_reward = 0.0
self.agent_rewards = defaultdict(float)
self._agent_reward_history = defaultdict(list)

@DeveloperAPI
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the policy graph for the specified agent.
Expand Down
9 changes: 7 additions & 2 deletions python/ray/rllib/evaluation/policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(self,
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False,
remote_env_batch_wait_ms=0,
soft_horizon=False,
_fake_sampler=False):
"""Initialize a policy evaluator.
Expand Down Expand Up @@ -208,6 +209,8 @@ def __init__(self,
least one env is ready) is a reasonable default, but optimal
value could be obtained by measuring your environment
step / reset and model inference perf.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
"""

Expand Down Expand Up @@ -372,7 +375,8 @@ def make_env(vector_index):
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation)
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon)
self.sampler.start()
else:
self.sampler = SyncSampler(
Expand All @@ -387,7 +391,8 @@ def make_env(vector_index):
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions)
clip_actions=clip_actions,
soft_horizon=soft_horizon)

self.input_reader = input_creator(self.io_context)
assert isinstance(self.input_reader, InputReader), self.input_reader
Expand Down
35 changes: 25 additions & 10 deletions python/ray/rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(self,
horizon=None,
pack=False,
tf_sess=None,
clip_actions=True):
clip_actions=True,
soft_horizon=False):
self.base_env = BaseEnv.to_base_env(env)
self.unroll_length = unroll_length
self.horizon = horizon
Expand All @@ -92,7 +93,7 @@ def __init__(self,
self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess, self.perf_stats)
pack, callbacks, tf_sess, self.perf_stats, soft_horizon)
self.metrics_queue = queue.Queue()

def get_data(self):
Expand Down Expand Up @@ -137,7 +138,8 @@ def __init__(self,
pack=False,
tf_sess=None,
clip_actions=True,
blackhole_outputs=False):
blackhole_outputs=False,
soft_horizon=False):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
Expand All @@ -159,6 +161,7 @@ def __init__(self,
self.callbacks = callbacks
self.clip_actions = clip_actions
self.blackhole_outputs = blackhole_outputs
self.soft_horizon = soft_horizon
self.perf_stats = PerfStats()
self.shutdown = False

Expand All @@ -182,7 +185,7 @@ def _run(self):
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
self.perf_stats)
self.perf_stats, self.soft_horizon)
while not self.shutdown:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
Expand Down Expand Up @@ -227,7 +230,7 @@ def get_extra_batches(self):
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
unroll_length, horizon, preprocessors, obs_filters,
clip_rewards, clip_actions, pack, callbacks, tf_sess,
perf_stats):
perf_stats, soft_horizon):
"""This implements the common experience collection logic.
Args:
Expand All @@ -252,6 +255,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
perf_stats (PerfStats): Record perf stats into this object.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
Yields:
rollout (SampleBatch): Object containing state, action, reward,
Expand Down Expand Up @@ -307,7 +312,8 @@ def new_episode():
active_envs, to_eval, outputs = _process_observations(
base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, unroll_length, pack, callbacks)
preprocessors, obs_filters, unroll_length, pack, callbacks,
soft_horizon)
perf_stats.processing_time += time.time() - t1
for o in outputs:
yield o
Expand Down Expand Up @@ -335,7 +341,8 @@ def new_episode():
def _process_observations(base_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
infos, off_policy_actions, horizon, preprocessors,
obs_filters, unroll_length, pack, callbacks):
obs_filters, unroll_length, pack, callbacks,
soft_horizon):
"""Record new data from the environment and prepare for policy evaluation.
Returns:
Expand Down Expand Up @@ -372,6 +379,8 @@ def _process_observations(base_env, policies, batch_builder_pool,

# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
hit_horizon = (episode.length >= horizon
and not dones[env_id]["__all__"])
all_done = True
atari_metrics = _fetch_atari_metrics(base_env)
if atari_metrics is not None:
Expand All @@ -384,6 +393,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
dict(episode.agent_rewards),
episode.custom_metrics, {}))
else:
hit_horizon = False
all_done = False
active_envs.add(env_id)

Expand Down Expand Up @@ -427,7 +437,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
rewards=rewards[env_id][agent_id],
prev_actions=episode.prev_action_for(agent_id),
prev_rewards=episode.prev_reward_for(agent_id),
dones=agent_done,
dones=(False
if (hit_horizon and soft_horizon) else agent_done),
infos=infos[env_id].get(agent_id, {}),
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
Expand Down Expand Up @@ -457,8 +468,12 @@ def _process_observations(base_env, policies, batch_builder_pool,
"policy": policies,
"episode": episode
})
del active_episodes[env_id]
resetted_obs = base_env.try_reset(env_id)
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs = agent_obs
else:
del active_episodes[env_id]
resetted_obs = base_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
if horizon != float("inf"):
Expand Down
28 changes: 28 additions & 0 deletions python/ray/rllib/tests/test_policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,34 @@ def testRewardClipping(self):
result2 = collect_metrics(ev2, [])
self.assertEqual(result2["episode_reward_mean"], 1000)

def testHardHorizon(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=4,
soft_horizon=False)
samples = ev.sample()
# three logical episodes
self.assertEqual(len(set(samples["eps_id"])), 3)
# 3 done values
self.assertEqual(sum(samples["dones"]), 3)

def testSoftHorizon(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=4,
soft_horizon=True)
samples = ev.sample()
# three logical episodes
self.assertEqual(len(set(samples["eps_id"])), 3)
# only 1 hard done value
self.assertEqual(sum(samples["dones"]), 1)

def testMetrics(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
Expand Down

0 comments on commit 55a2d39

Please sign in to comment.