Skip to content

Commit

Permalink
[rllib] TD3/DDPG improvements and MuJoCo benchmarks (ray-project#4694)
Browse files Browse the repository at this point in the history
* [rllib] Separate optimisers for DDPG actor & crit.

* [rllib] Better names for DDPG variables & options

Config changes:

- noise_scale -> exploration_ou_noise_scale
- exploration_theta -> exploration_ou_theta
- exploration_sigma -> exploration_ou_sigma
- act_noise -> exploration_gaussian_sigma
- noise_clip -> target_noise_clip

* [rllib] Make DDPG less class-y

Used functions to replace three classes with only an __init__ method & a
handful of unrelated attributes.

* [rllib] Refactor DDPG noise

* [rllib] Unify DDPG exploration annealing

Added option "exploration_should_anneal" to enable linear annealing of
exploration noise. By default this is off, for consistency with DDPG &
TD3 papers. Also renamed "exploration_final_eps" to
"exploration_final_scale" (that name seems to have been carried over
from DQN, and doesn't really make sense here). Finally, tried to rename
"eps" to "noise_scale" wherever possible.
  • Loading branch information
qxcv authored and ericl committed Apr 27, 2019
1 parent 05c896d commit 663e92a
Show file tree
Hide file tree
Showing 16 changed files with 559 additions and 400 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Deep Deterministic Policy Gradients (DDPG, TD3)
`[paper] <https://arxiv.org/abs/1509.02971>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/ddpg/ddpg.py>`__
DDPG is implemented similarly to DQN (below). The algorithm can be scaled by increasing the number of workers, switching to AsyncGradientsOptimizer, or using Ape-X. The improvements from `TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`__ are available though not enabled by default.

Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml>`__, `TD3 configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-td3.yaml>`__, `MountainCarContinuous-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml>`__, `HalfCheetah-v2 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml>`__
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml>`__, `MountainCarContinuous-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml>`__, `HalfCheetah-v2 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml>`__, `TD3 Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pendulum-td3.yaml>`__, `TD3 InvertedPendulum-v2 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/invertedpendulum-td3.yaml>`__, `TD3 Mujoco suite (Ant-v2, HalfCheetah-v2, Hopper-v2, Walker2d-v2) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/mujoco-td3.yaml>`__.

**DDPG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/agents/ddpg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG
from ray.rllib.agents.ddpg.td3 import TD3Trainer
from ray.rllib.utils import renamed_class

ApexDDPGAgent = renamed_class(ApexDDPGTrainer)
DDPGAgent = renamed_class(DDPGTrainer)

__all__ = [
"DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer",
"DEFAULT_CONFIG"
"TD3Trainer", "DEFAULT_CONFIG"
]
126 changes: 84 additions & 42 deletions python/ray/rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
DEFAULT_CONFIG = with_common_config({
# === Twin Delayed DDPG (TD3) and Soft Actor-Critic (SAC) tricks ===
# TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
# In addition to settings below, you can use "exploration_noise_type" and
# "exploration_gauss_act_noise" to get IID Gaussian exploration noise
# instead of OU exploration noise.
# twin Q-net
"twin_q": False,
# delayed policy update
"policy_delay": 1,
# target policy smoothing
# this also forces the use of gaussian instead of OU noise for exploration
# (this also replaces OU exploration noise with IID Gaussian exploration
# noise, for now)
"smooth_target_policy": False,
# gaussian stddev of act noise
"act_noise": 0.1,
# gaussian stddev of target noise
# gaussian stddev of target action noise for smoothing
"target_noise": 0.2,
# target noise limit (bound)
"noise_clip": 0.5,
"target_noise_clip": 0.5,

# === Evaluation ===
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
Expand All @@ -37,42 +39,64 @@
"evaluation_num_episodes": 10,

# === Model ===
# Postprocess the policy network model output with these hidden layers
"actor_hiddens": [64, 64],
# Hidden layers activation of the policy network
# Apply a state preprocessor with spec given by the "model" config option
# (like other RL algorithms). This is mostly useful if you have a weird
# observation shape, like an image. Disabled by default.
"use_state_preprocessor": False,
# Postprocess the policy network model output with these hidden layers. If
# use_state_preprocessor is False, then these will be the *only* hidden
# layers in the network.
"actor_hiddens": [400, 300],
# Hidden layers activation of the postprocessing stage of the policy
# network
"actor_hidden_activation": "relu",
# Postprocess the critic network model output with these hidden layers
"critic_hiddens": [64, 64],
# Hidden layers activation of the critic network
# Postprocess the critic network model output with these hidden layers;
# again, if use_state_preprocessor is True, then the state will be
# preprocessed by the model specified with the "model" config option first.
"critic_hiddens": [400, 300],
# Hidden layers activation of the postprocessing state of the critic.
"critic_hidden_activation": "relu",
# N-step Q learning
"n_step": 1,

# === Exploration ===
# Max num timesteps for annealing schedules. Exploration is annealed from
# 1.0 to exploration_fraction over this number of timesteps scaled by
# exploration_fraction
# Turns on annealing schedule for exploration noise. Exploration is
# annealed from 1.0 to exploration_final_eps over schedule_max_timesteps
# scaled by exploration_fraction. Original DDPG and TD3 papers do not
# anneal noise, so this is False by default.
"exploration_should_anneal": False,
# Max num timesteps for annealing schedules.
"schedule_max_timesteps": 100000,
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,
# Fraction of entire training period over which the exploration rate is
# annealed
"exploration_fraction": 0.1,
# Final value of random action probability
"exploration_final_eps": 0.02,
# OU-noise scale
"noise_scale": 0.1,
# theta
"exploration_theta": 0.15,
# sigma
"exploration_sigma": 0.2,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 0,
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 0.002,
# Final scaling multiplier for action noise (initial is 1.0)
"exploration_final_scale": 0.02,
# valid values: "ou" (time-correlated, like original DDPG paper),
# "gaussian" (IID, like TD3 paper)
"exploration_noise_type": "ou",
# OU-noise scale; this can be used to scale down magnitude of OU noise
# before adding to actions (requires "exploration_noise_type" to be "ou")
"exploration_ou_noise_scale": 0.1,
# theta for OU
"exploration_ou_theta": 0.15,
# sigma for OU
"exploration_ou_sigma": 0.2,
# gaussian stddev of act noise for exploration (requires
# "exploration_noise_type" to be "gaussian")
"exploration_gaussian_sigma": 0.1,
# If True parameter space noise will be used for exploration
# See https://blog.openai.com/better-exploration-with-parameter-noise/
"parameter_noise": False,
# Until this many timesteps have elapsed, the agent's policy will be
# ignored & it will instead take uniform random actions. Can be used in
# conjunction with learning_starts (which controls when the first
# optimization step happens) to decrease dependence of exploration &
# optimization on initial policy parameters. Note that this will be
# disabled when the action noise scale is set to 0 (e.g during evaluation).
"pure_exploration_steps": 1000,

# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
Expand All @@ -90,11 +114,14 @@
"compress_observations": False,

# === Optimization ===
# Learning rate for adam optimizer.
# Instead of using two optimizers, we use two different loss coefficients
"lr": 1e-3,
"actor_loss_coeff": 0.1,
"critic_loss_coeff": 1.0,
# Learning rate for the critic (Q-function) optimizer.
"critic_lr": 1e-3,
# Learning rate for the actor (policy) optimizer.
"actor_lr": 1e-3,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 0,
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 0.002,
# If True, use huber loss instead of squared loss for critic network
# Conventionally, no need to clip gradients if using a huber loss
"use_huber": False,
Expand All @@ -117,7 +144,7 @@
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
# you're using the Async or Ape-X optimizers.
"num_workers": 0,
# Optimizer class to use.
"optimizer_class": "SyncReplayOptimizer",
Expand All @@ -138,26 +165,41 @@ class DDPGTrainer(DQNTrainer):
_default_config = DEFAULT_CONFIG
_policy_graph = DDPGPolicyGraph

@override(DQNTrainer)
def _train(self):
pure_expl_steps = self.config["pure_exploration_steps"]
if pure_expl_steps:
# tell workers whether they should do pure exploration
only_explore = self.global_timestep < pure_expl_steps
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.set_pure_exploration_phase(only_explore))
for e in self.remote_evaluators:
e.foreach_trainable_policy.remote(
lambda p, _: p.set_pure_exploration_phase(only_explore))
return super(DDPGTrainer, self)._train()

@override(DQNTrainer)
def _make_exploration_schedule(self, worker_index):
# Override DQN's schedule to take into account `noise_scale`
# Override DQN's schedule to take into account
# `exploration_ou_noise_scale`
if self.config["per_worker_exploration"]:
assert self.config["num_workers"] > 1, \
"This requires multiple workers"
if worker_index >= 0:
exponent = (
1 +
worker_index / float(self.config["num_workers"] - 1) * 7)
return ConstantSchedule(
self.config["noise_scale"] * 0.4**exponent)
# FIXME: what do magic constants mean? (0.4, 7)
max_index = float(self.config["num_workers"] - 1)
exponent = 1 + worker_index / max_index * 7
return ConstantSchedule(0.4**exponent)
else:
# local ev should have zero exploration so that eval rollouts
# run properly
return ConstantSchedule(0.0)
else:
elif self.config["exploration_should_anneal"]:
return LinearSchedule(
schedule_timesteps=int(self.config["exploration_fraction"] *
self.config["schedule_max_timesteps"]),
initial_p=self.config["noise_scale"] * 1.0,
final_p=self.config["noise_scale"] *
self.config["exploration_final_eps"])
initial_p=1.0,
final_p=self.config["exploration_final_scale"])
else:
# *always* add exploration noise
return ConstantSchedule(1.0)
Loading

0 comments on commit 663e92a

Please sign in to comment.