Skip to content

Commit

Permalink
[rllib] PPO and A3C unification (ray-project#1253)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored Dec 14, 2017
1 parent 2f750e9 commit c5c83a4
Show file tree
Hide file tree
Showing 19 changed files with 290 additions and 349 deletions.
14 changes: 9 additions & 5 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import ray
from ray.rllib.agent import Agent
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.a3c.runner import RemoteRunner
from ray.rllib.envs import create_and_wrap
from ray.rllib.a3c.runner import RemoteA3CEvaluator
from ray.rllib.a3c.common import get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.tune.result import TrainingResult
Expand Down Expand Up @@ -49,7 +49,8 @@ def _init(self):
self.env.observation_space.shape)
self.rew_filter = get_filter(self.config["reward_filter"], ())
self.agents = [
RemoteRunner.remote(self.env_creator, self.config, self.logdir)
RemoteA3CEvaluator.remote(
self.env_creator, self.config, self.logdir)
for i in range(self.config["num_workers"])]
self.parameters = self.policy.get_weights()

Expand Down Expand Up @@ -105,6 +106,7 @@ def _fetch_metrics_from_workers(self):
return result

def _save(self):
# TODO(rliaw): extend to also support saving worker state?
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
objects = [self.parameters, self.obs_filter, self.rew_filter]
Expand All @@ -118,6 +120,8 @@ def _restore(self, checkpoint_path):
self.rew_filter = objects[2]
self.policy.set_weights(self.parameters)

# TODO(rliaw): augment to support LSTM
def compute_action(self, observation):
actions = self.policy.compute_action(observation)
return actions[0]
obs = self.obs_filter(observation, update=False)
action, info = self.policy.compute(obs)
return action
35 changes: 0 additions & 35 deletions python/ray/rllib/a3c/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,6 @@
from __future__ import division
from __future__ import print_function

import numpy as np
import scipy.signal
from collections import namedtuple


def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]


def process_rollout(rollout, reward_filter, gamma, lambda_=1.0):
"""Given a rollout, compute its returns and the advantage.
TODO(rliaw): generalize this"""
batch_si = np.asarray(rollout.data["state"])
batch_a = np.asarray(rollout.data["action"])
rewards = np.asarray(rollout.data["reward"])
vpred_t = np.asarray(rollout.data["value"] + [rollout.last_r])

rewards_plus_v = np.asarray(rollout.data["reward"] + [rollout.last_r])
batch_r = discount(rewards_plus_v, gamma)[:-1]
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes "Generalized Advantage Estimation":
# https://arxiv.org/abs/1506.02438
batch_adv = discount(delta_t, gamma * lambda_)
for i in range(batch_adv.shape[0]):
batch_adv[i] = reward_filter(batch_adv[i])

features = rollout.data["features"][0]
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.is_terminal(),
features)


def get_policy_cls(config):
if config["use_lstm"]:
Expand All @@ -45,7 +14,3 @@ def get_policy_cls(config):
from ray.rllib.a3c.shared_model import SharedModel
policy_cls = SharedModel
return policy_cls


Batch = namedtuple(
"Batch", ["si", "a", "adv", "r", "terminal", "features"])
2 changes: 1 addition & 1 deletion python/ray/rllib/a3c/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def set_weights(self, weights):
def compute_gradients(self, batch):
raise NotImplementedError

def compute_action(self, observations):
def compute(self, observations):
"""Compute action for a _single_ observation"""
raise NotImplementedError

Expand Down
40 changes: 22 additions & 18 deletions python/ray/rllib/a3c/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from __future__ import print_function

import ray
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.a3c.common import process_rollout, get_policy_cls
from ray.rllib.envs import create_and_wrap
from ray.rllib.evaluator import Evaluator
from ray.rllib.a3c.common import get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.sampler import AsyncSampler
from ray.rllib.utils.process_rollout import process_rollout


class Runner(object):
class A3CEvaluator(Evaluator):
"""Actor object to start running simulation on workers.
The gradient computation is also executed from this object.
Expand All @@ -29,19 +31,16 @@ def __init__(self, env_creator, config, logdir):
obs_filter = get_filter(
config["observation_filter"], env.observation_space.shape)
self.rew_filter = get_filter(config["reward_filter"], ())

self.sampler = AsyncSampler(env, self.policy, config["batch_size"],
obs_filter)
self.sampler = AsyncSampler(env, self.policy, obs_filter,
config["batch_size"])
self.logdir = logdir

def get_data(self):
def sample(self):
"""
Returns:
trajectory: trajectory information
obs_filter: Current state of observation filter
rew_filter: Current state of reward filter"""
rollout, obs_filter = self.sampler.get_data()
return rollout, obs_filter, self.rew_filter
trajectory (PartialRollout): Experience Samples from evaluator"""
rollout = self.sampler.get_data()
return rollout

