Skip to content

Commit

Permalink
Exploration with Parameter Space Noise (ray-project#4048)
Browse files Browse the repository at this point in the history
*  enable parameter space noise for exploration

*  enable parameter space noise for exploration

*  yapf formatted

*  remove the usage of scipy softmax avialable in the latest version only

*  enable subclass that has no parameter_noise in the config

*  run user specified callbacks and test parameter space noise in multi node setting

*  formatted by yapf

* Update dqn.py

* lint
  • Loading branch information
joneswong authored and ericl committed Feb 21, 2019
1 parent bcd5af7 commit 3ac8fd7
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 12 deletions.
3 changes: 3 additions & 0 deletions python/ray/rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
"target_network_update_freq": 0,
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 0.002,
# If True parameter space noise will be used for exploration
# See https://blog.openai.com/better-exploration-with-parameter-noise/
"parameter_noise": False,

# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
Expand Down
94 changes: 87 additions & 7 deletions python/ray/rllib/agents/ddpg/ddpg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@ class PNetwork(object):
"""Maps an observations (i.e., state) to an action where each entry takes
value from (0, 1) due to the sigmoid function."""

def __init__(self, model, dim_actions, hiddens=[64, 64],
activation="relu"):
def __init__(self,
model,
dim_actions,
hiddens=[64, 64],
activation="relu",
parameter_noise=False):
action_out = model.last_layer
activation = tf.nn.__dict__[activation]
for hidden in hiddens:
action_out = layers.fully_connected(
action_out, num_outputs=hidden, activation_fn=activation)
action_out,
num_outputs=hidden,
activation_fn=activation,
normalizer_fn=layers.layer_norm if parameter_noise else None)
# Use sigmoid layer to bound values within (0, 1)
# shape of action_scores is [batch_size, dim_actions]
self.action_scores = layers.fully_connected(
Expand All @@ -60,7 +67,8 @@ def __init__(self,
act_noise=0.1,
is_target=False,
target_noise=0.2,
noise_clip=0.5):
noise_clip=0.5,
parameter_noise=False):

# shape is [None, dim_action]
deterministic_actions = (
Expand Down Expand Up @@ -97,8 +105,9 @@ def __init__(self,
eps * (high_action - low_action) * exploration_value,
low_action, high_action)

self.actions = tf.cond(stochastic, lambda: stochastic_actions,
lambda: deterministic_actions)
self.actions = tf.cond(
tf.logical_and(stochastic, not parameter_noise),
lambda: stochastic_actions, lambda: deterministic_actions)


class QNetwork(object):
Expand Down Expand Up @@ -210,6 +219,12 @@ def __init__(self, observation_space, action_space, config):
self.cur_observations, observation_space)
self.p_func_vars = _scope_vars(scope.name)

# Noise vars for P network except for layer normalization vars
if self.config["parameter_noise"]:
self._build_parameter_noise([
var for var in self.p_func_vars if "LayerNorm" not in var.name
])

# Action outputs
with tf.variable_scope(A_SCOPE):
self.output_actions = self._build_action_network(
Expand Down Expand Up @@ -429,6 +444,29 @@ def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if self.config["parameter_noise"]:
# adjust the sigma of parameter space noise
states, noisy_actions = [
list(x) for x in sample_batch.columns(["obs", "actions"])
]
self.sess.run(self.remove_noise_op)
clean_actions = self.sess.run(
self.output_actions,
feed_dict={
self.cur_observations: states,
self.stochastic: False,
self.eps: .0
})
distance_in_action_space = np.sqrt(
np.mean(np.square(clean_actions - noisy_actions)))
self.pi_distance = distance_in_action_space
if distance_in_action_space < self.config["exploration_sigma"]:
self.parameter_noise_sigma_val *= 1.01
else:
self.parameter_noise_sigma_val /= 1.01
self.parameter_noise_sigma.load(
self.parameter_noise_sigma_val, session=self.sess)

return _postprocess_dqn(self, sample_batch)

@override(TFPolicyGraph)
Expand Down Expand Up @@ -465,7 +503,8 @@ def _build_p_network(self, obs, obs_space):
"is_training": self._get_is_training_placeholder(),
}, obs_space, 1, self.config["model"]), self.dim_actions,
self.config["actor_hiddens"],
self.config["actor_hidden_activation"])
self.config["actor_hidden_activation"],
self.config["parameter_noise"])
return policy_net.action_scores, policy_net.model

