Skip to content

Commit

Permalink
[rllib] more user-friendly Optimizer signature + compute_apply (ray-p…
Browse files Browse the repository at this point in the history
…roject#2335)

* Move signature of optimizers

* fix

* expose compute_apply for policy_graphs

* dictionaries and such

* test for multiagent
  • Loading branch information
richardliaw authored Jul 7, 2018
1 parent e3534c4 commit e32aed8
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 21 deletions.
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _init(self):
self.env_creator, policy_cls, self.config["num_workers"],
{"num_gpus": 1 if self.config["use_gpu_for_workers"] else 0})
self.optimizer = AsyncGradientsOptimizer(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)
self.local_evaluator, self.remote_evaluators,
self.config["optimizer"])

def _train(self):
self.optimizer.step()
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _init(self):
remote_cls.remote(self.env_creator, self.config, self.logdir)
for _ in range(self.config["num_workers"])]
self.optimizer = AsyncGradientsOptimizer(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)
self.local_evaluator, self.remote_evaluators,
self.config["optimizer"])

def _train(self):
self.optimizer.step()
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def _init(self):
{"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"]})
self.optimizer = getattr(optimizers, self.config["optimizer_class"])(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)
self.local_evaluator, self.remote_evaluators,
self.config["optimizer"])

self.last_target_update_ts = 0
self.num_target_updates = 0
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/pg/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _init(self):
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, PGPolicyGraph, self.config["num_workers"], {})
self.optimizer = SyncSamplesOptimizer(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)
self.local_evaluator, self.remote_evaluators,
self.config["optimizer"])

def _train(self):
self.optimizer.step()
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def _init(self):
{"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"]})
self.optimizer = LocalMultiGPUOptimizer(
self.local_evaluator, self.remote_evaluators,
{"sgd_batch_size": self.config["sgd_batchsize"],
"sgd_stepsize": self.config["sgd_stepsize"],
"num_sgd_iter": self.config["num_sgd_iter"],
"timesteps_per_batch": self.config["timesteps_per_batch"]},
self.local_evaluator, self.remote_evaluators)
"timesteps_per_batch": self.config["timesteps_per_batch"]})

def _train(self):
def postprocess_samples(batch):
Expand Down
16 changes: 16 additions & 0 deletions python/ray/rllib/evaluation/policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ def apply_gradients(self, gradients):
"""
raise NotImplementedError

def compute_apply(self, samples):
"""Fused compute gradients and apply gradients call.
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
apply_info: dictionary of extra metadata from apply_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.compute_apply(samples)
"""

grads, grad_info = self.compute_gradients(samples)
apply_info = self.apply_gradients(grads)
return grad_info, apply_info

def get_weights(self):
"""Returns model weights.
Expand Down
8 changes: 4 additions & 4 deletions python/ray/rllib/optimizers/policy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PolicyOptimizer(object):
evaluators created by this optimizer.
"""

def __init__(self, config, local_evaluator, remote_evaluators):
def __init__(self, local_evaluator, remote_evaluators=None, config=None):
"""Create an optimizer instance.
Args:
Expand All @@ -41,10 +41,10 @@ def __init__(self, config, local_evaluator, remote_evaluators):
evaluators instances. If empty, the optimizer should fall back
to using only the local evaluator.
"""
self.config = config
self.local_evaluator = local_evaluator
self.remote_evaluators = remote_evaluators
self._init(**config)
self.remote_evaluators = remote_evaluators or []
self.config = config or {}
self._init(**self.config)

# Counters that should be updated by sub-classes
self.num_steps_trained = 0
Expand Down
3 changes: 1 addition & 2 deletions python/ray/rllib/optimizers/sync_samples_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def step(self):
samples = self.local_evaluator.sample()

with self.grad_timer:
grad, _ = self.local_evaluator.compute_gradients(samples)
self.local_evaluator.apply_gradients(grad)
self.local_evaluator.compute_apply(samples)
self.grad_timer.push_units_processed(samples.count)

self.num_steps_sampled += samples.count
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/test/test_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _testWithOptimizer(self, optimizer_cls):
batch_steps=50)]
else:
remote_evs = []
optimizer = optimizer_cls({}, ev, remote_evs)
optimizer = optimizer_cls(ev, remote_evs, {})
for i in range(200):
ev.foreach_policy(
lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02))
Expand Down Expand Up @@ -338,7 +338,7 @@ def testTrainMultiCartpoleManyPolicies(self):
policy_graph=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer({}, ev, [])
optimizer = SyncSamplesOptimizer(ev, [], {})
for i in range(100):
optimizer.step()
result = collect_metrics(ev)
Expand Down
5 changes: 2 additions & 3 deletions python/ray/rllib/test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ def testBasic(self):
local = _MockEvaluator()
remotes = ray.remote(_MockEvaluator)
remote_evaluators = [remotes.remote() for i in range(5)]
test_optimizer = AsyncGradientsOptimizer({
"grads_per_step": 10
}, local, remote_evaluators)
test_optimizer = AsyncGradientsOptimizer(
local, remote_evaluators, {"grads_per_step": 10})
test_optimizer.step()
self.assertTrue(all(local.get_weights() == 0))

Expand Down

0 comments on commit e32aed8

Please sign in to comment.