Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
willi-menapace committed Jan 28, 2021
1 parent caf1eb3 commit da5942c
Show file tree
Hide file tree
Showing 103 changed files with 11,489 additions and 1 deletion.
36 changes: 36 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# We will use Ubuntu for our image
FROM nvidia/cuda:10.1-base-ubuntu18.04

# Updating Ubuntu packages
RUN apt-get update && \
apt-get install -y wget ffmpeg build-essential bzip2

# Anaconda installing
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
RUN bash Miniconda3-latest-Linux-x86_64.sh -b && \
rm Miniconda3-latest-Linux-x86_64.sh

# Set path to conda
ENV PATH /root/miniconda3/bin:$PATH

# Creates the conda environment
COPY env.yml .
RUN conda env create -f env.yml

# Initializes .bashrc with conda startup instructions
RUN /root/miniconda3/condabin/conda init bash && \
/bin/bash -c "source ~/.bashrc"

RUN mkdir video-generation

# Configures bash to automatically start the environment
RUN echo "conda activate video-generation" >> ~/.bashrc

# Set the api key for wandb
ENV WANDB_API_KEY <YOUR WANDB API KEY>

WORKDIR video-generation
RUN /bin/bash -c "source ~/.bashrc && conda run -n video-generation wandb on"


# Run with docker run -it --gpus all --ipc=host -v /path/to/directory/video-generation:/video-generation video-generation:1.0 /bin/bash
101 changes: 100 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,100 @@
# PlayableVideoGeneration
# Playable Video Generation
<br>
<p align="center">
<img src="./resources/architecture.png"/> <br />
<em>
Figure 1. Illustration of the proposed CADDY model for playable video generation.
</em>
</p>
<br>

