Skip to content
/ rlpyt Public
forked from astooke/rlpyt

Reinforcement Learning in PyTorch

License

Notifications You must be signed in to change notification settings

bqzhu922/rlpyt

Repository files navigation

rlpyt

Deep Reinforcement Learning in PyTorch

Modular, optimized implementations of common deep RL algorithms in PyTorch, with unified infrastructure supporting all three major families: policy gradient, deep-q learning, and q-function policy gradient. Intended to be a high-throughput code-base for small- to medium-scale research (large-scale meaning like OpenAI Dota with 100's GPUs). Key capabilities/features include:

  • Run experiments in serial mode (helpful for debugging during development, or maybe sufficient for experiments).
  • Run experiments fully parallelized, with options for parallel sampling and/or multi-GPU optimization.
    • Multi-GPU optimization uses PyTorch's DistributedDataParallel, which supports gradient reduction concurrent with backprop.
  • Use CPU or GPU for training and/or batched action selection during environment sampling.
  • Full support for recurrent agents.
    • All agents receive observation, prev_action, prev_reward.
    • Training data always organized with leading indexes as [Time, Batch].
  • Launching utilities for stacking/queueing sets of experiments in parallel on given local hardware resources (e.g. run 40 experiments on an 8-GPU machine with 1 experiment per GPU at a time).
  • Compatible with the OpenAI Gym environment interface.1

Implemented Algorithms

Policy Gradient A2C, PPO.

Replay Buffers (supporting both DQN + QPG) non-sequence and sequence (for recurrent) replay, n-step returns, uniform or prioritized replay, full-observation or frame-based buffer (e.g. for Atari, stores only unique frames to save memory, reconstructs multi-frame observations).

Deep Q-Learning DQN + variants: Double, Dueling, Categorical (up to Rainbow minus Noisy Nets), Recurrent (R2D2-style).

Q-Function Policy Gradient DDPG, TD3, SAC.

All implementations, except possibly recurrent DQN, produce learning curves verifying performance against published results, to be posted soon (docs still under construction).

Getting Started

Follow the installation instructions below, and then get started in the examples folder. Example scripts are ordered by increasing complexity.

For newcomers to deep RL, we recommend familiarizing with algorithms using a different resource, such as the excellent OpenAI Spinning Up docs, code.

New data structure: namedarraytuple

Rlpyt introduces new object classes namedarraytuple for easier organization of collections of numpy arrays / torch tensors. (see rlpyt/utils/collections.py). A namedarraytuple is essentially a namedtuple which exposes indexed or sliced read/writes into the structure. For example, consider writing into a (possibly nested) dictionary of arrays:

for k, v in src.items():
  if isinstance(dest[k], dict):
    ..recurse..
  dest[k][slice_or_indexes] = v

This code is replaced by the following:

dest[slice_or_indexes] = src

Importantly, this syntax looks the same whether dest and src are indiviual numpy arrays or arbitrarily-structured collections of arrays (the structures of dest and src must match). Rlpyt uses this data structure extensively--different elements of training data are organized with the same leading dimensions, making it easy to interact with desired time- or batch-dimensions.

This is also intended to support environments with multi-modal observations or actions. For example, consider and environment with joint-angle and camera-image observations. Rather than flattening and merging these into one observation vector, the environment can store them as-is into a namedarraytuple for the observation, and in the forward method of the model, observation.joint and observation.image can be fed into the desired layers. Intermediate infrastructure code need not change.

Future Developments.

Overall the code is fairly stable but still developing, expect changes.

My list

  • Asynchronous sampling/optimization mode. Should yield major speedups for e.g. Atari DQN, so sampler and optimizer both always running full speed, vs now when code is either sampling or optimizing.
  • Alternating GPU sampler. Should yield speedups for sampling, like the old accel_rl code, two sets of environments: one steps while the other waits for actions back from GPU.

Welcome contributions

  • Utilities for running experiments in the cloud and/or with docker. Maybe in a separate "cloud" folder oustide of rlpyt?
  • Other established algorithms (but not trying to clutter the master repo with countless variations). Feel free to suggest.
  • Other ideas?

Visualization

This package does not include its own visualization, as the logged data is compatible with previous editions (see below). For more features, use https://github.com/vitchyr/viskit.

Installation

  1. Clone this repository to the local machine.

  2. Install the anaconda environment appropriate for the machine.

conda env create -f linux_[cpu|cuda9|cuda10].yml
source activate rlpyt
  1. Either A) Edit the PYTHONPATH to include the rlpyt directory, or B) Install as editable python package
#A
export PYTHONPATH=path_to_rlpyt:$PYTHONPATH

#B
pip install -e .
  1. Install any packages / files pertaining to desired environments (e.g. gym, mujoco). Atari is included.

Hint: for easy access, add the following to your ~/.bashrc (might substitute conda for source).

alias rlpyt="source activate rlpyt; cd path_to_rlpyt"

Extended Notes

Code Organization

The class types perform the following roles:

  • Runner - Connects the sampler, agent, and algorithm; manages the training loop and logging of diagnostics.
    • Sampler - Manages agent / environment interaction to collect training data, can initialize parallel workers.
      • Collector - Steps environments (and maybe operates agent) and records samples, attached to sampler.
        • Environment - The task to be learned.
          • Space - Interface specifications from environment to agent.
        • TrajectoryInfo - Diagnostics logged on a per-trajectory basis.
    • Agent - Chooses control action to the environment in sampler; trained by the algorithm. Interface to model.
      • Model - Torch neural network module, attached to the agent.
      • Distribution - Samples actions for stochastic agents and defines related formulas for use in loss function, attached to the agent.
    • Algorithm - Uses gathered samples to train the agent (e.g. defines a loss function and performs gradient descent).
      • Optimizer - Training update rule (e.g. Adam), attached to the algorithm.
      • OptimizationInfo - Diagnostics logged on a per-training batch basis.

Historical, Scaling, Interfaces

This code is a revision of accel_rl, which explored scaling RL in the Atari domain using Theano. Scaling results were recorded here: A. Stooke & P. Abbeel, "Accelerated Methods for Deep Reinforcement Learning". For a broader and deeper study of scaling, see S. McCandlish, et. al "An Empirical Model of Large-Batch Trianing".

Accel_rl was inspired by rllab (the logger here is nearly a direct copy). Rlpyt follows the rllab interfaces: agents output action, agent_info, environments output observation, reward, done, env_info. In general in rlpyt, agent inputs/outputs are torch tensors, and environment inputs/ouputs are numpy arrays, with conversions handled automatically.

  1. Regarding OpenAI Gym compatibility, rlpyt uses a namedtuple for env_info rather than a dict. This makes for easier data recording but does require the same fields to be output at every environment step. An environment wrapper is provided.

Acknowledgements

Thank you to the Fannie & John Hertz Foundation and the NVIDIA Corporation for generous support of graduate studies. Thanks to Pieter Abbeel for extensive support and advising, Max Jaderberg for further mentoring, the BAIR community, and OpenAI for patient support during finishing stages. And thanks in advance to any contributors! Happy reinforcement learning!

About

Reinforcement Learning in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%