This is an implementation of TD-MPC and TD-MPC2 using JAX/Flax. While the official implementation of TD-MPC2 supports learning multiple tasks, this implementation focuses on simplicity and only supports learning a single task.
Inspired by CleanRL, this implementation is designed to enhance readability by not dispersing the algorithm processing across multiple files, but rather organizing it within a single file to make the algorithmic flow easy to follow.
Here are the test results of training in the DM-Control environment.
- Learning results by TD-MPC
- Learning results by TD-MPC2
Follow the steps below to set up the execution environment.
# Build the image
docker build -t simple_tdmpc .
# Start the container
docker run \
--gpus all \
-it \
--rm \
-w $HOME/work \
-v $(pwd):$HOME/work \
simple_tdmpc:latest bash
Install dependencies using Poetry.
poetry install
poetry run python src/tdmpc.py
poetry run python src/tdmpc2.py
-
Adding the
--capture_video
option allows saving videos of the training process in the/videos
folder.poetry run python src/tdmpc2.py --capture_video
-
The
--track
option enables recording of experiment logs through wandb.poetry run wandb login poetry run python src/tdmpc2.py --track --capture_video
If a 'transport failed error' occurs, execute the
git config
command displayed in the error message. -
The
--task
option allows switching the task for training.poetry run python src/tdmpc2.py --task 'dm_control/quadruped-run-v0' --total_timesteps 1000000