Skip to content

Commit

Permalink
minor tweaks + bypass_lstm experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
astooke committed May 21, 2019
1 parent dc2480d commit d04e2ab
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 17 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# rlpyt

## Deep Reinforcement Learning in PyTorch

Runs reinforcement learning algorithms with parallel sampling and GPU training, if available. Highly modular (modifiable) and optimized codebase with functionality for launching large sets of parallel experiments locally on multi-GPU or many-core machines.

Based on [accel_rl](https://github.com/astooke/accel_rl), which in turn was based on [rllab](https://github.com/rll/rllab).
Based on [accel_rl](https://github.com/astooke/accel_rl), which in turn was based on [rllab](https://github.com/rll/rllab) (the `logger` is nearly a direct copy).

Follows the rllab interfaces: agents output `action, agent_info`, environments output `observation, reward, done, env_info`, but introduces new object classes `namedarraytuple` for easier organization. This permits each output to be be either an individual numpy array [torch tensor] or an arbitrary collection of numpy arrays [torch tensors], without changing interfaces. In general, agent inputs/outputs are torch tensors, and environment inputs/ouputs are numpy arrays, with conversions handled automatically.
Follows the rllab interfaces: agents output `action, agent_info`, environments output `observation, reward, done, env_info`, but introduces new object classes `namedarraytuple` for easier organization (see `rlpyt/utils/collections.py`). This permits each output to be be either an individual numpy array [torch tensor] or an arbitrary collection of numpy arrays [torch tensors], without changing interfaces. In general, agent inputs/outputs are torch tensors, and environment inputs/ouputs are numpy arrays, with conversions handled automatically.

Recurrent agents are supported, as training batches are organized with leading indexes as `[Time, Batch]`, and agents receive previous action and previous reward as input, in addition to the observation.

Start from `rlpyt/experiments/scripts/atari/pg/launch/launch_atari_ff_a2c_cpu.py` as a complete example, and follow the code backwards from there. :)
Start from `rlpyt/experiments/scripts/atari/pg/launch/launch_atari_ff_a2c_cpu.py` as a complete launch script example, and follow the code backwards from there. :)



