Skip to content
/ TATS Public

Official PyTorch implementation of TATS: A Long Video Generation Framework with Time-Agnostic VQGAN and Time-Sensitive Transformer (ECCV 2022)

License

Notifications You must be signed in to change notification settings

songweige/TATS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Long Video Generation with Time-Agnostic VQGAN and Time-Sensitive Transformer

Project Website | Video | Paper

tl;dr We propose TATS, a long video generation framework that is trained on videos with tens of frames while it is able to generate videos with thousands of frames using sliding window.

Setup

  conda create -n tats python=3.8
  conda activate tats
  conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
  pip install pytorch-lightning==1.5.4
  pip install einops ftfy h5py imageio imageio-ffmpeg regex scikit-video tqdm

Datasets and trained models

UCF-101: official data, VQGAN, TATS-base
Sky-Timelapse: official data, VQGAN, TATS-base
Taichi-HD: official data, VQGAN, TATS-base
MUGEN: official data, VQGAN, TATS-base
AudioSet-Drums: official data, Video-VQGAN, STFT-VQGAN, TATS-base

Synthesis

  1. Short videos: To sample the videos with the same length with the training data, use the code under scripts/ with following flags:
  • gpt_ckpt: path to the trained transformer checkpoint.
  • vqgan_ckpt: path to the trained VQGAN checkpoint.
  • save: path to the save the generation results.
  • save_videos: indicate that videos will be saved.
  • class_cond: indicate that class labels are used as conditional information.

To compute the FVD, these flags are required:

  • compute_fvd: indicate that FVD will be calculated.
  • data_path: path to the dataset folder.
  • dataset: dataset name.
  • image_folder: should be used when dataset contain frames instead of videos, e.g. Sky Time-lapse.
  • sample_every_n_frames: number of frames to skip in the real video data, e.g. please set it to 4 when training on the Taichi-HD dataset.
python sample_vqgan_transformer_short_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} --class_cond \
    --save {SAVEPATH} --data_path {DATAPATH} --batch_size 16 \
    --top_k 2048 --top_p 0.8 --dataset {DATANAME} --compute_fvd --save_videos
  1. Long videos: To sample the videos with the length longer than the training length with sliding window, use the following script.
  • sample_length: number of latent frames to be generated.
  • temporal_sample_pos: position of the frame that the sliding window approach generates.
python sample_vqgan_transformer_long_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} \
    --dataset ucf101 --class_cond --sample_length 16 --temporal_sample_pos 1 --batch_size 5 --n_sample 5 --save_videos
  1. Text to video: To sample MUGEN videos conditioned on the text, check this colab notebook for an example!

  2. Audio to video: To sample drum videos conditioned on the audio, use the following script.

  • stft_vqgan_ckpt: path to the trained VQGAN checkpoint for STFT features.
python sample_vqgan_transformer_audio_cond.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} --stft_vqgan_ckpt {STFT-CKPT} \
    --dataset drum --n_sample 10

Training

Example usages of training the VQGAN and transformers are shown below. Explanation of the flags that are opt to change according to different settings:

  • data_path: path to the dataset folder.
  • default_root_dir: path to save the checkpoints and tensorboard logs.
  • vqvae: path to the trained VQGAN checkpoint.
  • resolution: resolution of the training videos clips.
  • sequence_length: frame number of the training videos clips.
  • discriminator_iter_start: the step id to start the GAN losses.
  • image_folder: should be used when dataset contain frames instead of videos, e.g. Sky Time-lapse.
  • unconditional: when no conditional information are available, e.g. Sky Time-lapse, use this flag.
  • sample_every_n_frames: number of frames to skip in the real video data, e.g. please set it to 4 when training on the Taichi-HD dataset.
  • downsample: sample rate in the dimensions of time, height and width.
  • no_random_restart: whether to re-initialize the codebook tokens.

VQGAN

python train_vqgan.py --embedding_dim 256 --n_codes 16384 --n_hiddens 16 --downsample 4 8 8 --no_random_restart \
                      --gpus 8 --sync_batchnorm --batch_size 2 --num_workers 32 --accumulate_grad_batches 6 \
                      --progress_bar_refresh_rate 500 --max_steps 2000000 --gradient_clip_val 1.0 --lr 3e-5 \
                      --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                      --resolution 128 --sequence_length 16 --discriminator_iter_start 10000 --norm_type batch \
                      --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1  --gan_feat_weight 4

Transforemer

TATS-base Transforemer

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 3 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 16 --max_steps 2000000

To train a conditional transformer, remove the --unconditional flag and use the following flags

  • cond_stage_key: what kind of conditional information to be used. It can be label, text, or stft.
  • stft_vqvae: path to the trained VQGAN checkpoint for STFT features.
  • text_cond: use this flag to indicate BPE encoded text.

TATS-hierarchical Transforemer

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 3 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1280 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 20 --spatial_length 128 --n_unmasked 256 --max_steps 2000000

python train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 8 --sync_batchnorm --batch_size 4 --unconditional \
                        --vqvae {VQGAN-CKPT} --data_path {DATAPATH} --default_root_dir {CKPTPATH} \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 128 --sequence_length 64 --sample_every_n_latent_frames 4 --spatial_length 128 --max_steps 2000000

Acknowledgments

Our code is partially built upon VQGAN and VideoGPT.

Citation

@article{ge2022long,
         title={Long Video Generation with Time-Agnostic VQGAN and Time-Sensitive Transformer},
         author={Ge, Songwei and Hayes, Thomas and Yang, Harry and Yin, Xi and Pang, Guan and Jacobs, David and Huang, Jia-Bin and Parikh, Devi},
         journal={arXiv preprint arXiv:2204.03638},
         year={2022}
}

License

TATS is licensed under the MIT license, as found in the LICENSE file.

About

Official PyTorch implementation of TATS: A Long Video Generation Framework with Time-Agnostic VQGAN and Time-Sensitive Transformer (ECCV 2022)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages