Skip to content

Commit

Permalink
[rllib] A3C Configurations (ray-project#1370)
Browse files Browse the repository at this point in the history
* initial introduction of a3c configs

* fix sample batch

* flake but need to check save

* save,resotre

* fix

* pickles

* entropy

* fix

* moving ppo

* results

* jenkins
  • Loading branch information
richardliaw authored and ericl committed Dec 24, 2017
1 parent b217a5e commit 4bb5b6b
Show file tree
Hide file tree
Showing 15 changed files with 164 additions and 113 deletions.
123 changes: 62 additions & 61 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,85 +8,83 @@

import ray
from ray.rllib.agent import Agent
from ray.rllib.envs import create_and_wrap
from ray.rllib.a3c.runner import RemoteA3CEvaluator
from ray.rllib.a3c.common import get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.rllib.optimizers import AsyncOptimizer
from ray.rllib.a3c.base_evaluator import A3CEvaluator, RemoteA3CEvaluator
from ray.tune.result import TrainingResult


DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 4,
"num_batches_per_iteration": 100,

# Size of rollout batch
"batch_size": 10,
"use_lstm": True,
# Use LSTM model - only applicable for image states
"use_lstm": False,
# Use PyTorch as backend - no LSTM support
"use_pytorch": False,
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",
# Which reward filter to apply to the reward
"reward_filter": "NoFilter",

"model": {"grayscale": True,
"zero_mean": False,
"dim": 42,
"channel_major": False}
# Discount factor of MDP
"gamma": 0.99,
# GAE(gamma) parameter
"lambda": 1.0,
# Max global norm for each gradient calculated by worker
"grad_clip": 40.0,
# Learning rate
"lr": 0.0001,
# Value Function Loss coefficient
"vf_loss_coeff": 0.5,
# Entropy coefficient
"entropy_coeff": -0.01,
# Preprocessing for environment
"preprocessing": {
# (Image statespace) - Converts image to Channels = 1
"grayscale": True,
# (Image statespace) - Each pixel
"zero_mean": False,
# (Image statespace) - Converts image to (dim, dim, C)
"dim": 80,
# (Image statespace) - Converts image shape to (C, dim, dim)
"channel_major": False
},
# Configuration for model specification
"model": {},
# Arguments to pass to the rllib optimizer
"optimizer": {
# Number of gradients applied for each `train` step
"grads_per_step": 100,
},
}


class A3CAgent(Agent):
_agent_name = "A3C"
_default_config = DEFAULT_CONFIG
_allow_unknown_subkeys = ["model", "optimizer"]

