Implementation of Generating Diverse High-Fidelity Images with VQ-VAE in PyTorch
ChenAoxuan:
|---sample 训练过程中测试模型的生成结果。
|---checkpoint 保存模型文件
|---test_sample sample.py加载的图片数据
|---dataset.py 读取Laion数据集的Dataset
|---sample.py / sample.sh 加载模型,测试对单张图片的编码和重建效果,内含img2code和code2img api,仅对vqvae1模型使用。
|---train_vqvae.py / train.sh 训练vqvae1或vqvae2模型
|---vqvae1.py vqvae1模型,与CogView中的VQ-VAE代码完全一致,但需要修改模型初始化参数才能与CogView对齐(参数在代码备注中)。
|---vqvae2.py vqvae2模型,仅供参考,能训练,但目前的若干生成模型中都在用vqvae1。
|---scheduler.py lr scheduler。
|---tmp.py 临时的文件,无意义。
|---prepare_data.py 临时的文件,无意义。
|---文件带有_old后缀的,表示Fork前的原始文件,未使用。
在5k张Laion数据上训练与CogView对齐的模型,使用两张TITAN XP(12G),单卡batch size最多开到16(显存占用10G+),训练一个epoch约4.1min。
Checkpoint of CogView - WuDao WenHui - vqvae_hard_biggerset_011.pt
Fork前:
- 2020-06-01
train_vqvae.py and vqvae.py now supports distributed training. You can use --n_gpu [NUM_GPUS] arguments for train_vqvae.py to use [NUM_GPUS] during training.
- Python >= 3.6
- PyTorch >= 1.1
- lmdb (for storing extracted codes)
Checkpoint of VQ-VAE pretrained on FFHQ
Currently supports 256px (top/bottom hierarchical prior)
- Stage 1 (VQ-VAE)
python train_vqvae.py [DATASET PATH]
If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg)
- Extract codes for stage 2 training
python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH]
- Stage 2 (PixelSNAIL)
python train_pixelsnail.py [LMDB NAME]
Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints.