Skip to content

Commit

Permalink
initial release commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jannerm committed Jul 29, 2019
0 parents commit e39fae3
Show file tree
Hide file tree
Showing 109 changed files with 10,261 additions and 0 deletions.
120 changes: 120 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
*.pkl
*.stl

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# soft learning specific things
*.swp
.idea
*.mp4
data/
vis/
tmp/
vendor/*
.pkl


.mujoco/
.vscode/
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "viskit"]
path = viskit
url = https://github.com/vitchyr/viskit.git
72 changes: 72 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Model-Based Policy Optimization

Code to reproduce the experiments in [When to Trust Your Model: Model-Based Policy Optimization](https://arxiv.org/abs/1906.08253).

<p align="center">
<!-- <img src="https://drive.google.com/uc?export=view&id=19KA7zIjo4HVEqrJNRRgNvkpUwZ6AWGMD" width="80%"> -->
<img src="https://drive.google.com/uc?export=view&id=1siZA55atJi8Tgeefvv28WOqk7pFSynJP" width="80%">
</p>

## Installation
1. Install [MuJoCo 1.50](https://www.roboti.us/index.html) at `~/.mujoco/mjpro150` and copy your license key to `~/.mujoco/mjkey.txt`
2. Clone `mbpo`
```
git clone --recursive https://github.com/jannerm/mbpo.git
```
3. Create a conda environment and install mbpo
```
cd mbpo
conda env create -f environment/gpu-env.yml
conda activate mbpo
pip install -e viskit
pip install -e .
```

## Usage
Configuration files can be found in `examples/config/`.

```
mbpo run_local examples.development --config=examples.config.halfcheetah.0 \
--checkpoint-frequency=1000 --gpus=1 --trial-gpus=1
```

Currently only running locally is supported.

#### New environments
To run on a different environment, you can modify the provided [template](examples/config/custom/0.py). You will also need to provide the termination function for the environment in [`mbpo/static`](mbpo/static). If you name the file the lowercase version of the environment name, it will be found automatically. See [`hopper.py`](mbpo/static/hopper.py) for an example.

#### Logging

This codebase contains [viskit](https://github.com/vitchyr/viskit) as a submodule. You can view saved runs with:
```
viskit ~/ray_mbpo --port 6008
```
assuming you used the default [`log_dir`](examples/config/halfcheetah/0.py#L7).

#### Hyperparameters

The rollout length schedule is defined by a length-4 list in a [config file](examples/config/halfcheetah/0.py#L31). The format is `[start_epoch, end_epoch, start_length, end_length]`, so the following:
```
'rollout_schedule': [20, 100, 1, 5]
```
corresponds to a model rollout length linearly increasing from 1 to 5 over epochs 20 to 100.

If you want to speed up training in terms of wall clock time (but possibly make the runs less sample-efficient), you can set a timeout for model training ([`max_model_t`](examples/config/halfcheetah/0.py#L30), in seconds) or train the model less frequently (every [`model_train_freq`](examples/config/halfcheetah/0.py#L22) steps).


## Reference
If you find this code useful in an academic setting, please cite:

```
@article{janner2019mbpo,
author = {Michael Janner and Justin Fu and Marvin Zhang and Sergey Levine},
title = {When to Trust Your Model: Model-Based Policy Optimization},
journal = {arXiv preprint arXiv:1906.08253},
year = {2019}
}
```

## Acknowledgments
The underlying soft actor-critic implementation in MBPO comes from [Tuomas Haarnoja](https://scholar.google.com/citations?user=VT7peyEAAAAJ&hl=en) and [Kristian Hartikainen's](https://hartikainen.github.io/) [softlearning](https://github.com/rail-berkeley/softlearning) codebase. The modeling code is a slightly modified version of [Kurtland Chua's](https://kchua.github.io/) [PETS](https://github.com/kchua/handful-of-trials) implementation.


13 changes: 13 additions & 0 deletions environment/gpu-env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## copy of https://github.com/rail-berkeley/softlearning/blob/master/environment.yml

name: mbpo
channels:
- defaults
- conda-forge
dependencies:
- python=3.6.5
- pip>=18.0
- conda>=4.5.9
- patchelf=0.9
- pip:
- -r requirements.txt
98 changes: 98 additions & 0 deletions environment/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
tqdm
gpflow
flask
tensorboardX
absl-py==0.6.1
asn1crypto==0.24.0
astor==0.7.1
atomicwrites==1.2.1
attrs==18.2.0
awscli==1.16.67
boto3==1.9.57
botocore==1.12.57
cachetools==3.0.0
cffi==1.11.5
chardet==3.0.4
Click==7.0
cloudpickle==0.6.1
colorama==0.3.9
cryptography==2.3.1
cycler==0.10.0
Cython==0.29.1
dask==1.0.0
decorator==4.3.0
docutils==0.14
dotmap==1.3.8
deepdiff==3.3.0
flatbuffers==1.10
funcsigs==1.0.2
future==0.17.1
gast==0.2.0
gitdb2==2.0.5
GitPython==2.1.11
glfw==1.7.0
google-api-python-client==1.7.5
google-auth==1.6.1
google-auth-httplib2==0.0.3
grpcio==1.16.1
gtimer==1.0.0b5
gym==0.12.0
h5py==2.8.0
httplib2==0.12.0
idna==2.7
imageio==2.4.1
jmespath==0.9.3
Keras-Applications==1.0.6
Keras-Preprocessing==1.0.5
kiwisolver==1.0.1
lockfile==0.12.2
Markdown==3.0.1
matplotlib==3.0.2
more-itertools==4.3.0
mujoco-py==1.50.1.68
git+https://github.com/vitchyr/multiworld.git@d76b3dae2e8cbca02924f93d6cc0239c552f6408
networkx==2.2
numpy==1.15.4
pandas==0.23.4
Pillow==5.3.0
plotly==1.9.6
pluggy==0.8.0
protobuf==3.6.1
py==1.7.0
pyasn1==0.4.4
pyasn1-modules==0.2.2
pycosat==0.6.3
pycparser==2.19
pygame==1.9.4
pyglet==1.3.2
pyOpenSSL==18.0.0
pyparsing==2.3.0
PySocks==1.6.8
pytest==4.0.1
python-dateutil==2.7.5
pytz==2018.7
PyWavelets==1.0.1
PyYAML==4.2b4
psutil==5.4.8
ray[rllib,debug]==0.6.4
redis==3.0.1
requests==2.20.1
rsa==3.4.2
ruamel-yaml==0.15.46
s3transfer==0.1.13
scikit-image==0.14.2
scikit-learn==0.20.1
scipy==1.1.0
git+https://github.com/hartikainen/serializable.git@76516385a3a716ed4a2a9ad877e2d5cbcf18d4e6
setproctitle==1.1.10
six==1.11.0
smmap2==2.0.5
tensorboard==1.13.1
tensorflow-gpu==1.13.1
tensorflow-estimator==1.13.0
tensorflow-probability==0.6.0
termcolor==1.1.0
toolz==0.9.0
uritemplate==3.0.0
urllib3==1.23
Werkzeug==0.14.1
Empty file added examples/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions examples/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
params = {
'type': 'MBPO',
'universe': 'gym',
'domain': 'Hopper',
'task': 'v2',

'log_dir': '~/ray_mbpo/',
'exp_name': 'defaults',

'kwargs': {
'epoch_length': 1000,
'train_every_n_steps': 1,
'n_train_repeat': 2, #20,
'eval_render_mode': None,
'eval_n_episodes': 1,
'eval_deterministic': True,

'discount': 0.99,
'tau': 5e-3,
'reward_scale': 1.0,
####
'model_reset_freq': 1000,
'model_train_freq': 250, # 250
# 'retain_model_epochs': 2,
'model_pool_size': 2e6,
'rollout_batch': 100e3, # 40e3
'rollout_length': 1,
'deterministic': False,
'num_networks': 7,
'num_elites': 5,
'real_ratio': 0.05,
'entropy_mult': 0.5,
# 'target_entropy': -1.5,
'max_model_t': 1e10,
# 'max_dev': 0.25,
# 'marker': 'early-stop_10rep_stochastic',
'rollout_length_params': [20, 150, 1, 1], ## epoch, loss, length
# 'marker': 'dump',
}
}
33 changes: 33 additions & 0 deletions examples/config/ant/0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
params = {
'type': 'MBPO',
'universe': 'gym',
'domain': 'Ant',
'task': 'v2',

'log_dir': '~/ray_mbpo/',
'exp_name': 'defaults',

'kwargs': {
'epoch_length': 1000,
'train_every_n_steps': 1,
'n_train_repeat': 20,
'eval_render_mode': None,
'eval_n_episodes': 1,
'eval_deterministic': True,

'discount': 0.99,
'tau': 5e-3,
'reward_scale': 1.0,

'model_train_freq': 250,
'model_retain_epochs': 1,
'rollout_batch_size': 100e3,
'deterministic': False,
'num_networks': 7,
'num_elites': 5,
'real_ratio': 0.05,
'target_entropy': -4,
'max_model_t': None,
'rollout_schedule': [20, 100, 1, 25],
}
}
Loading

0 comments on commit e39fae3

Please sign in to comment.