Skip to content

Commit

Permalink
Add FragmentWorker for Off-Policy RL (#1626)
Browse files Browse the repository at this point in the history
* Add FragmentWorker for Off-Policy RL

This is a sampler.Worker which collects fragments of trajectories,
allowing off-policy RL algorithms to update with partial trajectories.

* Fix bug in AddGaussianNoise

* Fixes frag_worker

The frag_worker now converts env_infos and agent_infos into numpy arrays
before converting constructing trajectories from fragments.

* Fix precommit

Co-authored-by: Avnish Narayan <[email protected]>
  • Loading branch information
krzentner and Avnish Narayan authored Jul 30, 2020
1 parent 42d98f5 commit e8eb175
Show file tree
Hide file tree
Showing 15 changed files with 488 additions and 81 deletions.
4 changes: 3 additions & 1 deletion src/garage/experiment/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def make_sampler(self,
seed=None,
n_workers=psutil.cpu_count(logical=False),
max_path_length=None,
worker_class=DefaultWorker,
worker_class=None,
sampler_args=None,
worker_args=None):
"""Construct a Sampler from a Sampler class.
Expand Down Expand Up @@ -190,6 +190,8 @@ def make_sampler(self,
if max_path_length is None:
raise ValueError('If `sampler_cls` is specified in runner.setup, '
'the algorithm must specify `max_path_length`')
if worker_class is None:
worker_class = getattr(self._algo, 'worker_cls', DefaultWorker)
if seed is None:
seed = get_seed()
if sampler_args is None:
Expand Down
6 changes: 4 additions & 2 deletions src/garage/experiment/local_tf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def make_sampler(self,
seed=None,
n_workers=psutil.cpu_count(logical=False),
max_path_length=None,
worker_class=DefaultWorker,
worker_class=None,
sampler_args=None,
worker_args=None):
"""Construct a Sampler from a Sampler class.
Expand All @@ -131,6 +131,8 @@ def make_sampler(self,
sampler_cls: An instance of the sampler class.
"""
if worker_class is None:
worker_class = getattr(self._algo, 'worker_cls', DefaultWorker)
# pylint: disable=useless-super-delegation
return super().make_sampler(
sampler_cls,
Expand All @@ -147,7 +149,7 @@ def setup(self,
sampler_cls=None,
sampler_args=None,
n_workers=psutil.cpu_count(logical=False),
worker_class=DefaultWorker,
worker_class=None,
worker_args=None):
"""Set up runner and sessions for algorithm and environment.
Expand Down
4 changes: 2 additions & 2 deletions src/garage/np/exploration_policies/add_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_action(self, observation):
action, agent_info = self.policy.get_action(observation)
sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(
1.0, self._iteration * 1.0 / self._decay_period)
return np.clip(action + np.random.normal(size=len(action)) * sigma,
return np.clip(action + np.random.normal(size=action.shape) * sigma,
self._action_space.low,
self._action_space.high), agent_info

Expand All @@ -78,6 +78,6 @@ def get_actions(self, observations):
actions, agent_infos = self.policy.get_actions(observations)
sigma = self._max_sigma - (self._max_sigma - self._min_sigma) * min(
1.0, self._iteration * 1.0 / self._decay_period)
return np.clip(actions + np.random.normal(size=len(actions)) * sigma,
return np.clip(actions + np.random.normal(size=actions.shape) * sigma,
self._action_space.low,
self._action_space.high), agent_infos
6 changes: 6 additions & 0 deletions src/garage/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Samplers which run agents in environments."""
# yapf: disable
from garage.sampler._dtypes import InProgressTrajectory
from garage.sampler._functions import _apply_env_update
from garage.sampler.default_worker import DefaultWorker
from garage.sampler.env_update import (EnvUpdate,
ExistingEnvUpdate,
NewEnvUpdate,
SetTaskUpdate)
from garage.sampler.fragment_worker import FragmentWorker
from garage.sampler.local_sampler import LocalSampler
from garage.sampler.multiprocessing_sampler import MultiprocessingSampler
from garage.sampler.ray_sampler import RaySampler
Expand All @@ -16,6 +19,9 @@
# yapf: enable

__all__ = [
'_apply_env_update',
'InProgressTrajectory',
'FragmentWorker',
'Sampler',
'LocalSampler',
'RaySampler',
Expand Down
87 changes: 87 additions & 0 deletions src/garage/sampler/_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Datatypes used by multiple Samplers or Workers."""
import collections

import numpy as np

from garage import TrajectoryBatch


class InProgressTrajectory:
"""An in-progress trajectory.
Compared to TrajectoryBatch, this datatype does less checking, only
contains one trajectory, and uses lists instead of numpy arrays to make
stepping faster.
Args:
env (gym.Env): The environment the trajectory is being collected in.
initial_observation (np.ndarray): The first observation. If None, the
environment will be reset to generate this observation.
"""

def __init__(self, env, initial_observation=None):
self.env = env
if initial_observation is None:
initial_observation = env.reset()
self.observations = [initial_observation]
self.actions = []
self.rewards = []
self.terminals = []
self.agent_infos = collections.defaultdict(list)
self.env_infos = collections.defaultdict(list)

def step(self, action, agent_info):
"""Step the trajectory using an action from an agent.
Args:
action (np.ndarray): The action taken by the agent.
agent_info (dict[str, np.ndarray]): Extra agent information.
Returns:
np.ndarray: The new observation from the environment.
"""
next_o, r, d, env_info = self.env.step(action)
self.observations.append(next_o)
self.rewards.append(r)
self.actions.append(action)
for k, v in agent_info.items():
self.agent_infos[k].append(v)
for k, v in env_info.items():
self.env_infos[k].append(v)
self.terminals.append(d)
return next_o

def to_batch(self):
"""Convert this in-progress trajectory into a TrajectoryBatch.
Returns:
TrajectoryBatch: This trajectory as a batch.
Raises:
AssertionError: If this trajectory contains no time steps.
"""
assert len(self.rewards) > 0
env_infos = dict(self.env_infos)
agent_infos = dict(self.agent_infos)
for k, v in env_infos.items():
env_infos[k] = np.asarray(v)
for k, v in agent_infos.items():
agent_infos[k] = np.asarray(v)
return TrajectoryBatch(env_spec=self.env.spec,
observations=np.asarray(self.observations[:-1]),
last_observations=np.asarray([self.last_obs]),
actions=np.asarray(self.actions),
rewards=np.asarray(self.rewards),
terminals=np.asarray(self.terminals),
env_infos=env_infos,
agent_infos=agent_infos,
lengths=np.asarray([len(self.rewards)],
dtype='l'))

@property
def last_obs(self):
"""np.ndarray: The last observation in the trajectory."""
return self.observations[-1]
41 changes: 41 additions & 0 deletions src/garage/sampler/_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Functions used by multiple Samplers or Workers."""
import gym

from garage.sampler.env_update import EnvUpdate


def _apply_env_update(old_env, env_update):
"""Use any non-None env_update as a new environment.
A simple env update function. If env_update is not None, it should be
the complete new environment.
This allows changing environments by passing the new environment as
`env_update` into `obtain_samples`.
Args:
old_env (gym.Env): Environment to updated.
env_update (gym.Env or EnvUpdate or None): The environment to
replace the existing env with. Note that other implementations
of `Worker` may take different types for this parameter.
Returns:
gym.Env: The updated environment (may be a different object from
`old_env`).
bool: True if an update happened.
Raises:
TypeError: If env_update is not one of the documented types.
"""
if env_update is not None:
if isinstance(env_update, EnvUpdate):
return env_update(old_env), True
elif isinstance(env_update, gym.Env):
if old_env is not None:
old_env.close()
return env_update, True
else:
raise TypeError('Uknown environment update type.')
else:
return old_env, False
13 changes: 2 additions & 11 deletions src/garage/sampler/default_worker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Default Worker class."""
from collections import defaultdict

import gym
import numpy as np

from garage import TrajectoryBatch
from garage.experiment import deterministic
from garage.sampler.env_update import EnvUpdate
from garage.sampler import _apply_env_update
from garage.sampler.worker import Worker


Expand Down Expand Up @@ -90,15 +89,7 @@ def update_env(self, env_update):
TypeError: If env_update is not one of the documented types.
"""
if env_update is not None:
if isinstance(env_update, EnvUpdate):
self.env = env_update(self.env)
elif isinstance(env_update, gym.Env):
if self.env is not None:
self.env.close()
self.env = env_update
else:
raise TypeError('Uknown environment update type.')
self.env, _ = _apply_env_update(self.env, env_update)

def start_rollout(self):
"""Begin a new rollout."""
Expand Down
Loading

0 comments on commit e8eb175

Please sign in to comment.