Skip to content

Commit

Permalink
[rllib] Add custom value functions, fix up and document multi-agent v…
Browse files Browse the repository at this point in the history
…ariable sharing (ray-project#3151)
  • Loading branch information
ericl authored Oct 30, 2018
1 parent e49839c commit a221f55
Show file tree
Hide file tree
Showing 18 changed files with 199 additions and 46 deletions.
42 changes: 42 additions & 0 deletions doc/source/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,48 @@ Here is a simple `example training script <https://github.com/ray-project/ray/bl

To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``.

Variable-Sharing Between Policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

RLlib will create each policy's model in a separate ``tf.variable_scope``. However, variables can still be shared between policies by explicitly entering a globally shared variable scope with ``tf.VariableScope(reuse=tf.AUTO_REUSE)``:

.. code-block:: python
with tf.variable_scope(
tf.VariableScope(tf.AUTO_REUSE, "name_of_global_shared_scope"),
reuse=tf.AUTO_REUSE,
auxiliary_name_scope=False):
<create the shared layers here>
There is a full example of this in the `example training script <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_cartpole.py>`__.

Implementing a Centralized Critic
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Implementing a shared critic between multiple policies requires the definition of custom policy graphs. It can be done as follows:

1. Querying the critic: this can be done in the ``postprocess_trajectory`` method of a custom policy graph, which has full access to the policies and observations of concurrent agents via the ``other_agent_batches`` and ``episode`` arguments. This assumes you use variable sharing to access the critic network from multiple policies. The critic predictions can then be added to the postprocessed trajectory. Here's an example:

.. code-block:: python
def postprocess_trajectory(self, sample_batch, other_agent_batches, episode):
agents = ["agent_1", "agent_2", "agent_3"] # simple example of 3 agents
global_obs_batch = np.stack(
[other_agent_batches[agent_id][1]["obs"] for agent_id in agents],
axis=1)
# add the global obs and global critic value
sample_batch["global_obs"] = global_obs_batch
sample_batch["global_vf"] = self.sess.run(
self.global_critic_network, feed_dict={"obs": global_obs_batch})
# metrics like "global reward" can be retrieved from the info return of the environment
sample_batch["global_rewards"] = [
info["global_reward"] for info in sample_batch["infos"]]
return sample_batch
2. Updating the critic: the centralized critic loss can be added to the loss of some arbitrary policy graph. The policy graph that is chosen must add the inputs for the critic loss to its postprocessed trajectory batches.

For an example of defining loss inputs, see the `PGPolicyGraph example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy_graph.py>`__.

Agent-Driven
------------

Expand Down
14 changes: 13 additions & 1 deletion doc/source/rllib-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The following is a list of the built-in model hyperparameters:
Custom Models
-------------

Custom models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``), and returns a feature layer and float vector of the specified output size. The model can then be registered and used in place of a built-in model:
Custom models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. The model can then be registered and used in place of a built-in model:

.. code-block:: python
Expand Down Expand Up @@ -74,6 +74,18 @@ Custom models should subclass the common RLlib `model class <https://github.com/
...
return layerN, layerN_minus_1
def value_function(self):
"""Builds the value function output.
This method can be overridden to customize the implementation of the
value function (e.g., not sharing hidden layers).
Returns:
Tensor of size [BATCH_SIZE] for the value function.
"""
return tf.reshape(
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
ModelCatalog.register_custom_model("my_model", MyModelClass)
ray.init()
Expand Down
10 changes: 5 additions & 5 deletions python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.catalog import ModelCatalog


Expand Down Expand Up @@ -57,9 +56,7 @@ def __init__(self, observation_space, action_space, config):
"prev_rewards": prev_rewards
}, observation_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
self.vf = tf.reshape(
linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
[-1])
self.vf = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)

Expand Down Expand Up @@ -144,7 +141,10 @@ def extra_compute_grad_fetches(self):
def get_initial_state(self):
return self.model.state_init

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
Expand Down
5 changes: 4 additions & 1 deletion python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def extra_action_out(self, model_out):
def optimizer(self):
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
Expand Down
5 changes: 4 additions & 1 deletion python/ray/rllib/agents/ddpg/ddpg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ def extra_compute_grad_fetches(self):
"td_error": self.loss.td_error,
}

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
return _postprocess_dqn(self, sample_batch)

def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
Expand Down
5 changes: 4 additions & 1 deletion python/ray/rllib/agents/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,10 @@ def extra_compute_grad_fetches(self):
"td_error": self.loss.td_error,
}

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
return _postprocess_dqn(self, sample_batch)

def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
Expand Down
10 changes: 5 additions & 5 deletions python/ray/rllib/agents/impala/vtrace_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance

Expand Down Expand Up @@ -140,9 +139,7 @@ def __init__(self,
state_in=existing_state_in,
seq_lens=existing_seq_lens)
action_dist = dist_class(self.model.outputs)
values = tf.reshape(
linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
[-1])
values = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)

Expand Down Expand Up @@ -251,7 +248,10 @@ def extra_compute_action_fetches(self):
def extra_compute_grad_fetches(self):
return self.stats_fetches

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
del sample_batch.data["new_obs"] # not used, so save some bandwidth
return sample_batch

Expand Down
24 changes: 18 additions & 6 deletions python/ray/rllib/agents/pg/pg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,27 @@


class PGLoss(object):
"""Simple policy gradient loss."""

def __init__(self, action_dist, actions, advantages):
self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages)


