Skip to content

Commit

Permalink
Remove joblib and old-style samplers (#1353)
Browse files Browse the repository at this point in the history
This includes OffPolicyVectorizedSampler.
  • Loading branch information
krzentner authored Jun 25, 2020
1 parent 394e1de commit 8baff69
Show file tree
Hide file tree
Showing 92 changed files with 323 additions and 2,381 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def gaussian_cnn_baseline(ctxt, env_id, seed):
gae_lambda=0.95,
lr_clip_range=0.2,
policy_ent_coeff=0.0,
flatten_input=False,
optimizer_args=dict(
batch_size=32,
max_epochs=10,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,19 @@ def categorical_cnn_policy(ctxt, env_id, seed):
hidden_sizes=hyper_params['hidden_sizes'],
use_trust_region=hyper_params['use_trust_region']))

algo = PPO(
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
discount=0.99,
gae_lambda=0.95,
lr_clip_range=0.2,
policy_ent_coeff=0.0,
optimizer_args=dict(
batch_size=32,
max_epochs=10,
learning_rate=1e-3,
),
flatten_input=False,
)
algo = PPO(env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
discount=0.99,
gae_lambda=0.95,
lr_clip_range=0.2,
policy_ent_coeff=0.0,
optimizer_args=dict(
batch_size=32,
max_epochs=10,
learning_rate=1e-3,
))

runner.setup(algo, env)
runner.train(n_epochs=hyper_params['n_epochs'],
Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ cma==2.7.0
dm_env
dowel==0.0.3
gym[atari, box2d, classic_control]==0.15.4
joblib<0.13,>=0.12
psutil
pyglet<1.4.0,>=1.3.0
pyprind
Expand Down
3 changes: 1 addition & 2 deletions examples/np/cem_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from garage.experiment.deterministic import set_seed
from garage.np.algos import CEM
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.policies import CategoricalMLPPolicy


Expand Down Expand Up @@ -46,7 +45,7 @@ def cem_cartpole(ctxt=None, seed=1):
max_path_length=100,
n_samples=n_samples)

runner.setup(algo, env, sampler_cls=OnPolicyVectorizedSampler)
runner.setup(algo, env)
runner.train(n_epochs=100, batch_size=1000)


Expand Down
3 changes: 1 addition & 2 deletions examples/np/cma_es_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from garage.experiment.deterministic import set_seed
from garage.np.algos import CMAES
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.policies import CategoricalMLPPolicy


Expand Down Expand Up @@ -46,7 +45,7 @@ def cma_es_cartpole(ctxt=None, seed=1):
max_path_length=100,
n_samples=n_samples)

runner.setup(algo, env, sampler_cls=OnPolicyVectorizedSampler)
runner.setup(algo, env)
runner.train(n_epochs=100, batch_size=1000)


Expand Down
18 changes: 11 additions & 7 deletions examples/sim_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import sys

import joblib
import cloudpickle
import tensorflow as tf

from garage.sampler.utils import rollout
Expand All @@ -12,12 +12,16 @@
def query_yes_no(question, default='yes'):
"""Ask a yes/no question via raw_input() and return their answer.
"question" is a string that is presented to the user.
"default" is the presumed answer if the user just hits <Enter>.
It must be "yes" (the default), "no" or None (meaning
an answer is required of the user).
Args:
question (str): Printed to user.
default (str or None): Default if user just hits enter.
Raises:
ValueError: If the provided default is invalid.
Returns:
bool: True for "yes"y answers, False for "no".
The "answer" return value is True for "yes" or False for "no".
"""
valid = {'yes': True, 'y': True, 'ye': True, 'no': False, 'n': False}
if default is None:
Expand Down Expand Up @@ -57,7 +61,7 @@ def query_yes_no(question, default='yes'):
# with tf.compat.v1.Session():
# [rest of the code]
with tf.compat.v1.Session() as sess:
data = joblib.load(args.file)
data = cloudpickle.load(args.file)
policy = data['algo'].policy
env = data['env']
while True:
Expand Down
1 change: 1 addition & 0 deletions examples/tf/ddpg_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def ddpg_pendulum(ctxt=None, seed=1):
qf_lr=1e-3,
qf=qf,
replay_buffer=replay_buffer,
max_path_length=100,
steps_per_epoch=20,
target_update_tau=1e-2,
n_train_steps=50,
Expand Down
1 change: 1 addition & 0 deletions examples/tf/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def dqn_cartpole(ctxt=None, seed=1):
policy=policy,
qf=qf,
exploration_policy=exploration_policy,
max_path_length=100,
replay_buffer=replay_buffer,
steps_per_epoch=steps_per_epoch,
qf_lr=1e-4,
Expand Down
4 changes: 2 additions & 2 deletions examples/tf/dqn_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

@click.command()
@click.option('--buffer_size', type=int, default=int(5e4))
@click.option('--max_path_length', type=int, default=None)
@click.option('--max_path_length', type=int, default=500)
@wrap_experiment
def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_path_length=None):
def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_path_length=500):
"""Train DQN on PongNoFrameskip-v4 environment.
Args:
Expand Down
3 changes: 1 addition & 2 deletions examples/tf/ppo_memorize_digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def ppo_memorize_digits(ctxt=None, seed=1, batch_size=4000):
batch_size=32,
max_epochs=10,
learning_rate=1e-3,
),
flatten_input=False)
))

runner.setup(algo, env)
runner.train(n_epochs=1000, batch_size=batch_size)
Expand Down
1 change: 1 addition & 0 deletions examples/tf/td3_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def td3_pendulum(ctxt=None, seed=1):
qf_lr=1e-3,
qf=qf,
qf2=qf2,
max_path_length=100,
replay_buffer=replay_buffer,
target_update_tau=1e-2,
steps_per_epoch=20,
Expand Down
60 changes: 0 additions & 60 deletions examples/tf/trpo_cartpole_batch_sampler.py

This file was deleted.

3 changes: 1 addition & 2 deletions examples/tf/trpo_cubecrash.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def trpo_cubecrash(ctxt=None, seed=1, batch_size=4000):
discount=0.99,
gae_lambda=0.95,
lr_clip_range=0.2,
policy_ent_coeff=0.0,
flatten_input=False)
policy_ent_coeff=0.0)

runner.setup(algo, env)
runner.train(n_epochs=100, batch_size=batch_size)
Expand Down
51 changes: 0 additions & 51 deletions examples/tf/trpois_inverted_pendulum.py

This file was deleted.

53 changes: 0 additions & 53 deletions examples/tf/vpgis_inverted_pendulum.py

This file was deleted.

1 change: 1 addition & 0 deletions examples/torch/ddpg_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def ddpg_pendulum(ctxt=None, seed=1, lr=1e-4):
policy=policy,
qf=qf,
replay_buffer=replay_buffer,
max_path_length=100,
steps_per_epoch=20,
n_train_steps=50,
min_buffer_size=int(1e4),
Expand Down
2 changes: 1 addition & 1 deletion examples/torch/trpo_pendulum_ray_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""This is an example to train a task with TRPO algorithm (PyTorch).
Uses Ray sampler instead of OnPolicyVectorizedSampler.
Uses Ray sampler instead of MultiprocessingSampler.
Here it runs InvertedDoublePendulum-v2 environment with 100 iterations.
"""
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
'cma==2.7.0',
'dowel==0.0.3',
'gym[atari,box2d,classic_control]' + GYM_VERSION,
'joblib<0.13,>=0.12',
'numpy>=1.14.5',
'psutil',
# Pyglet 1.4.0 introduces some api change which breaks some
Expand Down
Loading

0 comments on commit 8baff69

Please sign in to comment.