def _init(self):
self.env = create_and_wrap(self.env_creator, self.config["model"])
policy_cls = get_policy_cls(self.config)
self.policy = policy_cls(
self.env.observation_space.shape, self.env.action_space)
self.obs_filter = get_filter(
self.config["observation_filter"],
self.env.observation_space.shape)
self.rew_filter = get_filter(self.config["reward_filter"], ())
self.agents = [
self.local_evaluator = A3CEvaluator(
self.env_creator, self.config, self.logdir, start_sampler=False)
self.remote_evaluators = [
RemoteA3CEvaluator.remote(
self.env_creator, self.config, self.logdir)
for i in range(self.config["num_workers"])]
self.parameters = self.policy.get_weights()
self.optimizer = AsyncOptimizer(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)

def _train(self):
remote_params = ray.put(self.parameters)
ray.get([agent.set_weights.remote(remote_params)
for agent in self.agents])

gradient_list = {agent.compute_gradient.remote(): agent
for agent in self.agents}
max_batches = self.config["num_batches_per_iteration"]
batches_so_far = len(gradient_list)
while gradient_list:
[done_id], _ = ray.wait(list(gradient_list))
gradient, info = ray.get(done_id)
agent = gradient_list.pop(done_id)
self.obs_filter.update(info["obs_filter"])
self.rew_filter.update(info["rew_filter"])
self.policy.apply_gradients(gradient)
self.parameters = self.policy.get_weights()

if batches_so_far < max_batches:
batches_so_far += 1
agent.update_filters.remote(
obs_filter=self.obs_filter,
rew_filter=self.rew_filter)
agent.set_weights.remote(self.parameters)
gradient_list[agent.compute_gradient.remote()] = agent
res = self._fetch_metrics_from_workers()
self.optimizer.step()
res = self._fetch_metrics_from_remote_evaluators()
return res

def _fetch_metrics_from_workers(self):
def _fetch_metrics_from_remote_evaluators(self):
episode_rewards = []
episode_lengths = []
metric_lists = [
a.get_completed_rollout_metrics.remote() for a in self.agents]
metric_lists = [a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
Expand All @@ -106,22 +104,25 @@ def _fetch_metrics_from_workers(self):
return result

def _save(self):
# TODO(rliaw): extend to also support saving worker state?
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
objects = [self.parameters, self.obs_filter, self.rew_filter]
pickle.dump(objects, open(checkpoint_path, "wb"))
# self.saver.save
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()}
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path

def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.parameters = objects[0]
self.obs_filter = objects[1]
self.rew_filter = objects[2]
self.policy.set_weights(self.parameters)
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
ray.get(
[a.restore.remote(o) for a, o in zip(
self.remote_evaluators, extra_data["remote_state"])])
self.local_evaluator.restore(extra_data["local_state"])

# TODO(rliaw): augment to support LSTM
def compute_action(self, observation):
obs = self.obs_filter(observation, update=False)
action, info = self.policy.compute(obs)
obs = self.local_evaluator.obs_filter(observation, update=False)
action, info = self.local_evaluator.policy.compute(obs)
return action
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import division
from __future__ import print_function

import pickle

import ray
from ray.rllib.envs import create_and_wrap
from ray.rllib.optimizers import Evaluator
Expand All @@ -23,24 +25,33 @@ class A3CEvaluator(Evaluator):
rollouts.
logdir: Directory for logging.
"""
def __init__(self, env_creator, config, logdir):
self.env = env = create_and_wrap(env_creator, config["model"])
def __init__(self, env_creator, config, logdir, start_sampler=True):
self.env = env = create_and_wrap(env_creator, config["preprocessing"])
policy_cls = get_policy_cls(config)
# TODO(rliaw): should change this to be just env.observation_space
self.policy = policy_cls(env.observation_space.shape, env.action_space)
obs_filter = get_filter(
self.policy = policy_cls(
env.observation_space.shape, env.action_space, config)
self.config = config

# Technically not needed when not remote
self.obs_filter = get_filter(
config["observation_filter"], env.observation_space.shape)
self.rew_filter = get_filter(config["reward_filter"], ())
self.sampler = AsyncSampler(env, self.policy, obs_filter,
self.sampler = AsyncSampler(env, self.policy, self.obs_filter,
config["batch_size"])
if start_sampler and self.sampler.async:
self.sampler.start()
self.logdir = logdir

def sample(self):
"""
Returns:
trajectory (PartialRollout): Experience Samples from evaluator"""
rollout = self.sampler.get_data()
return rollout
samples = process_rollout(
rollout, self.rew_filter, gamma=self.config["gamma"],
lambda_=self.config["lambda"], use_gae=True)
return samples

def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
Expand All @@ -49,20 +60,16 @@ def get_completed_rollout_metrics(self):
"""
return self.sampler.get_metrics()

def compute_gradient(self):
rollout = self.sampler.get_data()
obs_filter = self.sampler.get_obs_filter(flush=True)

traj = process_rollout(
rollout, self.rew_filter, gamma=0.99, lambda_=1.0, use_gae=True)
gradient, info = self.policy.compute_gradients(traj)
info["obs_filter"] = obs_filter
info["rew_filter"] = self.rew_filter
return gradient, info
def compute_gradients(self, samples):
gradient, info = self.policy.compute_gradients(samples)
return gradient

def apply_gradient(self, grads):
def apply_gradients(self, grads):
self.policy.apply_gradients(grads)

def get_weights(self):
return self.policy.get_weights()

def set_weights(self, params):
self.policy.set_weights(params)

Expand All @@ -73,5 +80,13 @@ def update_filters(self, obs_filter=None, rew_filter=None):
if obs_filter:
self.sampler.update_obs_filter(obs_filter)

def save(self):
weights = self.get_weights()
return pickle.dumps({"weights": weights})

def restore(self, objs):
objs = pickle.loads(objs)
self.set_weights(objs["weights"])


RemoteA3CEvaluator = ray.remote(A3CEvaluator)
7 changes: 4 additions & 3 deletions python/ray/rllib/a3c/shared_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ class SharedModel(TFPolicy):
other_output = ["vf_preds"]
is_recurrent = False

def __init__(self, ob_space, ac_space, **kwargs):
super(SharedModel, self).__init__(ob_space, ac_space, **kwargs)
def __init__(self, ob_space, ac_space, config, **kwargs):
super(SharedModel, self).__init__(ob_space, ac_space, config, **kwargs)

def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ModelCatalog.get_model(self.x, self.logit_dim)
self._model = ModelCatalog.get_model(
self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
self.curr_dist = dist_class(self.logits)
# with tf.variable_scope("vf"):
Expand Down
7 changes: 4 additions & 3 deletions python/ray/rllib/a3c/shared_model_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ class SharedModelLSTM(TFPolicy):
"""
Attributes:
other_output (list): Other than `action`, the other return values from
`compute_gradient`.
`compute_gradients`.
is_recurrent (bool): True if is a recurrent network (requires features
to be tracked).
"""

other_output = ["vf_preds", "features"]
is_recurrent = True

def __init__(self, ob_space, ac_space, **kwargs):
super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs)
def __init__(self, ob_space, ac_space, config, **kwargs):
super(SharedModelLSTM, self).__init__(
ob_space, ac_space, config, **kwargs)

def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
Expand Down
17 changes: 11 additions & 6 deletions python/ray/rllib/a3c/shared_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ class SharedTorchPolicy(TorchPolicy):
other_output = ["vf_preds"]
is_recurrent = False

def __init__(self, ob_space, ac_space, **kwargs):
def __init__(self, ob_space, ac_space, config, **kwargs):
super(SharedTorchPolicy, self).__init__(
ob_space, ac_space, **kwargs)
ob_space, ac_space, config, **kwargs)