def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
Expand All @@ -51,14 +50,19 @@ def get_completed_rollout_metrics(self):
return self.sampler.get_metrics()

def compute_gradient(self):
rollout, obsf_snapshot = self.sampler.get_data()
batch = process_rollout(
rollout, self.rew_filter, gamma=0.99, lambda_=1.0)
gradient, info = self.policy.compute_gradients(batch)
info["obs_filter"] = obsf_snapshot
rollout = self.sampler.get_data()
obs_filter = self.sampler.get_obs_filter(flush=True)

traj = process_rollout(
rollout, self.rew_filter, gamma=0.99, lambda_=1.0, use_gae=True)
gradient, info = self.policy.compute_gradients(traj)
info["obs_filter"] = obs_filter
info["rew_filter"] = self.rew_filter
return gradient, info

def apply_gradient(self, grads):
self.policy.apply_gradients(grads)

def set_weights(self, params):
self.policy.set_weights(params)

Expand All @@ -70,4 +74,4 @@ def update_filters(self, obs_filter=None, rew_filter=None):
self.sampler.update_obs_filter(obs_filter)


RemoteRunner = ray.remote(Runner)
RemoteA3CEvaluator = ray.remote(A3CEvaluator)
19 changes: 8 additions & 11 deletions python/ray/rllib/a3c/shared_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class SharedModel(TFPolicy):

other_output = ["value"]
other_output = ["vf_preds"]
is_recurrent = False

def __init__(self, ob_space, ac_space, **kwargs):
Expand All @@ -35,13 +35,13 @@ def _setup_graph(self, ob_space, ac_space):
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)

def compute_gradients(self, batch):
def compute_gradients(self, trajectory):
info = {}
feed_dict = {
self.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
self.x: trajectory["observations"],
self.ac: trajectory["actions"],
self.adv: trajectory["advantages"],
self.r: trajectory["value_targets"],
}
self.grads = [g for g in self.grads if g is not None]
self.local_steps += 1
Expand All @@ -53,14 +53,11 @@ def compute_gradients(self, batch):
grad = self.sess.run(self.grads, feed_dict=feed_dict)
return grad, info

def compute_action(self, ob, *args):
def compute(self, ob, *args):
action, vf = self.sess.run([self.sample, self.vf],
{self.x: [ob]})
return action[0], {"value": vf[0]}
return action[0], {"vf_preds": vf[0]}

def value(self, ob, *args):
vf = self.sess.run(self.vf, {self.x: [ob]})
return vf[0]

def get_initial_features(self):
return []
21 changes: 11 additions & 10 deletions python/ray/rllib/a3c/shared_model_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SharedModelLSTM(TFPolicy):
to be tracked).
"""

other_output = ["value", "features"]
other_output = ["vf_preds", "features"]
is_recurrent = True

def __init__(self, ob_space, ac_space, **kwargs):
Expand Down Expand Up @@ -48,19 +48,20 @@ def _setup_graph(self, ob_space, ac_space):
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)

def compute_gradients(self, batch):
def compute_gradients(self, trajectory):
"""Computing the gradient is actually model-dependent.
The LSTM needs its hidden states in order to compute the gradient
accurately.
"""
features = trajectory["features"][0]
feed_dict = {
self.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
self.state_in[0]: batch.features[0],
self.state_in[1]: batch.features[1]
self.x: trajectory["observations"],
self.ac: trajectory["actions"],
self.adv: trajectory["advantages"],
self.r: trajectory["value_targets"],
self.state_in[0]: features[0],
self.state_in[1]: features[1]
}
info = {}
self.local_steps += 1
Expand All @@ -72,11 +73,11 @@ def compute_gradients(self, batch):
grad = self.sess.run(self.grads, feed_dict=feed_dict)
return grad, info

def compute_action(self, ob, c, h):
def compute(self, ob, c, h):
action, vf, c, h = self.sess.run(
[self.sample, self.vf] + self.state_out,
{self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
return action[0], {"value": vf[0], "features": (c, h)}
return action[0], {"vf_preds": vf[0], "features": (c, h)}

def value(self, ob, c, h):
vf = self.sess.run(self.vf, {self.x: [ob],
Expand Down
9 changes: 3 additions & 6 deletions python/ray/rllib/a3c/shared_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class SharedTorchPolicy(TorchPolicy):
"""Assumes nonrecurrent."""

other_output = ["value"]
other_output = ["vf_preds"]
is_recurrent = False

def __init__(self, ob_space, ac_space, **kwargs):
Expand All @@ -26,14 +26,14 @@ def _setup_graph(self, ob_space, ac_space):
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)

def compute_action(self, ob, *args):
def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
logits, values = self._model(ob)
samples = self._model.probs(logits).multinomial().squeeze()
values = values.squeeze(0)
return var_to_np(samples), {"value": var_to_np(values)}
return var_to_np(samples), {"vf_preds": var_to_np(values)}

def compute_logits(self, ob, *args):
with self.lock:
Expand Down Expand Up @@ -71,6 +71,3 @@ def _backward(self, batch):
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
overall_err.backward()
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)

def get_initial_features(self):
return [None]
2 changes: 1 addition & 1 deletion python/ray/rllib/a3c/tfpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def set_weights(self, weights):
def compute_gradients(self, batch):
raise NotImplementedError

def compute_action(self, observations):
def compute(self, observation):
raise NotImplementedError

def value(self, ob):
Expand Down
3 changes: 0 additions & 3 deletions python/ray/rllib/a3c/torchpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,3 @@ def _backward(self, batch):
This function regenerates the backward trace and
caluclates the gradient."""
raise NotImplementedError

