Skip to content

Commit

Permalink
still debugging lstm; and many additions and cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
astooke committed May 21, 2019
1 parent c05ba8d commit ac0c987
Show file tree
Hide file tree
Showing 36 changed files with 885 additions and 295 deletions.
23 changes: 18 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@

Runs reinforcement learning algorithms with parallel sampling and GPU training, if available. Highly modular (modifiable) and optimized codebase with functionality for launching large sets of parallel experiments locally on multi-GPU or many-core machines.

Based on [accel_rl](https://github.com/astooke/accel_rl), which in turn was based on [rllab](https://github.com/rll/rllab).
Based on [accel_rl](https://github.com/astooke/accel_rl), which in turn was based on [rllab](https://github.com/rll/rllab).

Follows the rllab interfaces: agents output `action, agent_info`, environments output `observation, reward, done, env_info`, but introduces new object classes `namedarraytuple` for easier organization. This permits each output to be be either an individual numpy array [torch tensor] or an arbitrary collection of numpy arrays [torch tensors], without changing interfaces. In general, agent inputs/outputs are torch tensors, and environment inputs/ouputs are numpy arrays, with conversions handled automatically.

Recurrent agents are supported, as training batches are organized with leading indexes as `[Time, Batch]`, and agents receive previous action and previous reward as input, in addition to the observation.
Recurrent agents are supported, as training batches are organized with leading indexes as `[Time, Batch]`, and agents receive previous action and previous reward as input, in addition to the observation.

Start from `rlpyt/experiments/scripts/atari/pg/launch/launch_atari_ff_a2c_cpu.py` as a complete example, and follow the code backwards from there. :)
Start from `rlpyt/experiments/scripts/atari/pg/launch/launch_atari_ff_a2c_cpu.py` as a complete example, and follow the code backwards from there. :)


## Current Status

