Remix-DiT: Mixing Diffusion Transformers for Multi-Expert Denoising
Gongfan Fang, Xinyin Ma, Xinchao Wang
National University of Singapore
📄 [Arxiv]
The goal of Remix-DiT is to craft N diffusion experts for different denoising timesteps, yet without the need for expensive training of N independent models. Each expert only handles a subset of the denoising tasks and thus can extend the total capacity without heavily increasing the active model size on each timestep. To achieve this, Remix-DiT employs K basis models (where K < N) and utilizes learnable mixing coefficients to adaptively craft expert models.
pip install -r requirements.txt
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path data/imagenet/train --features-path data/imagenet_encoded
mkdir -p pretrained && cd pretrained
wget
torchrun --nnodes=1 --nproc_per_node=8 --master_port=22238 train_fast.py --model RemixDiT-S/2 --load-weight pretrained/DiT_S_2_2M.pt --data-path data/imagenet_encoded --epochs 20 --prefix RemixDiT-S-4-20-100K --ckpt-every 50000 --n-basis 4 --n-experts 20
torchrun --nnodes=1 --nproc_per_node=8 --master_port=22238 train_fast.py --model RemixDiT-S/2 --load-weight pretrained/DiT_S_2_2M.pt --data-path data/imagenet_encoded --epochs 20 --prefix RemixDiT-S-4-8-100K --ckpt-every 50000 --n-basis 4 --n-experts 20
torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py --model RemixDiT-S/2 --ckpt outputs/RemixDiT-S-4-20-100K/checkpoints/0100000.pt
Please refer to https://github.com/openai/guided-diffusion/tree/main/evaluations for the VIRTUAL_imagenet256_labeled.npz
.
python evaluator.py data/VIRTUAL_imagenet256_labeled.npz PATH_TO_YOUR.npz
This project was built on DiT and ADM.