Skip to content

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch

License

Notifications You must be signed in to change notification settings

chenaoxuan/vq-vae-pytorch

 
 

Repository files navigation

vq-vae-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE in PyTorch

ChenAoxuan:

|---sample 训练过程中测试模型的生成结果。

|---checkpoint 保存模型文件

|---test_sample sample.py加载的图片数据

|---dataset.py 读取Laion数据集的Dataset

|---sample.py / sample.sh 加载模型,测试对单张图片的编码和重建效果,内含img2code和code2img api,仅对vqvae1模型使用。

|---train_vqvae.py / train.sh 训练vqvae1或vqvae2模型

|---vqvae1.py vqvae1模型,与CogView中的VQ-VAE代码完全一致,但需要修改模型初始化参数才能与CogView对齐(参数在代码备注中)。

|---vqvae2.py vqvae2模型,仅供参考,能训练,但目前的若干生成模型中都在用vqvae1。

|---scheduler.py lr scheduler。

|---tmp.py 临时的文件,无意义。

|---prepare_data.py 临时的文件,无意义。

|---文件带有_old后缀的,表示Fork前的原始文件,未使用。

在5k张Laion数据上训练与CogView对齐的模型,使用两张TITAN XP(12G),单卡batch size最多开到16(显存占用10G+),训练一个epoch约4.1min。

Checkpoint of CogView - WuDao WenHui - vqvae_hard_biggerset_011.pt

Fork前:

Update

  • 2020-06-01

train_vqvae.py and vqvae.py now supports distributed training. You can use --n_gpu [NUM_GPUS] arguments for train_vqvae.py to use [NUM_GPUS] during training.

Requisite

  • Python >= 3.6
  • PyTorch >= 1.1
  • lmdb (for storing extracted codes)

Checkpoint of VQ-VAE pretrained on FFHQ

Usage

Currently supports 256px (top/bottom hierarchical prior)

  1. Stage 1 (VQ-VAE)

python train_vqvae.py [DATASET PATH]

If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg)

  1. Extract codes for stage 2 training

python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH]

  1. Stage 2 (PixelSNAIL)

python train_pixelsnail.py [LMDB NAME]

Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints.

About

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.4%
  • Shell 0.6%