def get_initial_features(self):
return []
File renamed without changes.
20 changes: 12 additions & 8 deletions python/ray/rllib/models/pytorch/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
from torch.autograd import Variable


def convert_batch(batch, has_features=False):
"""Convert batch from numpy to PT variable"""
states = Variable(torch.from_numpy(batch.si).float())
acs = Variable(torch.from_numpy(batch.a))
advs = Variable(torch.from_numpy(batch.adv.copy()).float())
def convert_batch(trajectory, has_features=False):
"""Convert trajectory from numpy to PT variable"""
states = Variable(torch.from_numpy(
trajectory["observations"]).float())
acs = Variable(torch.from_numpy(
trajectory["actions"]))
advs = Variable(torch.from_numpy(
trajectory["advantages"].copy()).float())
advs = advs.view(-1, 1)
rs = Variable(torch.from_numpy(batch.r.copy()).float())
rs = Variable(torch.from_numpy(
trajectory["value_targets"]).float())
rs = rs.view(-1, 1)
if has_features:
features = [Variable(torch.from_numpy(f))
for f in batch.features]
for f in trajectory["features"]]
else:
features = batch.features
features = trajectory["features"]
return states, acs, advs, rs, features


Expand Down
17 changes: 11 additions & 6 deletions python/ray/rllib/ppo/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

class ProximalPolicyLoss(object):

other_output = ["vf_preds", "logprobs"]
is_recurrent = False

def __init__(
self, observation_space, action_space,
observations, returns, advantages, actions,
observations, value_targets, advantages, actions,
prev_logits, prev_vf_preds, logit_dim,
kl_coeff, distribution_class, config, sess):
assert (isinstance(action_space, gym.spaces.Discrete) or
Expand Down Expand Up @@ -55,11 +58,11 @@ def __init__(
# We use a huber loss here to be more robust against outliers,
# which seem to occur when the rollouts get longer (the variance
# scales superlinearly with the length of the rollout)
self.vf_loss1 = tf.square(self.value_function - returns)
self.vf_loss1 = tf.square(self.value_function - value_targets)
vf_clipped = prev_vf_preds + tf.clip_by_value(
self.value_function - prev_vf_preds,
-config["clip_param"], config["clip_param"])
self.vf_loss2 = tf.square(vf_clipped - returns)
self.vf_loss2 = tf.square(vf_clipped - value_targets)
self.vf_loss = tf.minimum(self.vf_loss1, self.vf_loss2)
self.mean_vf_loss = tf.reduce_mean(self.vf_loss)
self.loss = tf.reduce_mean(
Expand All @@ -82,9 +85,11 @@ def __init__(
self.policy_results = [
self.sampler, self.curr_logits, tf.constant("NA")]

def compute(self, observations):
return self.sess.run(self.policy_results,
feed_dict={self.observations: observations})
def compute(self, observation):
action, logprobs, vf = self.sess.run(
self.policy_results,
feed_dict={self.observations: [observation]})
return action[0], {"vf_preds": vf[0], "logprobs": logprobs[0]}

def loss(self):
return self.loss
Loading

0 comments on commit c5c83a4

Please sign in to comment.