diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index eb0c415094e7..256168e54762 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -70,8 +70,13 @@ # === Environment === # Discount factor of the MDP "gamma": 0.99, - # Number of steps after which the episode is forced to terminate + # Number of steps after which the episode is forced to terminate. Defaults + # to `env.spec.max_episode_steps` (if present) for Gym envs. "horizon": None, + # Calculate rewards but don't reset the environment when the horizon is + # hit. This allows value estimation and RNN state to span across logical + # episodes denoted by horizon. This only has an effect if horizon != inf. + "soft_horizon": False, # Arguments to pass to the env creator "env_config": {}, # Environment name can also be passed via config @@ -746,6 +751,7 @@ def session_creator(): output_creator=output_creator, remote_worker_envs=config["remote_worker_envs"], remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], + soft_horizon=config["soft_horizon"], _fake_sampler=config.get("_fake_sampler", False)) @override(Trainable) diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py index acf7d85cb9cc..b7afa222b149 100644 --- a/python/ray/rllib/evaluation/episode.py +++ b/python/ray/rllib/evaluation/episode.py @@ -65,6 +65,19 @@ def __init__(self, policies, policy_mapping_fn, batch_builder_factory, self._agent_to_prev_action = {} self._agent_reward_history = defaultdict(list) + @DeveloperAPI + def soft_reset(self): + """Clears rewards and metrics, but retains RNN and other state. + + This is used to carry state across multiple logical episodes in the + same env (i.e., if `soft_horizon` is set). + """ + self.length = 0 + self.episode_id = random.randrange(2e9) + self.total_reward = 0.0 + self.agent_rewards = defaultdict(float) + self._agent_reward_history = defaultdict(list) + @DeveloperAPI def policy_for(self, agent_id=_DUMMY_AGENT_ID): """Returns the policy graph for the specified agent. diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index c181df0846da..f6d9f70cdc87 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -125,6 +125,7 @@ def __init__(self, output_creator=lambda ioctx: NoopOutput(), remote_worker_envs=False, remote_env_batch_wait_ms=0, + soft_horizon=False, _fake_sampler=False): """Initialize a policy evaluator. @@ -208,6 +209,8 @@ def __init__(self, least one env is ready) is a reasonable default, but optimal value could be obtained by measuring your environment step / reset and model inference perf. + soft_horizon (bool): Calculate rewards but don't reset the + environment when the horizon is hit. _fake_sampler (bool): Use a fake (inf speed) sampler for testing. """ @@ -372,7 +375,8 @@ def make_env(vector_index): pack=pack_episodes, tf_sess=self.tf_sess, clip_actions=clip_actions, - blackhole_outputs="simulation" in input_evaluation) + blackhole_outputs="simulation" in input_evaluation, + soft_horizon=soft_horizon) self.sampler.start() else: self.sampler = SyncSampler( @@ -387,7 +391,8 @@ def make_env(vector_index): horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess, - clip_actions=clip_actions) + clip_actions=clip_actions, + soft_horizon=soft_horizon) self.input_reader = input_creator(self.io_context) assert isinstance(self.input_reader, InputReader), self.input_reader diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 8ca23d9f1459..08f6f2003c0b 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -78,7 +78,8 @@ def __init__(self, horizon=None, pack=False, tf_sess=None, - clip_actions=True): + clip_actions=True, + soft_horizon=False): self.base_env = BaseEnv.to_base_env(env) self.unroll_length = unroll_length self.horizon = horizon @@ -92,7 +93,7 @@ def __init__(self, self.base_env, self.extra_batches.put, self.policies, self.policy_mapping_fn, self.unroll_length, self.horizon, self.preprocessors, self.obs_filters, clip_rewards, clip_actions, - pack, callbacks, tf_sess, self.perf_stats) + pack, callbacks, tf_sess, self.perf_stats, soft_horizon) self.metrics_queue = queue.Queue() def get_data(self): @@ -137,7 +138,8 @@ def __init__(self, pack=False, tf_sess=None, clip_actions=True, - blackhole_outputs=False): + blackhole_outputs=False, + soft_horizon=False): for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." @@ -159,6 +161,7 @@ def __init__(self, self.callbacks = callbacks self.clip_actions = clip_actions self.blackhole_outputs = blackhole_outputs + self.soft_horizon = soft_horizon self.perf_stats = PerfStats() self.shutdown = False @@ -182,7 +185,7 @@ def _run(self): self.policy_mapping_fn, self.unroll_length, self.horizon, self.preprocessors, self.obs_filters, self.clip_rewards, self.clip_actions, self.pack, self.callbacks, self.tf_sess, - self.perf_stats) + self.perf_stats, self.soft_horizon) while not self.shutdown: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is @@ -227,7 +230,7 @@ def get_extra_batches(self): def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn, unroll_length, horizon, preprocessors, obs_filters, clip_rewards, clip_actions, pack, callbacks, tf_sess, - perf_stats): + perf_stats, soft_horizon): """This implements the common experience collection logic. Args: @@ -252,6 +255,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn, tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. perf_stats (PerfStats): Record perf stats into this object. + soft_horizon (bool): Calculate rewards but don't reset the + environment when the horizon is hit. Yields: rollout (SampleBatch): Object containing state, action, reward, @@ -307,7 +312,8 @@ def new_episode(): active_envs, to_eval, outputs = _process_observations( base_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, - preprocessors, obs_filters, unroll_length, pack, callbacks) + preprocessors, obs_filters, unroll_length, pack, callbacks, + soft_horizon) perf_stats.processing_time += time.time() - t1 for o in outputs: yield o @@ -335,7 +341,8 @@ def new_episode(): def _process_observations(base_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, - obs_filters, unroll_length, pack, callbacks): + obs_filters, unroll_length, pack, callbacks, + soft_horizon): """Record new data from the environment and prepare for policy evaluation. Returns: @@ -372,6 +379,8 @@ def _process_observations(base_env, policies, batch_builder_pool, # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: + hit_horizon = (episode.length >= horizon + and not dones[env_id]["__all__"]) all_done = True atari_metrics = _fetch_atari_metrics(base_env) if atari_metrics is not None: @@ -384,6 +393,7 @@ def _process_observations(base_env, policies, batch_builder_pool, dict(episode.agent_rewards), episode.custom_metrics, {})) else: + hit_horizon = False all_done = False active_envs.add(env_id) @@ -427,7 +437,8 @@ def _process_observations(base_env, policies, batch_builder_pool, rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), - dones=agent_done, + dones=(False + if (hit_horizon and soft_horizon) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) @@ -457,8 +468,12 @@ def _process_observations(base_env, policies, batch_builder_pool, "policy": policies, "episode": episode }) - del active_episodes[env_id] - resetted_obs = base_env.try_reset(env_id) + if hit_horizon and soft_horizon: + episode.soft_reset() + resetted_obs = agent_obs + else: + del active_episodes[env_id] + resetted_obs = base_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list if horizon != float("inf"): diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index 827b737ef6e1..093a2b00d226 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -249,6 +249,34 @@ def testRewardClipping(self): result2 = collect_metrics(ev2, []) self.assertEqual(result2["episode_reward_mean"], 1000) + def testHardHorizon(self): + ev = PolicyEvaluator( + env_creator=lambda _: MockEnv(episode_length=10), + policy_graph=MockPolicyGraph, + batch_mode="complete_episodes", + batch_steps=10, + episode_horizon=4, + soft_horizon=False) + samples = ev.sample() + # three logical episodes + self.assertEqual(len(set(samples["eps_id"])), 3) + # 3 done values + self.assertEqual(sum(samples["dones"]), 3) + + def testSoftHorizon(self): + ev = PolicyEvaluator( + env_creator=lambda _: MockEnv(episode_length=10), + policy_graph=MockPolicyGraph, + batch_mode="complete_episodes", + batch_steps=10, + episode_horizon=4, + soft_horizon=True) + samples = ev.sample() + # three logical episodes + self.assertEqual(len(set(samples["eps_id"])), 3) + # only 1 hard done value + self.assertEqual(sum(samples["dones"]), 1) + def testMetrics(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10),