> **Playable Video Generation**<br>
> [Willi Menapace](https://www.willimenapace.com/), [Stéphane Lathuilière](https://stelat.eu/), [Sergey Tulyakov](http://www.stulyakov.com/), [Aliaksandr Siarohin](https://github.com/AliaksandrSiarohin), [Elisa Ricci](http://elisaricci.eu/)<br>
> ArXiv<br>
> Paper: [arXiv: Coming soon]()<br>
> [Website](https://willi-menapace.github.io/playable-video-generation-website/)<br>
> [Live Demo](https://willi-menapace.github.io/playable-video-generation-website/play.html)<br>
> **Abstract:** *This paper introduces the unsupervised learning problem of playable video generation (PVG). In PVG, we aim at allowing a user to control the generated video by selecting a discrete action at every time step as when playing a video game. The difficulty of the task lies both in learning semantically consistent actions and in generating realistic videos conditioned on the user input. We propose a novel framework for PVG that is trained in a self-supervised manner on a large dataset of unlabelled videos. We employ an encoder-decoder architecture where the predicted action labels act as bottleneck. The network is constrained to learn a rich action space using, as main driving loss, a reconstruction loss on the generated video. We demonstrate the effectiveness of the proposed approach on several datasets with wide environment variety.*
# Overview

Given a set of completely unlabeled videos, we jointly learn a set of discrete actions and a video generation model conditioned on the learned actions. At test time, the user can control the generated video on-the-fly providing action labels as if he or she was playing a videogame. We name our method CADDY. Our architecture for unsupervised playable video generation is composed by several components. An encoder E extracts frame representations from the input sequence. A temporal model estimates the successive states using a recurrent dynamics network R and an action network A which predicts the action label corresponding to the current action performed in the input sequence. Finally, a decoder D reconstructs the input frames. The model is trained using reconstruction as the main driving loss.

# Installation

## Conda

The complete environment for execution can be installed with:

`conda env create -f env.yml`

`conda activate video-generation`

## Docker

Build the docker image
`docker build -t video-generation:1.0 .`

Run the docker image. Mount the root directory to `/video-generation` in the docker container:
`docker run -it --gpus all --ipc=host -v /path/to/directory/video-generation:/video-generation video-generation:1.0 /bin/bash`

# Directory structure

Please create the following directories in the root of the project:

- `results`
- `checkpoints`
- `data`

# Datasets
Datasets can be downloaded at the following link:
[Google Drive](https://drive.google.com/drive/folders/1CuHK_-cFWih0F8AxB4b76FoBQ9RjWMww?usp=sharing)

- Breakout: breakout_v2_160_ours.tar.gz
- BAIR: bair_256_ours.tar.gz
- Tennis: tennis_v4_256_ours.tar.gz

Please extract them under the `data` folder

# Pretrained Models

Pretrained models can be downloaded at the following link:
[Google Drive](https://drive.google.com/drive/folders/1xLlJ8Xh6_wOEEARwBcoeVng2Bbi-wAah?usp=sharing)

Please place the directories under the `checkpoints` folder

# Playing

After downloading the checkpoints, the models can be played with the following commands:

- Bair:
`python play.py --config configs/01_bair.yaml`

- Breakout:
`python play.py configs/breakout/02_breakout.yaml`

- Tennis:
`python play.py --config configs/03_tennis.yaml`

# Training

The models can be trained with the following commands:

`python train.py --config configs/<config_file>`

Multi-gpu support is active by default. Runs can be logged through Weights and Biases by running before execution of the training command:
`wandb init`

# Evaluation

Evaluation requires two steps. First, an evaluation dataset must be build. Second, evaluation is carried our on the evaluation dataset. To build the evaluation dataset please issue:

`python build_evaluation_dataset.py --config configs/<config_file>`

To run evaluation issue:

`python evaluate_dataset.py --config configs/evaluation/configs/<config_file>`

Evaluation results are saved under the `evaluation_results` directory.
68 changes: 68 additions & 0 deletions build_evaluation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import argparse
import importlib
import os

import torch
import torchvision
import numpy as np

from dataset.dataset_splitter import DatasetSplitter
from dataset.transforms import TransformsGenerator
from dataset.video_dataset import VideoDataset
from evaluation.action_sampler import OneHotActionSampler, GroundTruthActionSampler
from evaluation.evaluator import Evaluator
from training.trainer import Trainer
from utils.configuration import Configuration
from utils.logger import Logger

torch.backends.cudnn.benchmark = True

if __name__ == "__main__":


# Loads configuration file
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
arguments = parser.parse_args()

config_path = arguments.config

configuration = Configuration(config_path)
configuration.check_config()
configuration.create_directory_structure()

config = configuration.get_config()

logger = Logger(config)
search_name = config["model"]["architecture"]
model = getattr(importlib.import_module(search_name), 'model')(config)
model.cuda()

logger.get_wandb().watch(model, log='all')

datasets = {}

dataset_splits = DatasetSplitter.generate_splits(config)
transformations = TransformsGenerator.get_final_transforms(config)

for key in dataset_splits:
path, batching_config, split = dataset_splits[key]
transform = transformations[key]

datasets[key] = VideoDataset(path, batching_config, transform, split)

# Creates trainer and evaluator
trainer = getattr(importlib.import_module(config["training"]["trainer"]), 'trainer')(config, model, datasets["train"], logger)
# Creates evaluation dataset builder
evaluation_dataset_builder = getattr(importlib.import_module(config["evaluation_dataset"]["builder"]), 'builder')(
config, datasets["test"], logger)

# Resume training
try:
trainer.load_checkpoint(model)
except Exception as e:
logger.print(e)
#raise Exception("Cannot find checkpoint to load")

model.eval()
evaluation_dataset_builder.build(model)
188 changes: 188 additions & 0 deletions configs/01_bair.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Logging parameters
logging:
# Name which which the run will be logged
run_name: "01_bair"

# Directory where main results are saved
output_root: "results"
# Checkpoint directory
save_root: "checkpoints"

# Dataset parameters
data:
# Dataset path
data_root: "data/bair_256_ours"
# Crop to apply to each frame [left_index, upper_index, right_index, lower_index]
crop: [0, 0, 256, 256]
# Number of distinct actions present in the dataset
actions_count: 7
# True if ground truth annotations are available
ground_truth_available: False

# Model parameters
model:
# Class to use as model
architecture: "model.main_model.model"

representation_network:
# desired (width, height) of the input images
target_input_size: [256, 256]
# features of the tensor output by the representation network
state_features: 64
# height and width output by the representation network
state_resolution: [32, 32]

# Dynamics network parameters
dynamics_network:
# Size of the hidden state
hidden_state_size: 128
# Output units in the MLP for input vector embedding
embedding_mlp_size: 128
# Elements in the noise vector
random_noise_size: 32

rendering_network:
# Shape of the input tensor [features, height, width]
input_shape: [64, 32, 32]

# Action recognition network parameters
action_network:
use_gumbel: True
hard_gumbel: False
ensamble_size: 1
# Temperature to use in Gumbel-Softmax for action sampling
gumbel_temperature: 1.0
# Number of the spatial dimensions of the embedding space
action_space_dimension: 2

# Centroid estimator parameters
centroid_estimator:
# Alpha value to use for computing the moving average
alpha: 0.1

# Training parameters
training:

trainer: "training.smooth_mi_trainer"

use_ground_truth_actions: False

learning_rate: 0.0004
weight_decay: 0.000001

# Number of steps to use for pretraining
pretraining_steps: 1000
# Whether to avoid backpropagation of gradients into the representation network during pretraining
pretraining_detach: False
# Steps at which to switch learning rate
lr_schedule: [300000, 10000000000]
# Gamma parameter for lr scheduling
lr_gamma: 0.3333
# Maximum number of steps for which to train the model
max_steps: 300000
# Interval in training steps at which to save the model
save_freq: 3000

# Number of ground truth observations in each sequence to use at the beginning of training
ground_truth_observations_start: 6
# Number of real observations in each sequence to use at the end of the annealing period
ground_truth_observations_end: 6
# Length in steps of the annealing period
ground_truth_observations_steps: 16000

# Number of ground truth observations in each sequence to use at the beginning of training
gumbel_temperature_start: 1.0
# Number of real observations in each sequence to use at the end of the annealing period
gumbel_temperature_end: 0.4
# Length in steps of the annealing period
gumbel_temperature_steps: 20000

# The alpha value to use for mutual information estimation smoothing
mutual_information_estimation_alpha: 0.2

# Parameters for batch building
batching:
batch_size: 8

# Number of observations that each batch sample possesses
observations_count: 12
# Number of observations that the first batch possesses
observations_count_start: 7
# Length in steps of the annealing period
observations_count_steps: 25000

# Number of frames to skip between each observation
skip_frames: 0
# Total number of frames that compose an observation
observation_stacking: 1
# Number of threads to use for dataloading
num_workers: 16

# Weights to use for the loss
loss_weights:
# Weight for the reconstruction loss
reconstruction_loss_lambda: 1.0
# Weight for the reconstruction loss during pretraining
reconstruction_loss_lambda_pretraining: 1.0
# Weight for the perceptual loss
perceptual_loss_lambda: 1.0
# Weight for the perceptual loss during pretraining
perceptual_loss_lambda_pretraining: 1.0
# Weight for the action divergence between plain and transformed sequences
action_divergence_lambda: 0.0
# Weight for the action divergence between plain and transformed sequences during pretraining
action_divergence_lambda_pretraining: 0.0
# Weight for state reconstruction loss
states_rec_lambda: 0.2
# Weight for state reconstruction loss during pretraining
states_rec_lambda_pretraining: 0.2
# Weight for state reconstruction loss during pretraining
hidden_states_rec_lambda_pretraining: 1.0
# Weights for the action distribution entropy loss
entropy_lambda: 0.0
# Weights for the action distribution entropy loss during pretraining
entropy_lambda_pretraining: 0.0
# Weights for the action directions kl divergence loss
action_directions_kl_lambda: 0.0001
# Weights for the action directions kl divergence loss during pretraining
action_directions_kl_lambda_pretraining: 0.0001
# Weights for the action mutual information loss
action_mutual_information_lambda: 0.15
# Weights for the action mutual information loss during pretraining
action_mutual_information_lambda_pretraining: 0.15
# Weights for the kl distance loss between action states and reconstructed action states
action_state_distribution_kl_lambda: 0.0
# Weights for the kl distance loss between action states and reconstructed action states during pretraining
action_state_distribution_kl_lambda_pretraining: 0.0

# Number of steps between each plotting of the action space
action_direction_plotting_freq: 1000

# Parameters for evaluation
evaluation:

evaluator: "evaluation.evaluator"

max_evaluation_batches: 20
# Minimum number of steps between two successive evaluations
eval_freq: 8000
# Parameters for batch building
batching:
batch_size: 8
# Number of observations that each batch sample possesses
observations_count: 30
# Number of frames to skip between each observation
skip_frames: 0
# Total number of frames that compose an observation
observation_stacking: 1
# Number of threads to use for dataloading
num_workers: 16

# Parameters for final evaluation dataset computation
evaluation_dataset:

# The number of ground truth context frames to use to produce each sequence
ground_truth_observations_init: 4
builder: "evaluation.evaluation_dataset_builder"


Loading

0 comments on commit da5942c

Please sign in to comment.