Skip to content

Commit

Permalink
Train experts for AIRL environments (HumanCompatibleAI#68)
Browse files Browse the repository at this point in the history
* scripts.config: Organize named configs

* scripts.config: Add custom env named configs

* mujoco_experts.sh: Train custom env experts

* Rename mujoco_experts.sh => train_experts.sh

* Fix lint/Travis/imports

* train_experts: save to train_experts/

* setup.py: Install stable_baselines from master

Includes MPI hang fix

* Accommodate PPO1 import fail

* address comments

* Address comments

* Fix configs

* lint

* Fix bad merge: add back in fast
  • Loading branch information
shwang authored Aug 15, 2019
1 parent 35857d4 commit 98a7929
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 52 deletions.
17 changes: 0 additions & 17 deletions experiments/mujoco_experts.sh

This file was deleted.

25 changes: 25 additions & 0 deletions experiments/train_experts.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash

ENVS+="acrobot cartpole mountain_car "
ENVS+="reacher half_cheetah hopper ant humanoid swimmer walker "
ENVS+="two_d_maze custom_ant disabled_ant "
SEEDS="0 1 2"

if $(command -v gdate > /dev/null); then
DATE_CMD=gdate # macOS compatibility
else
DATE_CMD=date
fi

TIMESTAMP=$(${DATE_CMD} --iso-8601=seconds)
OUTPUT_DIR=output/train_experts/${TIMESTAMP}/

echo "Writing logs in ${OUTPUT_DIR}"

parallel -j 25% --header : --progress --results ${OUTPUT_DIR}/parallel/ \
python -m imitation.scripts.expert_demos \
with \
{env} \
seed={seed} \
log_root=${OUTPUT_DIR} \
::: env ${ENVS} ::: seed ${SEEDS}
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
'numpy>=1.15',
'tqdm',
'scikit-learn>=0.21.2',
'stable-baselines>=2.7.0',
# FIXME: Use stable release instead of tracking master once
# commit 9a760542 is released.
'stable-baselines @ git+https://github.com/hill-a/stable-baselines.git',
'jax!=0.1.37',
'jaxlib~=0.1.20',
# sacred==0.7.5 build is broken without pymongo
Expand Down
17 changes: 11 additions & 6 deletions src/imitation/policies/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Callable, Optional, Type

import gym
import stable_baselines
from stable_baselines.common.base_class import BaseRLModel
from stable_baselines.common.policies import BasePolicy
from stable_baselines.common.vec_env import VecEnv, VecNormalize
Expand Down Expand Up @@ -87,13 +86,19 @@ def f(path: str, env: gym.Env) -> BasePolicy:
value=registry.build_loader_fn_require_space(ZeroPolicy))

STABLE_BASELINES_CLASSES = {
'ppo1': (stable_baselines.PPO1, 'policy_pi'),
'ppo2': (stable_baselines.PPO2, 'act_model'),
'ppo1': ('stable_baselines:PPO1', 'policy_pi'),
'ppo2': ('stable_baselines:PPO2', 'act_model'),
}

for k, (cls, attr) in STABLE_BASELINES_CLASSES.items():
fn = _load_stable_baselines(cls, attr)
policy_registry.register(k, value=fn)
for k, (cls_name, attr) in STABLE_BASELINES_CLASSES.items():
try:
cls = registry.load_attr(cls_name)
fn = _load_stable_baselines(cls, attr)
policy_registry.register(k, value=fn)
except (AttributeError, ImportError):
# We expect PPO1 load to fail if mpi4py isn't installed.
# Stable Baselines can be installed without mpi4py.
tf.logging.debug(f"Couldn't load {cls_name}. Skipping...")


def load_policy(policy_type: str, policy_path: str,
Expand Down
56 changes: 43 additions & 13 deletions src/imitation/scripts/config/expert_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,29 @@ def logging(env_name, log_root):
util.make_unique_timestamp())


@expert_demos_ex.named_config
def ant():
env_name = "Ant-v2"
make_blank_policy_kwargs = dict(
n_steps=2048, # batch size of 2048*8=16384 due to num_vec
)
total_timesteps = int(5e6) # OK after 2e6, but continues improving
# Shared settings

ant_shared_locals = dict(
make_blank_policy_kwargs=dict(
n_steps=2048, # batch size of 2048*8=16384 due to num_vec
),
total_timesteps=int(5e6),
)


# Standard Gym env configs

@expert_demos_ex.named_config
def acrobot():
env_name = "Acrobot-v1"


@expert_demos_ex.named_config
def ant():
env_name = "Ant-v2"
locals().update(**ant_shared_locals)


@expert_demos_ex.named_config
def cartpole():
env_name = "CartPole-v1"
Expand All @@ -60,6 +69,12 @@ def half_cheetah():
total_timesteps = int(5e6) # does OK after 1e6, but continues improving


@expert_demos_ex.named_config
def hopper():
# TODO(adam): upgrade to Hopper-v3?
env_name = "Hopper-v2"


@expert_demos_ex.named_config
def humanoid():
env_name = "Humanoid-v2"
Expand All @@ -69,12 +84,6 @@ def humanoid():
total_timesteps = int(10e6) # fairly discontinuous, needs at least 5e6


@expert_demos_ex.named_config
def hopper():
# TODO(adam): upgrade to Hopper-v3?
env_name = "Hopper-v2"


