forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tune] clean up population based training prototype (ray-project#1478)
* patch up pbt * Sat Jan 27 01:00:03 PST 2018 * Sat Jan 27 01:04:14 PST 2018 * Sat Jan 27 01:04:21 PST 2018 * Sat Jan 27 01:15:15 PST 2018 * Sat Jan 27 01:15:42 PST 2018 * Sat Jan 27 01:16:14 PST 2018 * Sat Jan 27 01:38:42 PST 2018 * Sat Jan 27 01:39:21 PST 2018 * add pbt * Sat Jan 27 01:41:19 PST 2018 * Sat Jan 27 01:44:21 PST 2018 * Sat Jan 27 01:45:46 PST 2018 * Sat Jan 27 16:54:42 PST 2018 * Sat Jan 27 16:57:53 PST 2018 * clean up test * Sat Jan 27 18:01:15 PST 2018 * Sat Jan 27 18:02:54 PST 2018 * Sat Jan 27 18:11:18 PST 2018 * Sat Jan 27 18:11:55 PST 2018 * Sat Jan 27 18:14:09 PST 2018 * review * try out a ppo example * some tweaks to ppo example * add postprocess hook * Sun Jan 28 15:00:40 PST 2018 * clean up custom explore fn * Sun Jan 28 15:10:21 PST 2018 * Sun Jan 28 15:14:53 PST 2018 * Sun Jan 28 15:17:04 PST 2018 * Sun Jan 28 15:33:13 PST 2018 * Sun Jan 28 15:56:40 PST 2018 * Sun Jan 28 15:57:36 PST 2018 * Sun Jan 28 16:00:35 PST 2018 * Sun Jan 28 16:02:58 PST 2018 * Sun Jan 28 16:29:50 PST 2018 * Sun Jan 28 16:30:36 PST 2018 * Sun Jan 28 16:31:44 PST 2018 * improve tune doc * concepts * update humanoid * Fri Feb 2 18:03:33 PST 2018 * fix example * show error file
- Loading branch information
Showing
22 changed files
with
702 additions
and
292 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/env python | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import json | ||
import os | ||
import random | ||
import time | ||
|
||
import ray | ||
from ray.tune import Trainable, TrainingResult, register_trainable, \ | ||
run_experiments | ||
from ray.tune.pbt import PopulationBasedTraining | ||
|
||
|
||
class MyTrainableClass(Trainable): | ||
"""Fake agent whose learning rate is determined by dummy factors.""" | ||
|
||
def _setup(self): | ||
self.timestep = 0 | ||
self.current_value = 0.0 | ||
|
||
def _train(self): | ||
time.sleep(0.1) | ||
|
||
# Reward increase is parabolic as a function of factor_2, with a | ||
# maxima around factor_1=10.0. | ||
self.current_value += max( | ||
0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0)) | ||
|
||
# Flat increase by factor_2 | ||
self.current_value += random.gauss(self.config["factor_2"], 1.0) | ||
|
||
# Here we use `episode_reward_mean`, but you can also report other | ||
# objectives such as loss or accuracy (see tune/result.py). | ||
return TrainingResult( | ||
episode_reward_mean=self.current_value, timesteps_this_iter=1) | ||
|
||
def _save(self, checkpoint_dir): | ||
path = os.path.join(checkpoint_dir, "checkpoint") | ||
with open(path, "w") as f: | ||
f.write(json.dumps( | ||
{"timestep": self.timestep, "value": self.current_value})) | ||
return path | ||
|
||
def _restore(self, checkpoint_path): | ||
with open(checkpoint_path) as f: | ||
data = json.loads(f.read()) | ||
self.timestep = data["timestep"] | ||
self.current_value = data["value"] | ||
|
||
|
||
register_trainable("my_class", MyTrainableClass) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--smoke-test", action="store_true", help="Finish quickly for testing") | ||
args, _ = parser.parse_known_args() | ||
ray.init() | ||
|
||
pbt = PopulationBasedTraining( | ||
time_attr="training_iteration", reward_attr="episode_reward_mean", | ||
perturbation_interval=10, | ||
hyperparam_mutations={ | ||
# Allow for scaling-based perturbations, with a uniform backing | ||
# distribution for resampling. | ||
"factor_1": lambda config: random.uniform(0.0, 20.0), | ||
# Only allows resampling from this list as a perturbation. | ||
"factor_2": [1, 2], | ||
}) | ||
|
||
# Try to find the best factor 1 and factor 2 | ||
run_experiments({ | ||
"pbt_test": { | ||
"run": "my_class", | ||
"stop": {"training_iteration": 2 if args.smoke_test else 99999}, | ||
"repeat": 10, | ||
"resources": {"cpu": 1, "gpu": 0}, | ||
"config": { | ||
"factor_1": 4.0, | ||
"factor_2": 1.0, | ||
}, | ||
} | ||
}, scheduler=pbt, verbose=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python | ||
|
||
"""Example of using PBT with RLlib. | ||
Note that this requires a cluster with at least 8 GPUs in order for all trials | ||
to run concurrently, otherwise PBT will round-robin train the trials which | ||
is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead). | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import random | ||
|
||
import ray | ||
from ray.tune import run_experiments | ||
from ray.tune.pbt import PopulationBasedTraining | ||
|
||
if __name__ == "__main__": | ||
|
||
# Postprocess the perturbed config to ensure it's still valid | ||
def explore(config): | ||
# ensure we collect enough timesteps to do sgd | ||
if config["timesteps_per_batch"] < config["sgd_batchsize"] * 2: | ||
config["timesteps_per_batch"] = config["sgd_batchsize"] * 2 | ||
# ensure we run at least one sgd iter | ||
if config["num_sgd_iter"] < 1: | ||
config["num_sgd_iter"] = 1 | ||
return config | ||
|
||
pbt = PopulationBasedTraining( | ||
time_attr="time_total_s", reward_attr="episode_reward_mean", | ||
perturbation_interval=120, | ||
resample_probability=0.25, | ||
# Specifies the resampling distributions of these hyperparams | ||
hyperparam_mutations={ | ||
"lambda": lambda config: random.uniform(0.9, 1.0), | ||
"clip_param": lambda config: random.uniform(0.01, 0.5), | ||
"sgd_stepsize": lambda config: random.uniform(.00001, .001), | ||
"num_sgd_iter": lambda config: random.randint(1, 30), | ||
"sgd_batchsize": lambda config: random.randint(128, 16384), | ||
"timesteps_per_batch": | ||
lambda config: random.randint(2000, 160000), | ||
}, | ||
custom_explore_fn=explore) | ||
|
||
ray.init() | ||
run_experiments({ | ||
"pbt_humanoid_test": { | ||
"run": "PPO", | ||
"env": "Humanoid-v1", | ||
"repeat": 8, | ||
"resources": {"cpu": 4, "gpu": 1}, | ||
"config": { | ||
"kl_coeff": 1.0, | ||
"num_workers": 8, | ||
"devices": ["/gpu:0"], | ||
"model": {"free_log_std": True}, | ||
# These params are tuned from their starting value | ||
"lambda": 0.95, | ||
"clip_param": 0.2, | ||
# Start off with several random variations | ||
"sgd_stepsize": lambda spec: random.uniform(.00001, .001), | ||
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]), | ||
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]), | ||
"timesteps_per_batch": | ||
lambda spec: random.choice([10000, 20000, 40000]) | ||
}, | ||
}, | ||
}, scheduler=pbt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.