def _setup_graph(self, ob_space, ac_space):
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)
self._model = ModelCatalog.get_torch_model(
ob_space, self.logit_dim, self.config["model"])
self.optimizer = torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])

def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
Expand Down Expand Up @@ -68,6 +70,9 @@ def _backward(self, batch):
value_err = 0.5 * (values - rs).pow(2).sum()

self.optimizer.zero_grad()
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
overall_err = (pi_err +
value_err * self.config["vf_loss_coeff"] +
entropy * self.config["entropy_coeff"])
overall_err.backward()
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
torch.nn.utils.clip_grad_norm(
self._model.parameters(), self.config["grad_clip"])
13 changes: 9 additions & 4 deletions python/ray/rllib/a3c/tfpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

class TFPolicy(Policy):
"""The policy base class."""
def __init__(self, ob_space, action_space, name="local", summarize=True):
def __init__(self, ob_space, action_space, config,
name="local", summarize=True):
self.local_steps = 0
self.config = config
self.summarize = summarize
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
self.g = tf.Graph()
Expand Down Expand Up @@ -52,13 +54,15 @@ def setup_loss(self, action_space):
delta = self.vf - self.r
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
self.entropy = tf.reduce_sum(self.curr_dist.entropy())
self.loss = self.pi_loss + 0.5 * self.vf_loss - self.entropy * 0.01
self.loss = (self.pi_loss +
self.vf_loss * self.config["vf_loss_coeff"] +
self.entropy * self.config["entropy_coeff"])

def setup_gradients(self):
grads = tf.gradients(self.loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
grads_and_vars = list(zip(self.grads, self.var_list))
opt = tf.train.AdamOptimizer(1e-4)
opt = tf.train.AdamOptimizer(self.config["lr"])
self._apply_gradients = opt.apply_gradients(grads_and_vars)

def initialize(self):
Expand All @@ -71,6 +75,7 @@ def initialize(self):
tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list))
self.summary_op = tf.summary.merge_all()

# TODO(rliaw): Can consider exposing these parameters
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
self.variables = ray.experimental.TensorFlowVariables(self.loss,
Expand Down
4 changes: 3 additions & 1 deletion python/ray/rllib/a3c/torchpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ class TorchPolicy(Policy):
The model is a separate object than the policy. This could be changed
in the future."""

def __init__(self, ob_space, action_space, name="local", summarize=True):
def __init__(self, ob_space, action_space, config,
name="local", summarize=True):
self.local_steps = 0
self.config = config
self.summarize = summarize
self._setup_graph(ob_space, action_space)
torch.set_num_threads(2)
Expand Down
3 changes: 1 addition & 2 deletions python/ray/rllib/optimizers/async.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def step(self):
# Note: can't use wait: https://github.com/ray-project/ray/issues/1128
while gradient_queue:
with self.wait_timer:
fut, e = gradient_queue[0]
gradient_queue = gradient_queue[1:]
fut, e = gradient_queue.pop(0)
gradient = ray.get(fut)

if gradient is not None:
Expand Down
Loading

0 comments on commit 4bb5b6b

Please sign in to comment.