class PGPolicyGraph(TFPolicyGraph):
"""Simple policy gradient example of defining a policy graph."""

def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config)
self.config = config

# Setup policy
# Setup placeholders
obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")

# Create the model network and action outputs
self.model = ModelCatalog.get_model({
"obs": obs,
"prev_actions": prev_actions,
Expand All @@ -38,17 +44,19 @@ def __init__(self, obs_space, action_space, config):
advantages = tf.placeholder(tf.float32, [None], name="adv")
loss = PGLoss(action_dist, actions, advantages).loss

# Initialize TFPolicyGraph
sess = tf.get_default_session()
# Mapping from sample batch keys to placeholders
# Mapping from sample batch keys to placeholders. These keys will be
# read from postprocessed sample batches and fed into the specified
# placeholders during loss computation.
loss_in = [
("obs", obs),
("actions", actions),
("prev_actions", prev_actions),
("prev_rewards", prev_rewards),
("advantages", advantages),
("advantages", advantages), # added during postprocessing
]

# Initialize TFPolicyGraph
sess = tf.get_default_session()
TFPolicyGraph.__init__(
self,
obs_space,
Expand All @@ -66,7 +74,11 @@ def __init__(self, obs_space, action_space, config):
max_seq_len=config["model"]["max_seq_len"])
sess.run(tf.global_variables_initializer())

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
# This ads the "advantages" column to the sample batch
return compute_advantages(
sample_batch, 0.0, self.config["gamma"], use_gae=False)

Expand Down
10 changes: 5 additions & 5 deletions python/ray/rllib/agents/ppo/ppo_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.utils.explained_variance import explained_variance


Expand Down Expand Up @@ -180,9 +179,7 @@ def __init__(self,
self.sampler = curr_action_dist.sample()
if self.config["use_gae"]:
if self.config["vf_share_layers"]:
self.value_function = tf.reshape(
linear(self.model.last_layer, 1, "value",
normc_initializer(1.0)), [-1])
self.value_function = self.model.value_function()
else:
vf_config = self.config["model"].copy()
# Do not split the last layer of the value function into
Expand Down Expand Up @@ -286,7 +283,10 @@ def value(self, ob, *args):
vf = self.sess.run(self.value_function, feed_dict)
return vf[0]

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
Expand Down
8 changes: 4 additions & 4 deletions python/ray/rllib/evaluation/policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.tf_run_builder import TFRunBuilder


Expand Down Expand Up @@ -299,8 +300,7 @@ def _build_policy_map(self, policy_dict, policy_config):
policy_map = {}
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
merged_conf = policy_config.copy()
merged_conf.update(conf)
merged_conf = merge_dicts(policy_config, conf)
with tf.variable_scope(name):
if isinstance(obs_space, gym.spaces.Dict):
raise ValueError(
Expand Down
10 changes: 8 additions & 2 deletions python/ray/rllib/evaluation/policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compute_single_action(self,
is_training (bool): whether we are training the policy
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multiagent algorithms.
multi-agent algorithms.
Returns:
actions (obj): single action
Expand All @@ -96,7 +96,10 @@ def compute_single_action(self,
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}

def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
Expand All @@ -108,6 +111,9 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
other_agent_batches (dict): In a multi-agent env, this contains a
mapping of agent ids to (policy_graph, agent_batch) tuples
containing the policy graph and experiences of the other agent.
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
Returns:
SampleBatch: postprocessed sample batch.
Expand Down
14 changes: 10 additions & 4 deletions python/ray/rllib/evaluation/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ def add_values(self, agent_id, policy_id, **values):
builder = self.agent_builders[agent_id]
builder.add_values(**values)

def postprocess_batch_so_far(self):
def postprocess_batch_so_far(self, episode):
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Arguments:
episode: current MultiAgentEpisode object or None
"""

# Materialize the batches so far
Expand All @@ -128,7 +131,7 @@ def postprocess_batch_so_far(self):
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches)
pre_batch, other_batches, episode)

# Append into policy batches and reset
for agent_id, post_batch in sorted(post_batches.items()):
Expand All @@ -137,14 +140,17 @@ def postprocess_batch_so_far(self):
self.agent_builders.clear()
self.agent_to_policy.clear()

def build_and_reset(self):
def build_and_reset(self, episode):
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Arguments:
episode: current MultiAgentEpisode object or None
"""

self.postprocess_batch_so_far()
self.postprocess_batch_so_far(episode)
policy_batches = {}
for policy_id, builder in self.policy_builders.items():
if builder.count > 0:
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ def new_episode():
if episode.batch_builder.has_pending_data():
if (all_done and not pack) or \
episode.batch_builder.count >= unroll_length:
yield episode.batch_builder.build_and_reset()
yield episode.batch_builder.build_and_reset(episode)
elif all_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far()
episode.batch_builder.postprocess_batch_so_far(episode)

if all_done:
# Handle episode termination
Expand Down
4 changes: 3 additions & 1 deletion python/ray/rllib/evaluation/tf_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __init__(self,
loss_inputs (list): a (name, placeholder) tuple for each loss
input argument. Each placeholder name must correspond to a
SampleBatch column key returned by postprocess_trajectory(),
and has shape [BATCH_SIZE, data...].
and has shape [BATCH_SIZE, data...]. These keys will be read
from postprocessed sample batches and fed into the specified
placeholders during loss computation.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
prev_action_input (Tensor): placeholder for previous actions
Expand Down
Loading

0 comments on commit a221f55

Please sign in to comment.