Implementation of Diffusion Transformer (DiT) in JAX
First, clone the repo. Then you can install the conda environment using conda env create -f environment.yml
. You will need to compile the TFDS datasets for imagenet2012
or celebahq
.
To run training code, use train_diffusion.py
. Evalute FID on a trained model with eval_fid.py
.
Here are some useful to commands to replicate results. These use the DiT-B
settings, with patch size of 2
for latent diffusion and 8
for pixels.
# Diffusion on Imagenet256 (w/ Stable Diffusion VAE)
python train_diffusion.py --dataset_name imagenet256 --wandb.name DiT-B --model.depth 12 --model.hidden_size 768 --model.patch_size 2 --model.num_heads 16 --model.mlp_ratio 4 --batch_size 512
# Diffusion on CelebaHQ256 (w/ Stable Diffusion VAE)
python train_diffusion.py --dataset_name celebahq256 --wandb.name DiT-B-CelebA --model.depth 12 --model.hidden_size 768 --model.patch_size 2 --model.num_heads 16 --model.mlp_ratio 4 --batch_size 512 --save_dir gs://rll/checkpoints/diffusion/dit-celeba256
# Diffusion on CelebaHQ256 (Pixels)
python train_diffusion.py --dataset_name celebahq256 --wandb.name DiT-B-CelebAPixel --model.depth 12 --model.hidden_size 768 --model.patch_size 8 --model.num_heads 16 --model.mlp_ratio 4 --batch_size 512 --use_stable_vae 0 --save_dir gs://rll/checkpoints/diffusion/dit-celebahq256pixel
Model | FID 50K (ours) | FID 50K (reference paper) |
---|---|---|
DiT-B Imagenet256, no CFG | 70.5 | 43.47 (DiT) |
DiT-XL Imagenet256, no CFG | N/A | 9.62 (DiT) |
DiT-B Imagenet256, CFG=4 | 17.7 | N/A |
DiT-B CelebAHQ256 | 28.35 | 5.11 (LDM) |