def _build_action_network(self, p_values, stochastic, eps,
Expand All @@ -491,6 +530,43 @@ def _build_actor_critic_loss(self,
self.config["use_huber"], self.config["huber_threshold"],
self.config["twin_q"])

def _build_parameter_noise(self, pnet_params):
self.parameter_noise_sigma_val = self.config["exploration_sigma"]
self.parameter_noise_sigma = tf.get_variable(
initializer=tf.constant_initializer(
self.parameter_noise_sigma_val),
name="parameter_noise_sigma",
shape=(),
trainable=False,
dtype=tf.float32)
self.parameter_noise = list()
# No need to add any noise on LayerNorm parameters
for var in pnet_params:
noise_var = tf.get_variable(
name=var.name.split(':')[0] + "_noise",
shape=var.shape,
initializer=tf.constant_initializer(.0),
trainable=False)
self.parameter_noise.append(noise_var)
remove_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
remove_noise_ops.append(tf.assign_add(var, -var_noise))
self.remove_noise_op = tf.group(*tuple(remove_noise_ops))
generate_noise_ops = list()
for var_noise in self.parameter_noise:
generate_noise_ops.append(
tf.assign(
var_noise,
tf.random_normal(
shape=var_noise.shape,
stddev=self.parameter_noise_sigma)))
with tf.control_dependencies(generate_noise_ops):
add_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
add_noise_ops.append(tf.assign_add(var, var_noise))
self.add_noise_op = tf.group(*tuple(add_noise_ops))
self.pi_distance = None

def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
td_err = self.sess.run(
Expand All @@ -508,6 +584,10 @@ def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
def reset_noise(self, sess):
sess.run(self.reset_noise_op)

def add_parameter_noise(self):
if self.config["parameter_noise"]:
self.sess.run(self.add_noise_op)

# support both hard and soft sync
def update_target(self, tau=None):
return self.sess.run(
Expand Down
52 changes: 52 additions & 0 deletions python/ray/rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import time

from ray import tune
from ray.rllib import optimizers
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
Expand Down Expand Up @@ -73,6 +74,9 @@
# Softmax temperature. Q values are divided by this value prior to softmax.
# Softmax approaches argmax as the temperature drops to zero.
"softmax_temp": 1.0,
# If True parameter space noise will be used for exploration
# See https://blog.openai.com/better-exploration-with-parameter-noise/
"parameter_noise": False,

# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
Expand Down Expand Up @@ -139,6 +143,8 @@ class DQNAgent(Agent):

@override(Agent)
def _init(self):
self._validate_config()

# Update effective batch size to include n-step
adjusted_batch_size = max(self.config["sample_batch_size"],
self.config.get("n_step", 1))
Expand All @@ -160,6 +166,41 @@ def _init(self):
if k not in self.config["optimizer"]:
self.config["optimizer"][k] = self.config[k]

if self.config.get("parameter_noise", False):
if self.config["callbacks"]["on_episode_start"]:
start_callback = self.config["callbacks"]["on_episode_start"]
else:
start_callback = None

def on_episode_start(info):
# as a callback function to sample and pose parameter space
# noise on the parameters of network
policies = info["policy"]
for pi in policies.values():
pi.add_parameter_noise()
if start_callback:
start_callback(info)

self.config["callbacks"]["on_episode_start"] = tune.function(
on_episode_start)
if self.config["callbacks"]["on_episode_end"]:
end_callback = self.config["callbacks"]["on_episode_end"]
else:
end_callback = None

def on_episode_end(info):
# as a callback function to monitor the distance
# between noisy policy and original policy
policies = info["policy"]
episode = info["episode"]
episode.custom_metrics["policy_distance"] = policies[
"default"].pi_distance
if end_callback:
end_callback(info)

self.config["callbacks"]["on_episode_end"] = tune.function(
on_episode_end)

self.local_evaluator = self.make_local_evaluator(
self.env_creator, self._policy_graph)

Expand Down Expand Up @@ -296,3 +337,14 @@ def __setstate__(self, state):
Agent.__setstate__(self, state)
self.num_target_updates = state["num_target_updates"]
self.last_target_update_ts = state["last_target_update_ts"]

def _validate_config(self):
if self.config.get("parameter_noise", False):
if self.config["batch_mode"] != "complete_episodes":
raise ValueError(
"Exploration with parameter space noise requires "
"batch_mode to be complete_episodes.")
if self.config.get("noisy", False):
raise ValueError(
"Exploration with parameter space noise and noisy network "
"cannot be used at the same time.")
85 changes: 81 additions & 4 deletions python/ray/rllib/agents/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from gym.spaces import Discrete
import numpy as np
from scipy.stats import entropy
import tensorflow as tf
import tensorflow.contrib.layers as layers

Expand All @@ -28,7 +29,8 @@ def __init__(self,
num_atoms=1,
v_min=-10.0,
v_max=10.0,
sigma0=0.5):
sigma0=0.5,
parameter_noise=False):
self.model = model
with tf.variable_scope("action_value"):
if hiddens:
Expand All @@ -41,7 +43,9 @@ def __init__(self,
action_out = layers.fully_connected(
action_out,
num_outputs=hiddens[i],
activation_fn=tf.nn.relu)
activation_fn=tf.nn.relu,
normalizer_fn=layers.layer_norm
if parameter_noise else None)
else:
# Avoid postprocessing the outputs. This enables custom models
# to be used for parametric action DQN.
Expand Down Expand Up @@ -89,7 +93,9 @@ def __init__(self,
state_out = layers.fully_connected(
state_out,
num_outputs=hiddens[i],
activation_fn=tf.nn.relu)
activation_fn=tf.nn.relu,
normalizer_fn=layers.layer_norm
if parameter_noise else None)
if use_noisy:
state_score = self.noisy_layer(
"dueling_output",
Expand Down Expand Up @@ -310,6 +316,13 @@ def __init__(self, observation_space, action_space, config):
self.q_values = q_values
self.q_func_vars = _scope_vars(scope.name)

# Noise vars for Q network except for layer normalization vars
if self.config["parameter_noise"]:
self._build_parameter_noise([
var for var in self.q_func_vars if "LayerNorm" not in var.name
])
self.action_probs = tf.nn.softmax(self.q_values)

# Action outputs
self.output_actions, self.action_prob = self._build_q_value_policy(
q_values)
Expand Down Expand Up @@ -448,6 +461,28 @@ def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if self.config["parameter_noise"]:
# adjust the sigma of parameter space noise
states = [list(x) for x in sample_batch.columns(["obs"])][0]

noisy_action_distribution = self.sess.run(
self.action_probs, feed_dict={self.cur_observations: states})
self.sess.run(self.remove_noise_op)
clean_action_distribution = self.sess.run(
self.action_probs, feed_dict={self.cur_observations: states})
distance_in_action_space = np.mean(
entropy(clean_action_distribution.T,
noisy_action_distribution.T))
self.pi_distance = distance_in_action_space
if (distance_in_action_space <
-np.log(1 - self.cur_epsilon +
self.cur_epsilon / self.num_actions)):
self.parameter_noise_sigma_val *= 1.01
else:
self.parameter_noise_sigma_val /= 1.01
self.parameter_noise_sigma.load(
self.parameter_noise_sigma_val, session=self.sess)

return _postprocess_dqn(self, sample_batch)

@override(PolicyGraph)
Expand All @@ -459,6 +494,43 @@ def set_state(self, state):
TFPolicyGraph.set_state(self, state[0])
self.set_epsilon(state[1])

def _build_parameter_noise(self, pnet_params):
self.parameter_noise_sigma_val = 1.0
self.parameter_noise_sigma = tf.get_variable(
initializer=tf.constant_initializer(
self.parameter_noise_sigma_val),
name="parameter_noise_sigma",
shape=(),
trainable=False,
dtype=tf.float32)
self.parameter_noise = list()
# No need to add any noise on LayerNorm parameters
for var in pnet_params:
noise_var = tf.get_variable(
name=var.name.split(':')[0] + "_noise",
shape=var.shape,
initializer=tf.constant_initializer(.0),
trainable=False)
self.parameter_noise.append(noise_var)
remove_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
remove_noise_ops.append(tf.assign_add(var, -var_noise))
self.remove_noise_op = tf.group(*tuple(remove_noise_ops))
generate_noise_ops = list()
for var_noise in self.parameter_noise:
generate_noise_ops.append(
tf.assign(
var_noise,
tf.random_normal(
shape=var_noise.shape,
stddev=self.parameter_noise_sigma)))
with tf.control_dependencies(generate_noise_ops):
add_noise_ops = list()
for var, var_noise in zip(pnet_params, self.parameter_noise):
add_noise_ops.append(tf.assign_add(var, var_noise))
self.add_noise_op = tf.group(*tuple(add_noise_ops))
self.pi_distance = None

def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
td_err = self.sess.run(
Expand All @@ -473,6 +545,10 @@ def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
})
return td_err

def add_parameter_noise(self):
if self.config["parameter_noise"]:
self.sess.run(self.add_noise_op)

def update_target(self):
return self.sess.run(self.update_target_expr)

Expand All @@ -487,7 +563,8 @@ def _build_q_network(self, obs, space):
}, space, self.num_actions, self.config["model"]),
self.num_actions, self.config["dueling"], self.config["hiddens"],
self.config["noisy"], self.config["num_atoms"],
self.config["v_min"], self.config["v_max"], self.config["sigma0"])
self.config["v_min"], self.config["v_max"], self.config["sigma0"],
self.config["parameter_noise"])
return qnet.value, qnet.logits, qnet.dist, qnet.model

def _build_q_value_policy(self, q_values):
Expand Down
Loading

0 comments on commit 3ac8fd7

Please sign in to comment.