Multi-GPU training within one learning run is not implemented (see [accel_rl](https://github.com/astooke/accel_rl) for hint of how it might be done). Stacking multiple experiments per machine is more effective for testing multiple runs / variations.
Multi-GPU training within one learning run is not implemented (see [accel_rl](https://github.com/astooke/accel_rl) for hint of how it might be done, or maybe easier with PyTorch's data parallel functionality). Stacking multiple experiments per machine is more effective for multiple runs / variations.

A2C is the first algorithm in place. See [accel_rl](https://github.com/astooke/accel_rl) for similar implementations of other algorithms, including DQN, which could be ported.
A2C is the first algorithm in place. See [accel_rl](https://github.com/astooke/accel_rl) for similar implementations of other algorithms, including DQN+variants, which could be ported.


## Visualization
Expand Down Expand Up @@ -45,4 +45,17 @@ pip install -e .
3. Install any packages / files pertaining to desired environments. Atari is included.


## Code Organization

The class types perform the following roles:

* `Runner` - Connects the sampler, agent, and algorithm; manages the training loop and logging of diagnostics.
* `Sampler` - Manages agent / environment interaction to collect training data, can initialize parallel workers.
* `Collector` - Steps environments (and maybe operates agent) and records samples, attached to sampler.
* `Environment` - The task to be learned.
* `Space` - Interface specifications from environment to agent.
* `Agent` - Chooses control action to the environment; trained by the algorithm. Interface to model.
* `Model` - Neural network module, attached to the agent.
* `Distribution` - Samples actions for stochastic agents and defines related formulas for use in loss function, attached to the agent.
* `Algorithm` - Uses gathered samples to train the agent (e.g. defines a loss function and calls gradient descent).
* `Optimizer` - Training update rule (e.g. Adam), attached to the algorithm.
6 changes: 3 additions & 3 deletions rlpyt/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rlpyt.utils.collections import namedarraytuple
from rlpyt.utils.logging import logger

AgentInput = namedarraytuple("AgentInput",
AgentInputs = namedarraytuple("AgentInputs",
["observation", "prev_action", "prev_reward"])
AgentStep = namedarraytuple("AgentStep", ["action", "agent_info"])

Expand Down Expand Up @@ -73,8 +73,8 @@ def reset_one(self, idx):
self._reset_one(idx, self._prev_rnn_state)

def _reset_one(self, idx, prev_rnn_state):
"""Assume each state is of shape: [B, ...], but can be nested
list/tuple. Reset chosen index in the Batch dimension."""
"""Assume each state is of shape: [B, ...], but can be nested tuples.
Reset chosen index in the Batch dimension."""
if isinstance(prev_rnn_state, tuple):
for prev_state in prev_rnn_state:
self._reset_one(idx, prev_state)
Expand Down
21 changes: 10 additions & 11 deletions rlpyt/agents/policy_gradient/atari/atari_lstm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
AgentInfo)
from rlpyt.models.policy_gradient.atari_lstm_model import AtariLstmModel
from rlpyt.distributions.categorical import Categorical, DistInfo
from rlpyt.utils.buffer import buffer_to
from rlpyt.utils.buffer import buffer_to, buffer_func
from rlpyt.algos.policy_gradient.base import AgentTrain


Expand Down Expand Up @@ -38,26 +38,25 @@ def step(self, observation, prev_action, prev_reward):
prev_action = self.distribution.to_onehot(prev_action)
agent_inputs = buffer_to((observation, prev_action, prev_reward),
device=self.device)
prev_rnn_state = self.prev_rnn_state
if prev_rnn_state is not None:
prev_rnn_state = buffer_to(prev_rnn_state, device=self.device)
prev_rnn_state = buffer_to(self.prev_rnn_state, # Model handles None.
device=self.device) if self.prev_rnn_state is not None else None
pi, value, rnn_state = self.model(*agent_inputs, prev_rnn_state)
dist_info = DistInfo(pi)
action = self.distribution.sample(dist_info)
self.advance_rnn_state(rnn_state) # Keep on device?
prev_rnn_state = buffer_func(rnn_state, # Buffer does not handle None.
torch.zeros_like) if prev_rnn_state is None else self.prev_rnn_state
action, agent_info = buffer_to(
(action, AgentInfo(dist_info, value, self.prev_rnn_state)),
(action, AgentInfo(dist_info, value, prev_rnn_state)),
device="cpu")
self.advance_rnn_state(rnn_state) # Do this last. Keep on device?
return AgentStep(action, agent_info)

@torch.no_grad()
def value(self, observation, prev_action, prev_reward):
prev_action = self.distribution.to_onehot(prev_action)
agent_inputs = buffer_to((observation, prev_action, prev_reward),
device=self.device)
prev_rnn_state = self.prev_rnn_state
if prev_rnn_state is not None:
prev_rnn_state = buffer_to(prev_rnn_state, device=self.device)
_pi, value, _rnn_state = self.model(observation, prev_action,
prev_reward, self.prev_rnn_state)
prev_rnn_state = buffer_to(self.prev_rnn_state,
device=self.device) if self.prev_rnn_state is not None else None
_pi, value, _rnn_state = self.model(*agent_inputs, prev_rnn_state)
return value.to("cpu")
8 changes: 4 additions & 4 deletions rlpyt/algos/policy_gradient/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def optimize_agent(self, train_samples, itr):
self.agent.model.parameters(), self.clip_grad_norm)
self.optimizer.step()
opt_info = OptInfo(
Loss=loss.item(),
GradNorm=grad_norm,
Entropy=entropy.item(),
Perplexity=perplexity.item(),
loss=loss.item(),
gradNorm=grad_norm,
entropy=entropy.item(),
perplexity=perplexity.item(),
)
return opt_data, opt_info
8 changes: 4 additions & 4 deletions rlpyt/algos/policy_gradient/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from rlpyt.algos.utils import discount_return

OptData = namedarraytuple("OptData", ["return_", "advantage", "valid"])
OptInfo = namedtuple("OptInfo", ["Loss", "GradNorm", "Entropy", "Perplexity"])
# Convention: traj_info fields CamelCase, opt_info fields lowerCamelCase
OptInfo = namedtuple("OptInfo", ["loss", "gradNorm", "entropy", "perplexity"])
AgentTrain = namedtuple("AgentTrain", ["dist_info", "value"])


Expand All @@ -31,8 +32,7 @@ def process_samples(self, samples):
samples.agent.bootstrap_value, self.discount)
advantage = return_ - samples.agent.agent_info.value
if self.mid_batch_reset:
valid = torch.ones_like(done)
valid = torch.ones_like(done) # or None.
else:
valid = 1 - torch.clamp(torch.cumsum(done, dim=0),
max=1)
valid = 1 - torch.clamp(torch.cumsum(done, dim=0), max=1)
return return_, advantage, valid
29 changes: 22 additions & 7 deletions rlpyt/algos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,26 @@

def discount_return(reward, done, bootstrap_value, discount, return_dest=None):
"""Time-major inputs, optional batch dimension: [T] or [T, B]."""
return_ = torch.zeros(reward.shape, dtype=reward.dtype) \
if return_dest is None else return_dest
last_return = bootstrap_value.clone() # (clone, I think?)
for t in reversed(range(len(reward))):
last_return *= discount * (1 - done[t])
last_return += reward[t]
return_[t] = last_return
return_ = return_dest if return_dest is not None else torch.zeros(
reward.shape, dtype=reward.dtype)
not_done = 1 - done
return_[-1] = reward[-1] + discount * bootstrap_value * not_done[-1]
for t in reversed(range(len(reward) - 1)):
return_[t] = reward[t] + return_[t + 1] * discount * not_done[t]
return return_


def generalized_advantage_estimation(reward, value, done, bootstrap_value,
discount, gae_lambda, advantage_dest=None, return_dest=None):
"""Time-major inputs, optional batch dimension: [T] or [T, B]."""
advantage = advantage_dest if advantage_dest is not None else torch.zeros(
reward.shape, dtype=reward.dtype)
return_ = return_dest if return_dest is not None else torch.zeros(
reward.shape, dtype=reward.dtype)
nd = 1 - done
advantage[-1] = reward[-1] + discount * bootstrap_value * nd[-1] - value[-1]
for t in reversed(range(len(reward) - 1)):
delta = reward[t] + discount * value[t + 1] * nd[t] - value[t]
advantage[t] = delta + discount * gae_lambda * nd[t] * advantage[t + 1]
return_[:] = advantage + value
return advantage, return_
23 changes: 11 additions & 12 deletions rlpyt/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from rlpyt.distributions.base import Distribution
from rlpyt.utils.collections import namedarraytuple
from rlpyt.utils import tensor
from rlpyt.utils.tensor import (valid_mean, select_at_indexes, to_onehot,
from_onehot)

EPS = 1e-8

Expand All @@ -24,11 +25,11 @@ def dim(self):

def kl(self, old_dist_info, new_dist_info):
p = old_dist_info.prob # TODO: check order of p and q.
q = new_dist_info.prob # TODO: check numerically safe implementation.
q = new_dist_info.prob
return torch.sum(p * (torch.log(p + EPS) - torch.log(q + EPS)), dim=-1)

def mean_kl(self, old_dist_info, new_dist_info, valid):
return tensor.valid_mean(self.kl(old_dist_info, new_dist_info), valid)
return valid_mean(self.kl(old_dist_info, new_dist_info), valid)

def sample(self, dist_info):
p = dist_info.prob
Expand All @@ -43,26 +44,24 @@ def perplexity(self, dist_info):
return torch.exp(self.entropy(dist_info))

def mean_entropy(self, dist_info, valid=None):
return tensor.valid_mean(self.entropy(dist_info), valid)
return valid_mean(self.entropy(dist_info), valid)

def mean_perplexity(self, dist_info, valid=None):
return tensor.valid_mean(self.perplexity(dist_info), valid)
return valid_mean(self.perplexity(dist_info), valid)

def log_likelihood(self, indexes, dist_info):
selected_likelihood = tensor.select_at_indexes(indexes, dist_info.prob)
selected_likelihood = select_at_indexes(indexes, dist_info.prob)
return torch.log(selected_likelihood + EPS)

def likelihood_ratio(self, indexes, old_dist_info, new_dist_info):
num = tensor.select_at_indexes(indexes, new_dist_info.prob)
den = tensor.select_at_indexes(indexes, old_dist_info.prob)
num = select_at_indexes(indexes, new_dist_info.prob)
den = select_at_indexes(indexes, old_dist_info.prob)
return (num + EPS) / (den + EPS)

def to_onehot(self, indexes, dtype=None):
dtype = self.onehot_dtype if dtype is None else dtype
return tensor.to_onehot(indexes, self._dim, dtype=dtype)
return to_onehot(indexes, self._dim, dtype=dtype)

def from_onehot(self, onehot, dtype=None):
dtype = self.dtype if dtype is None else dtype
return tensor.from_onehot(onehot, dtpye=dtype)


return from_onehot(onehot, dtpye=dtype)
66 changes: 43 additions & 23 deletions rlpyt/envs/atari/atari_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,24 @@
from rlpyt.envs.base import Env, EnvStep
from rlpyt.spaces.int_box import IntBox
from rlpyt.utils.quick_args import save__init__args
from rlpyt.samplers.collections import TrajInfo


W, H = (80, 104) # Fixed size: crop two rows, then downsample by 2x.


EnvInfo = namedtuple("EnvInfo", ["raw_reward", "need_reset"])
EnvInfo = namedtuple("EnvInfo", ["game_score", "need_reset"])


class AtariTrajInfo(TrajInfo):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.GameScore = 0

def step(self, _observation, _action, reward, _agent_info, env_info):
super().step(_observation, _action, reward, _agent_info, env_info)
self.GameScore += getattr(env_info, "game_score", 0)


class AtariEnv(Env):
Expand All @@ -26,6 +38,7 @@ def __init__(self,
episodic_lives=True,
max_start_noops=30,
repeat_action_probability=0.,
horizon=27000,
):
save__init__args(locals(), underscore=True)
# ALE
Expand All @@ -41,8 +54,8 @@ def __init__(self,
self._action_set = self.ale.getMinimalActionSet()
self._action_space = IntBox(low=0, high=len(self._action_set))
obs_shape = (num_img_obs, H, W)
self._observation_space = IntBox(low=0, high=255,
shape=obs_shape, dtype="uint8")
self._observation_space = IntBox(low=0, high=255, shape=obs_shape,
dtype="uint8")
self._max_frame = self.ale.getScreenGrayscale()
self._raw_frame_1 = self._max_frame.copy()
self._raw_frame_2 = self._max_frame.copy()
Expand All @@ -53,21 +66,31 @@ def __init__(self,
self._has_up = "UP" in self.get_action_meanings()
self._done = self._done_episodic_lives if episodic_lives else \
self._done_no_epidosic_lives
self._horizon = int(horizon)

# Get ready
self.reset()
def reset(self):
"""Call reset when first using AtariEnv, before first step."""
self.ale.reset_game()
self._reset_obs()
self._life_reset()
for _ in range(np.random.randint(0, self._max_start_noops + 1)):
self.ale.act(0)
self._update_obs() # (don't bother to populate any frame history)
self._step_counter = 0
return self.get_obs()

def step(self, action):
a = self._action_set[action]
raw_reward = np.array(0., dtype="float32")
game_score = np.array(0., dtype="float32")
for _ in range(self._frame_skip - 1):
raw_reward += self.ale.act(a)
game_score += self.ale.act(a)
self._get_screen(1)
raw_reward += self.ale.act(a)
game_score += self.ale.act(a)
self._update_obs()
reward = np.sign(raw_reward) if self._clip_reward else raw_reward
reward = np.sign(game_score) if self._clip_reward else game_score
done, need_reset = self._done()
info = EnvInfo(raw_reward, need_reset)
info = EnvInfo(game_score=game_score, need_reset=need_reset)
self._step_counter += 1
return EnvStep(self.get_obs(), reward, done, info)

def render(self, wait=10, show_full_obs=False):
Expand All @@ -83,15 +106,6 @@ def render(self, wait=10, show_full_obs=False):
def get_obs(self):
return self._obs.copy()

def reset(self):
self.ale.reset_game()
self._reset_obs()
self._life_reset()
for _ in range(np.random.randint(0, self._max_start_noops + 1)):
self.ale.act(0)
self._update_obs() # (don't bother to populate any frame history)
return self.get_obs()

###########################################################################
# Helpers

Expand All @@ -104,7 +118,7 @@ def _update_obs(self):
self._get_screen(2)
np.maximum(self._raw_frame_1, self._raw_frame_2, self._max_frame)
img = cv2.resize(self._max_frame[1:-1], (W, H), cv2.INTER_NEAREST)
# NOTE: this order--oldest to newest--needed for ReplayFrameBuffer in DQN!!
# NOTE: Order oldest to newest needed for ReplayFrameBuffer in DQN.
self._obs = np.concatenate([self._obs[1:], img[np.newaxis]])

def _reset_obs(self):
Expand All @@ -131,16 +145,17 @@ def _life_reset(self):

def _done_no_epidosic_lives(self):
self._check_life()
done = self.ale.game_over()
done = self.ale.game_over() or self._step_counter >= self.horizon
return done, done

def _done_episodic_lives(self):
need_reset = self.ale.game_over()
need_reset = self.ale.game_over() or self._step_counter >= self.horizon
lost_life = self._check_life()
done = lost_life or need_reset
if lost_life:
self._reset_obs() # (reset here, so sampler does NOT call reset)
self._update_obs() # (will have already advanced in check_life)
return lost_life or need_reset, need_reset
return done, need_reset

###########################################################################
# Properties
Expand Down Expand Up @@ -173,6 +188,10 @@ def episodic_lives(self):
def repeat_action_probability(self):
return self._repeat_action_probability

@property
def horizon(self):
return self._horizon

def get_action_meanings(self):
return [ACTION_MEANING[i] for i in self._action_set]

Expand All @@ -199,3 +218,4 @@ def get_action_meanings(self):
}

ACTION_INDEX = {v: k for k, v in ACTION_MEANING.items()}

1 change: 0 additions & 1 deletion rlpyt/experiments/configs/atari/pg/atari_ff_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
sampler=dict(
batch_T=5,
batch_B=64,
max_path_length=27000,
max_decorrelation_steps=1000,
),
)
Expand Down
Loading

0 comments on commit ac0c987

Please sign in to comment.