Skip to content

kvfrans/jax-diffusion-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jax-diffusion-transformer

Implementation of Diffusion Transformer (DiT) in JAX

Installation

First, clone the repo. Then you can install the conda environment using conda env create -f environment.yml. You will need to compile the TFDS datasets for imagenet2012 or celebahq.

Usage

To run training code, use train_diffusion.py. Evalute FID on a trained model with eval_fid.py.

Here are some useful to commands to replicate results. These use the DiT-B settings, with patch size of 2 for latent diffusion and 8 for pixels.

# Diffusion on Imagenet256 (w/ Stable Diffusion VAE)
python train_diffusion.py --dataset_name imagenet256 --wandb.name DiT-B --model.depth 12 --model.hidden_size 768 --model.patch_size 2 --model.num_heads 16 --model.mlp_ratio 4 --batch_size 512

# Diffusion on CelebaHQ256 (w/ Stable Diffusion VAE)
python train_diffusion.py --dataset_name celebahq256 --wandb.name DiT-B-CelebA --model.depth 12 --model.hidden_size 768 --model.patch_size 2 --model.num_heads 16 --model.mlp_ratio 4 --batch_size 512 --save_dir gs://rll/checkpoints/diffusion/dit-celeba256

# Diffusion on CelebaHQ256 (Pixels)
python train_diffusion.py --dataset_name celebahq256 --wandb.name DiT-B-CelebAPixel --model.depth 12 --model.hidden_size 768 --model.patch_size 8 --model.num_heads 16 --model.mlp_ratio 4 --batch_size 512 --use_stable_vae 0 --save_dir gs://rll/checkpoints/diffusion/dit-celebahq256pixel
Model FID 50K (ours) FID 50K (reference paper)
DiT-B Imagenet256, no CFG 70.5 43.47 (DiT)
DiT-XL Imagenet256, no CFG N/A 9.62 (DiT)
DiT-B Imagenet256, CFG=4 17.7 N/A
DiT-B CelebAHQ256 28.35 5.11 (LDM)

Examples

DiT-B Imagenet, CFG=4

DiT-B CelebAHQ256

About

Implementation of Diffusion Transformer (DiT) in JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages