VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction
NeurIPS 2024 Oral
- 2024-12: We Release our Text-to-Image research based on VAR, please check Infinity.
- 2024-09: VAR is accepted as NeurIPS 2024 Oral Presentation.
- 2024-04: Visual AutoRegressive modeling is released.
We provide a demo website for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!
We also provide demo_sample.ipynb for you to see more technical details about VAR.
Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".
For a deep dive into our analyses, discussions, and evaluations, check out our paper.
We provide VAR models for you to play with, which are on or can be downloaded from the following links:
model | reso. | FID | rel. cost | #params | HF weights🤗 |
---|---|---|---|---|---|
VAR-d16 | 256 | 3.55 | 0.4 | 310M | var_d16.pth |
VAR-d20 | 256 | 2.95 | 0.5 | 600M | var_d20.pth |
VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | var_d24.pth |
VAR-d30 | 256 | 1.97 | 1 | 2.0B | var_d30.pth |
VAR-d30-re | 256 | 1.80 | 1 | 2.0B | var_d30.pth |
You can load these models to generate images via the codes in demo_sample.ipynb. Note: you need to download vae_ch160v4096z32.pth first.
-
Install
torch>=2.0.0
. -
Install other pip packages via
pip3 install -r requirements.txt
. -
Prepare the ImageNet dataset
assume the ImageNet is in `/path/to/imagenet`. It should be like this:
/path/to/imagenet/: train/: n01440764: many_images.JPEG ... n01443537: many_images.JPEG ... val/: n01440764: ILSVRC2012_val_00000293.JPEG ... n01443537: ILSVRC2012_val_00000236.JPEG ...
NOTE: The arg
--data_path=/path/to/imagenet
should be passed to the training script. -
(Optional) install and compile
flash-attn
andxformers
for faster attention computation. Our code will automatically use them if installed. See models/basic_var.py#L15-L30.
To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:
# d16, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
# d20, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
# d24, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
# d30, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
# d36-s, 512x512 (-s means saln=1, shared AdaLN)
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
--depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
A folder named local_output
will be created to save the checkpoints and logs.
You can monitor the training process by checking the logs in local_output/log.txt
and local_output/stdout.txt
, or using tensorboard --logdir=local_output/
.
If your experiment is interrupted, just rerun the command, and the training will automatically resume from the last checkpoint in local_output/ckpt*.pth
(see utils/misc.py#L344-L357).
For FID evaluation, use var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)
to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a .npz
file via create_npz_from_sample_folder(sample_folder)
in utils/misc.py#L344.
Then use the OpenAI's FID evaluation toolkit and reference ground truth npz file of 256x256 or 512x512 to evaluate FID, IS, precision, and recall.
Note a relatively small cfg=1.5
is used for trade-off between image quality and diversity. You can adjust it to cfg=5.0
, or sample with autoregressive_infer_cfg(..., more_smooth=True)
for better visual quality.
We'll provide the sampling script later.
This project is licensed under the MIT License - see the LICENSE file for details.
If our work assists your research, feel free to give us a star ⭐ or cite us using:
@Article{VAR,
title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
year={2024},
eprint={2404.02905},
archivePrefix={arXiv},
primaryClass={cs.CV}
}