@expert_demos_ex.named_config
def mountain_car():
env_name = "MountainCar-v0"
Expand All @@ -100,6 +109,27 @@ def walker():
env_name = "Walker2d-v2"


# Custom env configs

@expert_demos_ex.named_config
def custom_ant():
env_name = "imitation/CustomAnt-v0"
locals().update(**ant_shared_locals)


@expert_demos_ex.named_config
def disabled_ant():
env_name = "imitation/DisabledAnt-v0"
locals().update(**ant_shared_locals)


@expert_demos_ex.named_config
def two_d_maze():
env_name = "imitation/TwoDMaze-v0"


# Debug configs

@expert_demos_ex.named_config
def fast():
"""Intended for testing purposes: small # of updates, ends quickly."""
Expand Down
44 changes: 35 additions & 9 deletions src/imitation/scripts/config/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def paths(env_name, log_root):
"*", "rollouts", "final.pkl")


# Training algorithm configs

@train_ex.named_config
def gail():
init_trainer_kwargs = dict(
Expand All @@ -72,17 +74,20 @@ def airl():
)


@train_ex.named_config
def ant():
env_name = "Ant-v2"
n_epochs = 2000
# Standard Gym env configs


@train_ex.named_config
def acrobot():
env_name = "Acrobot-v1"


@train_ex.named_config
def ant():
env_name = "Ant-v2"
n_epochs = 2000


@train_ex.named_config
def cartpole():
env_name = "CartPole-v1"
Expand Down Expand Up @@ -128,11 +133,6 @@ def reacher():
env_name = "Reacher-v2"


@train_ex.named_config
def walker():
env_name = "Walker2d-v2"


@train_ex.named_config
def swimmer():
env_name = "Swimmer-v2"
Expand All @@ -144,6 +144,32 @@ def swimmer():
)


@train_ex.named_config
def walker():
env_name = "Walker2d-v2"


# Custom env configs

@train_ex.named_config
def two_d_maze():
env_name = "imitation/TwoDMaze-v0"


@train_ex.named_config
def custom_ant():
env_name = "imitation/CustomAnt-v0"
n_epochs = 2000


@train_ex.named_config
def disabled_ant():
env_name = "imitation/DisabledAnt-v0"
n_epochs = 2000


# Debug configs

@train_ex.named_config
def fast():
"""Minimize the amount of computation. Useful for test cases."""
Expand Down
1 change: 1 addition & 0 deletions src/imitation/scripts/eval_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from stable_baselines.common.vec_env import VecEnvWrapper
import tensorflow as tf

import imitation.envs.examples # noqa: F401
from imitation.policies import serialize
from imitation.scripts.config.eval_policy import eval_policy_ex
from imitation.util import rollout, util
Expand Down
1 change: 1 addition & 0 deletions src/imitation/scripts/expert_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from stable_baselines.common.vec_env import VecNormalize
import tensorflow as tf

import imitation.envs.examples # noqa: F401
from imitation.policies import serialize
from imitation.rewards.discrim_net import DiscrimNetAIRL
from imitation.scripts.config.expert_demos import expert_demos_ex
Expand Down
1 change: 1 addition & 0 deletions src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tqdm

from imitation.algorithms.adversarial import init_trainer
import imitation.envs.examples # noqa: F401
from imitation.policies import serialize
from imitation.scripts.config.train_adversarial import train_ex
import imitation.util as util
Expand Down
4 changes: 2 additions & 2 deletions src/imitation/util/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""The type stored in Registry is commonly an instance of LoaderFn."""


def load(name):
def load_attr(name):
"""Load an attribute in format path.to.module:attribute."""
module_name, attr_name = name.split(":")
module = importlib.import_module(module_name)
Expand Down Expand Up @@ -42,7 +42,7 @@ def get(self, key: str) -> T:
raise KeyError(f"Key '{key}' is not registered.")

if key not in self._values:
self._values[key] = load(self._indirect[key])
self._values[key] = load_attr(self._indirect[key])
return self._values[key]

def register(self, key: str, *,
Expand Down
15 changes: 11 additions & 4 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from stable_baselines.common.vec_env import VecNormalize

from imitation.policies import serialize
from imitation.util import rollout, util
from imitation.util import registry, rollout, util

SIMPLE_ENVS = [
"CartPole-v0", # Discrete(2) action space
"MountainCarContinuous-v0", # Box(1) action space
]
HARDCODED_TYPES = ["random", "zero"]
BASELINE_MODELS = [(name, cls)
for name, (cls, attr) in
BASELINE_MODELS = [(name, cls_name)
for name, (cls_name, attr) in
serialize.STABLE_BASELINES_CLASSES.items()]


Expand All @@ -42,7 +42,14 @@ def test_serialize_identity(env_name, model_cfg, normalize):
vec_normalize = None
if normalize:
venv = vec_normalize = VecNormalize(venv)
model_name, model_cls = model_cfg

model_name, model_cls_name = model_cfg
try:
model_cls = registry.load_attr(model_cls_name)
except (AttributeError, ImportError):
pytest.skip("Couldn't load stable baselines class. "
"(Probably because mpi4py not installed.)")

model = model_cls('MlpPolicy', venv)
model.learn(1000)

Expand Down

0 comments on commit 98a7929

Please sign in to comment.