Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
eloialonso committed May 20, 2024
0 parents commit a1311e9
Show file tree
Hide file tree
Showing 42 changed files with 3,806 additions and 0 deletions.
18 changes: 18 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Custom
wandb
.vscode
outputs
results/figures
slurm*
experiments
checkpoints
saved_runs

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Jupyter Notebook
.ipynb_checkpoints

21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Eloi Alonso

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
232 changes: 232 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Diffusion for World Modeling: Visual Details Matter in Atari

**TL;DR** We introduce DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained in a diffusion world model.

>[Install](#installation), then try our [pretrained world models](#try)!
>
>```bash
>python src/play.py --pretrained
>```
<div align='center'>
Autoregressive imagination with DIAMOND on a subset of Atari games
<img alt="DIAMOND's world model in Breakout, Pong, KungFuMaster, Boxing, Asterix" src="assets/main.gif">
</div>
## Quick Links
- [Installation](#installation)
- [Try our playable diffusion world models](#try)
- [Launch a training run](#launch)
- [Configuration](#configuration)
- [Visualization](#visualization)
- [Play mode (default)](#play_mode)
- [Dataset mode (add `-d`)](#dataset_mode)
- [Other options, common to play/dataset modes](#other_options)
- [Run folder structure](#structure)
- [Results](#results)
- [Citation](#citation)
- [Credits](#credits)
<a name="installation"></a>
## [⬆️](#quick-links) Installation
Clone this repository:
```bash
git clone [email protected]:eloialonso/diamond.git
cd diamond
```
We recommend using [miniconda](https://docs.anaconda.com/free/miniconda/miniconda-install/) to create a new environment:
```bash
conda create -n diamond python=3.10
conda activate diamond
```
Install dependencies listed in [requirements.txt](requirements.txt):
```bash
pip install -r requirements.txt
```
**Warning**: Atari ROMs will be downloaded with the dependencies, which means that you acknowledge that you have the license to use them.
<a name="try"></a>
## [⬆️](#quick-links) Try our playable diffusion world models
```bash
python src/play.py --pretrained
```
Then select a game, and world model and policy pretrained on Atari 100k will be downloaded from our [repository on Hugging Face Hub 🤗](https://huggingface.co/eloialonso/diamond) and cached on your machine.
First things you might want to try:
- Press `m` to take control (the policy is playing by default).
- Press `` to increase the imagination horizon (default is 15, which is frustrating when playing yourself).
To adjust the sampling parameters (number of denoising steps, stochasticity, order, etc) of the trained diffusion world model, for instance to trade off sampling speed and quality, edit the section `world_model_env.diffusion_sampler` in the file `config/trainer.yaml`.
See [Visualization](#visualization) for more details about the available commands and options.
<a name="launch"></a>
## [⬆️](#quick-links) Launch a training run
To train with the hyperparameters used in the paper, launch:
```bash
python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0
```
This creates a new folder for your run, located in `outputs/YYYY-MM-DD/hh-mm-ss/`.
To resume a run that crashed, navigate to the fun folder and launch:
```bash
./scripts/resume.sh
```
<a name="configuration"></a>
## [⬆️](#quick-links) Configuration
We use [Hydra](https://github.com/facebookresearch/hydra) for configuration management.
All configuration files are located in the `config` folder:
- `config/trainer.yaml`: main configuration file.
- `config/agent/default.yaml`: architecture hyperparameters.
- `config/env/atari.yaml`: environment hyperparameters.
You can turn on logging to [weights & biases](https://wandb.ai) in the `wandb` section of `config/trainer.yaml`.
Set `training.model_free=true` in the file `config/trainer.yaml` to "unplug" the world model and perform standard model-free reinforcement learning.
<a name="visualization"></a>
## [⬆️](#quick-links) Visualization
<a name="play_mode"></a>
### [⬆️](#quick-links) Play mode (default)
To visualize your last checkpoint, launch **from the run folder**:
```bash
python src/play.py
```
By default, you visualize the policy playing in the world model. To play yourself, or switch to the real environment, use the controls described below.
```txt
Controls (play mode)
(Game-specific commands will be printed on start up)
: reset environment
m : switch controller (policy/human)
↑/↓ : imagination horizon (+1/-1)
←/→ : next environment [world model ←→ real env (test) ←→ real env (train)]
. : pause/unpause
e : step-by-step (when paused)
```
Add `-r` to toggle "recording mode" (works only in play mode). Every completed episode will be saved in `dataset/rec_<env_name>_<controller>`. For instance:
- `dataset/rec_wm_π`: Policy playing in world model.
- `dataset/rec_wm_H`: Human playing in world model.
- `dataset/rec_test_H`: Human playing in test real environment.
You can then use the "dataset mode" described in the next section to replay the stored episodes.
<a name="dataset_mode"></a>
### [⬆️](#quick-links) Dataset mode (add `-d`)
**In the run folder**, to visualize the datasets contained in the `dataset` subfolder, add `-d` to switch to "dataset mode":
```bash
python src/play.py -d
```
You can use the controls described below to navigate the datasets and episodes.
```txt
Controls (dataset mode)
m : next dataset (if multiple datasets, like recordings, etc)
↑/↓ : next/previous episode
←/→ : next/previous timestep in episodes
PgUp: +10 timesteps
PgDn: -10 timesteps
: back to first timestep
```
<a name="other_options"></a>
### [⬆️](#quick-links) Other options, common to play/dataset modes
```txt
--fps FPS Target frame rate (default 15).
--size SIZE Window size (default 800).
--no-header Remove header.
```
<a name="structure"></a>
## [⬆️](#quick-links) Run folder structure
Each new run is located at `outputs/YYYY-MM-DD/hh-mm-ss/`. This folder is structured as follows:
```txt
outputs/YYYY-MM-DD/hh-mm-ss/
└─── checkpoints
│ │ state.pt # full training state
│ │
│ └─── agent_versions
│ │ ...
│ │ agent_epoch_00999.pt
│ │ agent_epoch_01000.pt # agent weights only
└─── config
│ | trainer.yaml
|
└─── dataset
│ │
│ └─── train
│ | │ info.pt
│ | │ ...
| |
│ └─── test
│ │ info.pt
│ │ ...
└─── scripts
│ │ resume.sh
| | ...
|
└─── src
| | main.py
| | ...
|
└─── wandb
| ...
```
<a name="results"></a>
## [⬆️](#quick-links) Results
The file `./results/data/DIAMOND.json` contains the results for each game and seed used in the paper.
<a name="citation"></a>
## [⬆️](#quick-links) Citation
```text
TODO
```
<a name="credits"></a>
## [⬆️](#quick-links) Credits
- [https://github.com/crowsonkb/k-diffusion/](https://github.com/crowsonkb/k-diffusion/)
- [https://github.com/huggingface/huggingface_hub](https://github.com/huggingface/huggingface_hub)
- [https://github.com/google-research/rliable](https://github.com/google-research/rliable)
- [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch)
33 changes: 33 additions & 0 deletions config/agent/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
_target_: agent.AgentConfig

denoiser:
_target_: models.diffusion.DenoiserConfig
sigma_data: 0.5
sigma_offset_noise: 0.3
inner_model:
_target_: models.diffusion.InnerModelConfig
img_channels: 3
num_steps_conditioning: 4
cond_channels: 256
depths: [2,2,2,2]
channels: [64,64,64,64]
attn_depths: [0,0,0,0]

rew_end_model:
_target_: models.rew_end_model.RewEndModelConfig
lstm_dim: 512
img_channels: ${agent.denoiser.inner_model.img_channels}
img_size: ${env.train.size}
cond_channels: 128
depths: [2,2,2,2]
channels: [32,32,32,32]
attn_depths: [0,0,0,0]

actor_critic:
_target_: models.actor_critic.ActorCriticConfig
lstm_dim: 512
img_channels: ${agent.denoiser.inner_model.img_channels}
img_size: ${env.train.size}
channels: [32,32,64,64]
down: [1,1,1,1]

13 changes: 13 additions & 0 deletions config/env/atari.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
train:
id: BreakoutNoFrameskip-v4
done_on_life_loss: True
size: 64
max_episode_steps: null

test:
id: ${..train.id}
done_on_life_loss: False
size: ${..train.size}
max_episode_steps: null

keymap: atari/${.train.id}
Loading

0 comments on commit a1311e9

Please sign in to comment.