Skip to content

Commit

Permalink
[rllib] Initial RLLib documentation (ray-project#969)
Browse files Browse the repository at this point in the history
* initial documentation for RLLib

* more RL documentation

* fix linting

* fix comments

* update

* fix
  • Loading branch information
pcmoritz authored and robertnishihara committed Sep 13, 2017
1 parent 9ec3608 commit 1eb8c83
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 5 deletions.
8 changes: 7 additions & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@

# These lines added to enable Sphinx to work without installing Ray.
import mock
MOCK_MODULES = ["pyarrow",
MOCK_MODULES = ["gym",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.slim",
"tensorflow.contrib.rnn",
"pyarrow",
"pyarrow.plasma",
"smart_open",
"ray.local_scheduler",
"ray.plasma",
"ray.core.generated.TaskInfo",
Expand Down
1 change: 1 addition & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Ray
api.rst
actors.rst
using-ray-with-gpus.rst
rllib.rst

.. toctree::
:maxdepth: 1
Expand Down
159 changes: 159 additions & 0 deletions doc/source/rllib.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
RLLib: Ray's scalable reinforcement learning library
====================================================

This document describes Ray's reinforcement learning library.
It currently supports the following algorithms:

- `Proximal Policy Optimization <https://arxiv.org/abs/1707.06347>`__ which
is a proximal variant of `TRPO <https://arxiv.org/abs/1502.05477>`__.

- Evolution Strategies which is decribed in `this
paper <https://arxiv.org/abs/1703.03864>`__. Our implementation
borrows code from
`here <https://github.com/openai/evolution-strategies-starter>`__.

- `The Asynchronous Advantage Actor-Critic <https://arxiv.org/abs/1602.01783>`__
based on `the OpenAI starter agent <https://github.com/openai/universe-starter-agent>`__.

Proximal Policy Optimization scales to hundreds of cores and several GPUs,
Evolution Strategies to clusters with thousands of cores and
the Asynchronous Advantage Actor-Critic scales to dozens of cores
on a single node.

These algorithms can be run on any OpenAI gym MDP, including custom ones written
and registered by the user.

Getting Started
---------------

You can run training with

::

python ray/python/ray/rllib/train.py --env CartPole-v0 --alg PPO --config '{"timesteps_per_batch": 10000}'

By default, the results will be logged to a subdirectory of ``/tmp/ray``.
This subdirectory will contain a file ``config.json`` which contains the
hyperparameters, a file ``result.json`` which contains a training summary
for each episode and a TensorBoard file that can be used to visualize
training process with TensorBoard by running

::

tensorboard --logdir=/tmp/ray


The ``train.py`` script has a number of options you can show by running

::

python ray/python/ray/rllib/train.py --help

The most important options are for choosing the environment
with ``--env`` (any OpenAI gym environment including ones registered by the user
can be used) and for choosing the algorithm with ``--alg``
(available options are ``PPO``, ``A3C``, ``ES`` and ``DQN``). Each algorithm
has specific hyperparameters that can be set with ``--config``, see the
``DEFAULT_CONFIG`` variable in
`PPO <https://github.com/ray-project/ray/blob/master/python/ray/rllib/ppo/ppo.py>`__,
`A3C <https://github.com/ray-project/ray/blob/master/python/ray/rllib/a3c/a3c.py>`__,
`ES <https://github.com/ray-project/ray/blob/master/python/ray/rllib/es/es.py>`__ and
`DQN <https://github.com/ray-project/ray/blob/master/python/ray/rllib/dqn/dqn.py>`__.


Examples
--------

Some good hyperparameters and settings are available in
`the repository <https://github.com/ray-project/ray/blob/master/python/ray/rllib/test/tuned_examples.sh>`__
(some of them are tuned to run on GPUs). If you find better settings or tune
an algorithm on a different domain, consider submitting a Pull Request!

The User API
------------

You will be using this part of the API if you run the existing algorithms
on a new problem. Note that the API is not considered to be stable yet.
Here is an example how to use it:

::

import ray
import ray.rllib.ppo as ppo

ray.init()

config = ppo.DEFAULT_CONFIG.copy()
alg = ppo.PPOAgent("CartPole-v1", config)

# Can optionally call alg.restore(path) to load a checkpoint.

for i in range(10):
# Perform one iteration of the algorithm.
result = alg.train()
print("result: {}".format(result))
print("checkpoint saved at path: {}".format(alg.save()))

The Developer API
-----------------

This part of the API will be useful if you need to change existing RL algorithms
or implement new ones. Note that the API is not considered to be stable yet.

Agents
~~~~~~

Agents implement a particular algorithm and can be used to run
some number of iterations of the algorithm, save and load the state
of training and evaluate the current policy. All agents inherit from
a common base class:

.. autoclass:: ray.rllib.common.Agent
:members:

Models
~~~~~~

Models are subclasses of the Model class:

.. autoclass:: ray.rllib.models.Model

Currently we support fully connected policies, convolutional policies and
LSTMs:

.. autofunction:: ray.rllib.models.FullyConnectedNetwork
.. autofunction:: ray.rllib.models.ConvolutionalNetwork
.. autofunction:: ray.rllib.models.LSTM

Action Distributions
~~~~~~~~~~~~~~~~~~~~

Actions can be sampled from different distributions, they have a common base
class:

.. autoclass:: ray.rllib.models.ActionDistribution
:members:

Currently we support the following action distributions:

.. autofunction:: ray.rllib.models.Categorical
.. autofunction:: ray.rllib.models.DiagGaussian
.. autofunction:: ray.rllib.models.Deterministic

The Model Catalog
~~~~~~~~~~~~~~~~~

To make picking the right action distribution and models easier, there is
a mechanism to pick good default values for various gym environments.

.. autoclass:: ray.rllib.models.ModelCatalog
:members:

Using RLLib on a cluster
------------------------

First create a cluster as described in `managing a cluster with parallel ssh`_.
You can then run RLLib on this cluster by passing the address of the main redis
shard into ``train.py`` with ``--redis-address``.

.. _`managing a cluster with parallel ssh`: http://ray.readthedocs.io/en/latest/using-ray-on-a-large-cluster.html
11 changes: 10 additions & 1 deletion python/ray/rllib/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.action_dist import (ActionDistribution, Categorical,
DiagGaussian, Deterministic)
from ray.rllib.models.model import Model
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.convnet import ConvolutionalNetwork
from ray.rllib.models.lstm import LSTM

__all__ = ["ModelCatalog"]

__all__ = ["ActionDistribution", "ActionDistribution", "Categorical",
"DiagGaussian", "Deterministic", "ModelCatalog", "Model",
"FullyConnectedNetwork", "ConvolutionalNetwork", "LSTM"]
4 changes: 4 additions & 0 deletions python/ray/rllib/models/action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@ def __init__(self, inputs):
self.inputs = inputs

def logp(self, x):
"""The log-likelihood of the action distribution."""
raise NotImplementedError

def kl(self, other):
"""The KL-divergene between two action distributions."""
raise NotImplementedError

def entropy(self):
"""The entroy of the action distribution."""
raise NotImplementedError

def sample(self):
"""Draw a sample from the action distribution."""
raise NotImplementedError


Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
normc_initializer)
from ray.rllib.models.model import Model

use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.0.0"))


class LSTM(Model):
# TODO(rliaw): Add LSTM code for other algorithms
def _init(self, inputs, num_outputs, options):
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.0.0"))

self.x = x = inputs
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
Expand Down

0 comments on commit 1eb8c83

Please sign in to comment.