Remix-DiT: Mixing Diffusion Transformers for Multi-Expert Denoising
Gongfan Fang, Xinyin Ma, Xinchao Wang
National University of Singapore
📄 [Arxiv]
The goal of Remix-DiT is to craft N diffusion experts for different denoising timesteps, yet without the need for expensive training of N independent models. Each expert only handles a subset of the denoising tasks and thus can extend the total capacity without heavily increasing the active model size on each timestep. To achieve this, Remix-DiT employs K basis models (where K < N) and utilizes learnable mixing coefficients to adaptively craft expert models.
pip install -r requirements.txt
torchrun --nnodes=1 --nproc_per_node=1 --model DiT-XL/2 --data-path data/imagenet/train --features-path data/imagenet_encoded
mkdir -p pretrained && cd pretrained
torchrun --nnodes=1 --nproc_per_node=8 --master_port=22238 --model RemixDiT-S/2 --load-weight pretrained/ --data-path data/imagenet_encoded --epochs 20 --prefix RemixDiT-S-4-20-100K --ckpt-every 50000 --n-basis 4 --n-experts 20
torchrun --nnodes=1 --nproc_per_node=8 --master_port=22238 --model RemixDiT-S/2 --load-weight pretrained/ --data-path data/imagenet_encoded --epochs 20 --prefix RemixDiT-S-4-8-100K --ckpt-every 50000 --n-basis 4 --n-experts 20
torchrun --nnodes=1 --nproc_per_node=8 --model RemixDiT-S/2 --ckpt outputs/RemixDiT-S-4-20-100K/checkpoints/
Please refer to for the VIRTUAL_imagenet256_labeled.npz
python data/VIRTUAL_imagenet256_labeled.npz PATH_TO_YOUR.npz
This project was built on DiT and ADM.