-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FragmentWorker for Off-Policy RL (#1626)
* Add FragmentWorker for Off-Policy RL This is a sampler.Worker which collects fragments of trajectories, allowing off-policy RL algorithms to update with partial trajectories. * Fix bug in AddGaussianNoise * Fixes frag_worker The frag_worker now converts env_infos and agent_infos into numpy arrays before converting constructing trajectories from fragments. * Fix precommit Co-authored-by: Avnish Narayan <[email protected]>
- Loading branch information
Showing
15 changed files
with
488 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Datatypes used by multiple Samplers or Workers.""" | ||
import collections | ||
|
||
import numpy as np | ||
|
||
from garage import TrajectoryBatch | ||
|
||
|
||
class InProgressTrajectory: | ||
"""An in-progress trajectory. | ||
Compared to TrajectoryBatch, this datatype does less checking, only | ||
contains one trajectory, and uses lists instead of numpy arrays to make | ||
stepping faster. | ||
Args: | ||
env (gym.Env): The environment the trajectory is being collected in. | ||
initial_observation (np.ndarray): The first observation. If None, the | ||
environment will be reset to generate this observation. | ||
""" | ||
|
||
def __init__(self, env, initial_observation=None): | ||
self.env = env | ||
if initial_observation is None: | ||
initial_observation = env.reset() | ||
self.observations = [initial_observation] | ||
self.actions = [] | ||
self.rewards = [] | ||
self.terminals = [] | ||
self.agent_infos = collections.defaultdict(list) | ||
self.env_infos = collections.defaultdict(list) | ||
|
||
def step(self, action, agent_info): | ||
"""Step the trajectory using an action from an agent. | ||
Args: | ||
action (np.ndarray): The action taken by the agent. | ||
agent_info (dict[str, np.ndarray]): Extra agent information. | ||
Returns: | ||
np.ndarray: The new observation from the environment. | ||
""" | ||
next_o, r, d, env_info = self.env.step(action) | ||
self.observations.append(next_o) | ||
self.rewards.append(r) | ||
self.actions.append(action) | ||
for k, v in agent_info.items(): | ||
self.agent_infos[k].append(v) | ||
for k, v in env_info.items(): | ||
self.env_infos[k].append(v) | ||
self.terminals.append(d) | ||
return next_o | ||
|
||
def to_batch(self): | ||
"""Convert this in-progress trajectory into a TrajectoryBatch. | ||
Returns: | ||
TrajectoryBatch: This trajectory as a batch. | ||
Raises: | ||
AssertionError: If this trajectory contains no time steps. | ||
""" | ||
assert len(self.rewards) > 0 | ||
env_infos = dict(self.env_infos) | ||
agent_infos = dict(self.agent_infos) | ||
for k, v in env_infos.items(): | ||
env_infos[k] = np.asarray(v) | ||
for k, v in agent_infos.items(): | ||
agent_infos[k] = np.asarray(v) | ||
return TrajectoryBatch(env_spec=self.env.spec, | ||
observations=np.asarray(self.observations[:-1]), | ||
last_observations=np.asarray([self.last_obs]), | ||
actions=np.asarray(self.actions), | ||
rewards=np.asarray(self.rewards), | ||
terminals=np.asarray(self.terminals), | ||
env_infos=env_infos, | ||
agent_infos=agent_infos, | ||
lengths=np.asarray([len(self.rewards)], | ||
dtype='l')) | ||
|
||
@property | ||
def last_obs(self): | ||
"""np.ndarray: The last observation in the trajectory.""" | ||
return self.observations[-1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Functions used by multiple Samplers or Workers.""" | ||
import gym | ||
|
||
from garage.sampler.env_update import EnvUpdate | ||
|
||
|
||
def _apply_env_update(old_env, env_update): | ||
"""Use any non-None env_update as a new environment. | ||
A simple env update function. If env_update is not None, it should be | ||
the complete new environment. | ||
This allows changing environments by passing the new environment as | ||
`env_update` into `obtain_samples`. | ||
Args: | ||
old_env (gym.Env): Environment to updated. | ||
env_update (gym.Env or EnvUpdate or None): The environment to | ||
replace the existing env with. Note that other implementations | ||
of `Worker` may take different types for this parameter. | ||
Returns: | ||
gym.Env: The updated environment (may be a different object from | ||
`old_env`). | ||
bool: True if an update happened. | ||
Raises: | ||
TypeError: If env_update is not one of the documented types. | ||
""" | ||
if env_update is not None: | ||
if isinstance(env_update, EnvUpdate): | ||
return env_update(old_env), True | ||
elif isinstance(env_update, gym.Env): | ||
if old_env is not None: | ||
old_env.close() | ||
return env_update, True | ||
else: | ||
raise TypeError('Uknown environment update type.') | ||
else: | ||
return old_env, False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.