Skip to content

Commit

Permalink
Merge pull request nicklashansen#10 from nicklashansen/experimental
Browse files Browse the repository at this point in the history
[Feature] Faster replay buffer + support pixel observations
  • Loading branch information
nicklashansen authored Dec 28, 2023
2 parents bfb1971 + 6cb779a commit 1f6c777
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 106 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.

<img src="assets/8.png" width="100%" style="max-width: 640px"><br/>

This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL.
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.

----

Expand All @@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image
cd docker && docker build . -t <user>/tdmpc2:0.1.0
```

If you prefer to install dependencies manually, start by installing dependencies via `conda` by running
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:

```
conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml
```

The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.

If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running

```
Expand Down Expand Up @@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| metaworld | mw-pick-place-wall
| maniskill | pick-cube
| maniskill | pick-ycb
| myosuite | myo-hand-key-turn
| myosuite | myo-hand-key-turn-hard
| myosuite | myo-key-turn
| myosuite | myo-key-turn-hard

which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.

**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.


## Example usage

Expand All @@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and
$ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000
$ python train.py task=walker-walk obs=rgb
```

We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
Expand Down
1 change: 1 addition & 0 deletions docker/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
Expand Down
39 changes: 39 additions & 0 deletions docker/environment_minimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- setuptools==65.5.0
- "cython<3"
- dm-control
- pillow
- tensordict-nightly
- torchrl-nightly
- wandb
113 changes: 47 additions & 66 deletions tdmpc2/common/buffer.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,27 @@
from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict, Transform, Compose

from common.logger import make_dir


class DataPrepTransform(Transform):
"""
Preprocesses data for TD-MPC2 training.
Replay data is expected to be a TensorDict with the following keys:
obs: observations
action: actions
reward: rewards
task: task IDs (optional)
A TensorDict with T time steps has T+1 observations and T actions and rewards.
The first actions and rewards in each TensorDict are dummies and should be ignored.
"""

def __init__(self):
super().__init__([])

def forward(self, td):
td = td.permute(1,0)
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
from common.samplers import SliceSampler


class Buffer():
"""
Create a replay buffer for TD-MPC2 training.
Replay buffer for TD-MPC2 training. Based on torchrl.
Uses CUDA memory if available, and CPU memory otherwise.
"""

def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0

@property
Expand All @@ -53,63 +37,60 @@ def num_eps(self):
def _reserve_buffer(self, storage):
"""
Reserve a buffer with the given storage.
Uses the RandomSampler to sample trajectories,
and the RandomCropTensorDict transform to crop trajectories to the desired length.
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
"""
return ReplayBuffer(
storage=storage,
sampler=RandomSampler(),
sampler=self._sampler,
pin_memory=True,
prefetch=1,
transform=Compose(
RandomCropTensorDict(self.cfg.horizon+1, -1),
DataPrepTransform(),
),
batch_size=self.cfg.batch_size,
batch_size=self._batch_size,
)

def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
bytes_per_step = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
for v in tds.values()
]) / len(tds)
total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
print('Using CPU memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
)
else: # Sufficient CUDA memory
print('Using CUDA memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
print(f'Using {storage_device.upper()} memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
)

def _to_device(self, *args, device=None):
if device is None:
device = self._device
return (arg.to(device, non_blocking=True) \
if arg is not None else None for arg in args)

def _prepare_batch(self, td):
"""
Prepare a sampled batch for training (post-processing).
Expects `td` to be a TensorDict with batch size TxB.
"""
obs = td['obs']
action = td['action'][1:]
reward = td['reward'][1:].unsqueeze(-1)
task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task)

def add(self, tds):
"""Add an episode to the buffer. All episodes are expected to have the same length."""
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0:
self._buffer = self._init(tds)
self._buffer.add(tds)
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps

def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
return obs.to(self._device, non_blocking=True), \
action.to(self._device, non_blocking=True), \
reward.to(self._device, non_blocking=True), \
task.to(self._device, non_blocking=True) if task is not None else None

def save(self):
"""Save the buffer to disk. Useful for storing offline datasets."""
td = self._buffer._storage._storage.cpu()
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
torch.save(td, fp)
"""Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
return self._prepare_batch(td)
Loading

0 comments on commit 1f6c777

Please sign in to comment.