## Current Status
Expand All @@ -19,11 +21,13 @@ Multi-GPU training within one learning run is not implemented (see [accel_rl](ht
A2C is the first algorithm in place. See [accel_rl](https://github.com/astooke/accel_rl) for similar implementations of other algorithms, including DQN+variants, which could be ported.



## Visualization

This package does not include its own visualization, as the logged data is compatible with previous editions (see above). For more features, use [https://github.com/vitchyr/viskit](https://github.com/vitchyr/viskit).



## Installation

1. Install the anaconda environment appropriate for the machine.
Expand All @@ -45,6 +49,7 @@ pip install -e .
3. Install any packages / files pertaining to desired environments. Atari is included.



## Code Organization

The class types perform the following roles:
Expand All @@ -54,8 +59,10 @@ The class types perform the following roles:
* **Collector** - Steps `environments` (and maybe operates `agent`) and records samples, attached to `sampler`.
* **Environment** - The task to be learned.
* **Space** - Interface specifications from `environment` to `agent`.
* **TrajectoryInfo** - Diagnostics logged on a per-trajectory basis.
* **Agent** - Chooses control action to the `environment` in `sampler`; trained by the `algorithm`. Interface to `model`.
* **Model** - Neural network module, attached to the `agent`.
* **Distribution** - Samples actions for stochastic `agents` and defines related formulas for use in loss function, attached to the `agent`.
* **Algorithm** - Uses gathered samples to train the `agent` (e.g. defines a loss function and performs gradient descent).
* **Optimizer** - Training update rule (e.g. Adam), attached to the `algorithm`.
* **OptimizationInfo** - Diagnostics logged on a per-training batch basis.
6 changes: 3 additions & 3 deletions rlpyt/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def initialize_cuda(self, cuda_idx=None):
"""Call after initialize() and after forking sampler workers."""
if cuda_idx is None:
return # CPU
if self._memory_is_shared:
self.shared_model = self.model
self.model = self.ModelCls(**self.env_model_kwargs, **self.model_kwargs)
if self.shared_model is not None: # (If model using shared memory.)
self.model = self.ModelCls(**self.env_model_kwargs,
**self.model_kwargs)
self.model.load_state_dict(self.shared_model.state_dict())
self.device = torch.device("cuda", index=cuda_idx)
self.model.to(self.device)
Expand Down
5 changes: 2 additions & 3 deletions rlpyt/agents/policy_gradient/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

import torch
from collections import namedtuple

from rlpyt.utils.collections import namedarraytuple
from rlpyt.agents.base import BaseAgent, BaseRecurrentAgent
from rlpyt.agents.base import BaseAgent

AgentInfo = namedarraytuple("AgentInfo", ["dist_info", "value"])

Expand All @@ -17,7 +16,7 @@ def initialize(self, env_spec, share_memory=False):
self.model = self.ModelCls(**env_model_kwargs, **self.model_kwargs)
if share_memory:
self.model.share_memory()
self._memory_is_shared = share_memory
self.shared_model = self.model
if self.initial_model_state_dict is not None:
self.model.load_state_dict(self.initial_model_state_dict)
self.env_spec = env_spec
Expand Down
7 changes: 7 additions & 0 deletions rlpyt/experiments/configs/atari/pg/atari_lstm_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@
config["sampler"]["batch_B"] = 16
config["algo"]["learning_rate"] = 1e-4
configs["4frame"] = config


config = copy.deepcopy(config)
config["algo"]["learning_rate"] = 7e-4
config["sampler"]["batch_B"] = 32
config["algo"]["clip_grad_norm"] = 1
configs["like_ff"] = config
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,27 @@
n_socket=1,
# cpu_per_run=2,
)
runs_per_setting = 1
experiment_title = "lstm_4frame_test"
runs_per_setting = 2
experiment_title = "bypass_lstm_test"
variant_levels = list()

learning_rate = [1e-4] * 4
entropy_loss_coeff = [0.01, 0.4, 0.04, 0.1]
learning_rate = [5e-4] * 4
entropy_loss_coeff = [0.01, 0.02, 0.04, 0.08]
values = list(zip(learning_rate, entropy_loss_coeff))
dir_names = ["test_{}lr_{}ent".format(*v) for v in values]
keys = [("algo", "learning_rate"), ("algo", "entropy_loss_coeff")]
variant_levels.append(VariantLevel(keys, values, dir_names))


games = ["seaquest"]
games = ["pong", "seaquest"]
values = list(zip(games))
dir_names = ["{}".format(*v) for v in values]
keys = [("env", "game")]
variant_levels.append(VariantLevel(keys, values, dir_names))

variants, log_dirs = make_variants(*variant_levels)

default_config_key = "4frame"
default_config_key = "like_ff"

run_experiments(
script=script,
Expand Down
7 changes: 5 additions & 2 deletions rlpyt/models/policy_gradient/atari_lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ def forward(self, image, prev_action, prev_reward, init_rnn_state):
next_rnn_state = (hn.transpose(0, 1), cn.transpose(0, 1)) # --> [B,N,H]
lstm_flat = lstm_out.view(T * B, -1)

pi = F.softmax(self.linear_pi(lstm_flat), dim=-1)
v = self.linear_v(lstm_flat).squeeze(-1)
# pi = F.softmax(self.linear_pi(lstm_flat), dim=-1)
# v = self.linear_v(lstm_flat).squeeze(-1)
# DEBUG: bypass the LSTM.
pi = F.softmax(self.linear_pi(fc_out), dim=-1)
v = self.linear_v(fc_out).squeeze(-1)

# Restore leading dimensions: [T,B], [B], or [], as input.
pi, v = restore_leading_dims((pi, v), T, B, has_T, has_B)
Expand Down

0 comments on commit d04e2ab

Please sign in to comment.