forked from willi-menapace/PlayableVideoGeneration
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
caf1eb3
commit da5942c
Showing
103 changed files
with
11,489 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# We will use Ubuntu for our image | ||
FROM nvidia/cuda:10.1-base-ubuntu18.04 | ||
|
||
# Updating Ubuntu packages | ||
RUN apt-get update && \ | ||
apt-get install -y wget ffmpeg build-essential bzip2 | ||
|
||
# Anaconda installing | ||
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | ||
RUN bash Miniconda3-latest-Linux-x86_64.sh -b && \ | ||
rm Miniconda3-latest-Linux-x86_64.sh | ||
|
||
# Set path to conda | ||
ENV PATH /root/miniconda3/bin:$PATH | ||
|
||
# Creates the conda environment | ||
COPY env.yml . | ||
RUN conda env create -f env.yml | ||
|
||
# Initializes .bashrc with conda startup instructions | ||
RUN /root/miniconda3/condabin/conda init bash && \ | ||
/bin/bash -c "source ~/.bashrc" | ||
|
||
RUN mkdir video-generation | ||
|
||
# Configures bash to automatically start the environment | ||
RUN echo "conda activate video-generation" >> ~/.bashrc | ||
|
||
# Set the api key for wandb | ||
ENV WANDB_API_KEY <YOUR WANDB API KEY> | ||
|
||
WORKDIR video-generation | ||
RUN /bin/bash -c "source ~/.bashrc && conda run -n video-generation wandb on" | ||
|
||
|
||
# Run with docker run -it --gpus all --ipc=host -v /path/to/directory/video-generation:/video-generation video-generation:1.0 /bin/bash |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,100 @@ | ||
# PlayableVideoGeneration | ||
# Playable Video Generation | ||
<br> | ||
<p align="center"> | ||
<img src="./resources/architecture.png"/> <br /> | ||
<em> | ||
Figure 1. Illustration of the proposed CADDY model for playable video generation. | ||
</em> | ||
</p> | ||
<br> | ||
|
||
> **Playable Video Generation**<br> | ||
> [Willi Menapace](https://www.willimenapace.com/), [Stéphane Lathuilière](https://stelat.eu/), [Sergey Tulyakov](http://www.stulyakov.com/), [Aliaksandr Siarohin](https://github.com/AliaksandrSiarohin), [Elisa Ricci](http://elisaricci.eu/)<br> | ||
> ArXiv<br> | ||
> Paper: [arXiv: Coming soon]()<br> | ||
> [Website](https://willi-menapace.github.io/playable-video-generation-website/)<br> | ||
> [Live Demo](https://willi-menapace.github.io/playable-video-generation-website/play.html)<br> | ||
> **Abstract:** *This paper introduces the unsupervised learning problem of playable video generation (PVG). In PVG, we aim at allowing a user to control the generated video by selecting a discrete action at every time step as when playing a video game. The difficulty of the task lies both in learning semantically consistent actions and in generating realistic videos conditioned on the user input. We propose a novel framework for PVG that is trained in a self-supervised manner on a large dataset of unlabelled videos. We employ an encoder-decoder architecture where the predicted action labels act as bottleneck. The network is constrained to learn a rich action space using, as main driving loss, a reconstruction loss on the generated video. We demonstrate the effectiveness of the proposed approach on several datasets with wide environment variety.* | ||
# Overview | ||
|
||
Given a set of completely unlabeled videos, we jointly learn a set of discrete actions and a video generation model conditioned on the learned actions. At test time, the user can control the generated video on-the-fly providing action labels as if he or she was playing a videogame. We name our method CADDY. Our architecture for unsupervised playable video generation is composed by several components. An encoder E extracts frame representations from the input sequence. A temporal model estimates the successive states using a recurrent dynamics network R and an action network A which predicts the action label corresponding to the current action performed in the input sequence. Finally, a decoder D reconstructs the input frames. The model is trained using reconstruction as the main driving loss. | ||
|
||
# Installation | ||
|
||
## Conda | ||
|
||
The complete environment for execution can be installed with: | ||
|
||
`conda env create -f env.yml` | ||
|
||
`conda activate video-generation` | ||
|
||
## Docker | ||
|
||
Build the docker image | ||
`docker build -t video-generation:1.0 .` | ||
|
||
Run the docker image. Mount the root directory to `/video-generation` in the docker container: | ||
`docker run -it --gpus all --ipc=host -v /path/to/directory/video-generation:/video-generation video-generation:1.0 /bin/bash` | ||
|
||
# Directory structure | ||
|
||
Please create the following directories in the root of the project: | ||
|
||
- `results` | ||
- `checkpoints` | ||
- `data` | ||
|
||
# Datasets | ||
Datasets can be downloaded at the following link: | ||
[Google Drive](https://drive.google.com/drive/folders/1CuHK_-cFWih0F8AxB4b76FoBQ9RjWMww?usp=sharing) | ||
|
||
- Breakout: breakout_v2_160_ours.tar.gz | ||
- BAIR: bair_256_ours.tar.gz | ||
- Tennis: tennis_v4_256_ours.tar.gz | ||
|
||
Please extract them under the `data` folder | ||
|
||
# Pretrained Models | ||
|
||
Pretrained models can be downloaded at the following link: | ||
[Google Drive](https://drive.google.com/drive/folders/1xLlJ8Xh6_wOEEARwBcoeVng2Bbi-wAah?usp=sharing) | ||
|
||
Please place the directories under the `checkpoints` folder | ||
|
||
# Playing | ||
|
||
After downloading the checkpoints, the models can be played with the following commands: | ||
|
||
- Bair: | ||
`python play.py --config configs/01_bair.yaml` | ||
|
||
- Breakout: | ||
`python play.py configs/breakout/02_breakout.yaml` | ||
|
||
- Tennis: | ||
`python play.py --config configs/03_tennis.yaml` | ||
|
||
# Training | ||
|
||
The models can be trained with the following commands: | ||
|
||
`python train.py --config configs/<config_file>` | ||
|
||
Multi-gpu support is active by default. Runs can be logged through Weights and Biases by running before execution of the training command: | ||
`wandb init` | ||
|
||
# Evaluation | ||
|
||
Evaluation requires two steps. First, an evaluation dataset must be build. Second, evaluation is carried our on the evaluation dataset. To build the evaluation dataset please issue: | ||
|
||
`python build_evaluation_dataset.py --config configs/<config_file>` | ||
|
||
To run evaluation issue: | ||
|
||
`python evaluate_dataset.py --config configs/evaluation/configs/<config_file>` | ||
|
||
Evaluation results are saved under the `evaluation_results` directory. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import argparse | ||
import importlib | ||
import os | ||
|
||
import torch | ||
import torchvision | ||
import numpy as np | ||
|
||
from dataset.dataset_splitter import DatasetSplitter | ||
from dataset.transforms import TransformsGenerator | ||
from dataset.video_dataset import VideoDataset | ||
from evaluation.action_sampler import OneHotActionSampler, GroundTruthActionSampler | ||
from evaluation.evaluator import Evaluator | ||
from training.trainer import Trainer | ||
from utils.configuration import Configuration | ||
from utils.logger import Logger | ||
|
||
torch.backends.cudnn.benchmark = True | ||
|
||
if __name__ == "__main__": | ||
|
||
|
||
# Loads configuration file | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config", type=str, required=True) | ||
arguments = parser.parse_args() | ||
|
||
config_path = arguments.config | ||
|
||
configuration = Configuration(config_path) | ||
configuration.check_config() | ||
configuration.create_directory_structure() | ||
|
||
config = configuration.get_config() | ||
|
||
logger = Logger(config) | ||
search_name = config["model"]["architecture"] | ||
model = getattr(importlib.import_module(search_name), 'model')(config) | ||
model.cuda() | ||
|
||
logger.get_wandb().watch(model, log='all') | ||
|
||
datasets = {} | ||
|
||
dataset_splits = DatasetSplitter.generate_splits(config) | ||
transformations = TransformsGenerator.get_final_transforms(config) | ||
|
||
for key in dataset_splits: | ||
path, batching_config, split = dataset_splits[key] | ||
transform = transformations[key] | ||
|
||
datasets[key] = VideoDataset(path, batching_config, transform, split) | ||
|
||
# Creates trainer and evaluator | ||
trainer = getattr(importlib.import_module(config["training"]["trainer"]), 'trainer')(config, model, datasets["train"], logger) | ||
# Creates evaluation dataset builder | ||
evaluation_dataset_builder = getattr(importlib.import_module(config["evaluation_dataset"]["builder"]), 'builder')( | ||
config, datasets["test"], logger) | ||
|
||
# Resume training | ||
try: | ||
trainer.load_checkpoint(model) | ||
except Exception as e: | ||
logger.print(e) | ||
#raise Exception("Cannot find checkpoint to load") | ||
|
||
model.eval() | ||
evaluation_dataset_builder.build(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Logging parameters | ||
logging: | ||
# Name which which the run will be logged | ||
run_name: "01_bair" | ||
|
||
# Directory where main results are saved | ||
output_root: "results" | ||
# Checkpoint directory | ||
save_root: "checkpoints" | ||
|
||
# Dataset parameters | ||
data: | ||
# Dataset path | ||
data_root: "data/bair_256_ours" | ||
# Crop to apply to each frame [left_index, upper_index, right_index, lower_index] | ||
crop: [0, 0, 256, 256] | ||
# Number of distinct actions present in the dataset | ||
actions_count: 7 | ||
# True if ground truth annotations are available | ||
ground_truth_available: False | ||
|
||
# Model parameters | ||
model: | ||
# Class to use as model | ||
architecture: "model.main_model.model" | ||
|
||
representation_network: | ||
# desired (width, height) of the input images | ||
target_input_size: [256, 256] | ||
# features of the tensor output by the representation network | ||
state_features: 64 | ||
# height and width output by the representation network | ||
state_resolution: [32, 32] | ||
|
||
# Dynamics network parameters | ||
dynamics_network: | ||
# Size of the hidden state | ||
hidden_state_size: 128 | ||
# Output units in the MLP for input vector embedding | ||
embedding_mlp_size: 128 | ||
# Elements in the noise vector | ||
random_noise_size: 32 | ||
|
||
rendering_network: | ||
# Shape of the input tensor [features, height, width] | ||
input_shape: [64, 32, 32] | ||
|
||
# Action recognition network parameters | ||
action_network: | ||
use_gumbel: True | ||
hard_gumbel: False | ||
ensamble_size: 1 | ||
# Temperature to use in Gumbel-Softmax for action sampling | ||
gumbel_temperature: 1.0 | ||
# Number of the spatial dimensions of the embedding space | ||
action_space_dimension: 2 | ||
|
||
# Centroid estimator parameters | ||
centroid_estimator: | ||
# Alpha value to use for computing the moving average | ||
alpha: 0.1 | ||
|
||
# Training parameters | ||
training: | ||
|
||
trainer: "training.smooth_mi_trainer" | ||
|
||
use_ground_truth_actions: False | ||
|
||
learning_rate: 0.0004 | ||
weight_decay: 0.000001 | ||
|
||
# Number of steps to use for pretraining | ||
pretraining_steps: 1000 | ||
# Whether to avoid backpropagation of gradients into the representation network during pretraining | ||
pretraining_detach: False | ||
# Steps at which to switch learning rate | ||
lr_schedule: [300000, 10000000000] | ||
# Gamma parameter for lr scheduling | ||
lr_gamma: 0.3333 | ||
# Maximum number of steps for which to train the model | ||
max_steps: 300000 | ||
# Interval in training steps at which to save the model | ||
save_freq: 3000 | ||
|
||
# Number of ground truth observations in each sequence to use at the beginning of training | ||
ground_truth_observations_start: 6 | ||
# Number of real observations in each sequence to use at the end of the annealing period | ||
ground_truth_observations_end: 6 | ||
# Length in steps of the annealing period | ||
ground_truth_observations_steps: 16000 | ||
|
||
# Number of ground truth observations in each sequence to use at the beginning of training | ||
gumbel_temperature_start: 1.0 | ||
# Number of real observations in each sequence to use at the end of the annealing period | ||
gumbel_temperature_end: 0.4 | ||
# Length in steps of the annealing period | ||
gumbel_temperature_steps: 20000 | ||
|
||
# The alpha value to use for mutual information estimation smoothing | ||
mutual_information_estimation_alpha: 0.2 | ||
|
||
# Parameters for batch building | ||
batching: | ||
batch_size: 8 | ||
|
||
# Number of observations that each batch sample possesses | ||
observations_count: 12 | ||
# Number of observations that the first batch possesses | ||
observations_count_start: 7 | ||
# Length in steps of the annealing period | ||
observations_count_steps: 25000 | ||
|
||
# Number of frames to skip between each observation | ||
skip_frames: 0 | ||
# Total number of frames that compose an observation | ||
observation_stacking: 1 | ||
# Number of threads to use for dataloading | ||
num_workers: 16 | ||
|
||
# Weights to use for the loss | ||
loss_weights: | ||
# Weight for the reconstruction loss | ||
reconstruction_loss_lambda: 1.0 | ||
# Weight for the reconstruction loss during pretraining | ||
reconstruction_loss_lambda_pretraining: 1.0 | ||
# Weight for the perceptual loss | ||
perceptual_loss_lambda: 1.0 | ||
# Weight for the perceptual loss during pretraining | ||
perceptual_loss_lambda_pretraining: 1.0 | ||
# Weight for the action divergence between plain and transformed sequences | ||
action_divergence_lambda: 0.0 | ||
# Weight for the action divergence between plain and transformed sequences during pretraining | ||
action_divergence_lambda_pretraining: 0.0 | ||
# Weight for state reconstruction loss | ||
states_rec_lambda: 0.2 | ||
# Weight for state reconstruction loss during pretraining | ||
states_rec_lambda_pretraining: 0.2 | ||
# Weight for state reconstruction loss during pretraining | ||
hidden_states_rec_lambda_pretraining: 1.0 | ||
# Weights for the action distribution entropy loss | ||
entropy_lambda: 0.0 | ||
# Weights for the action distribution entropy loss during pretraining | ||
entropy_lambda_pretraining: 0.0 | ||
# Weights for the action directions kl divergence loss | ||
action_directions_kl_lambda: 0.0001 | ||
# Weights for the action directions kl divergence loss during pretraining | ||
action_directions_kl_lambda_pretraining: 0.0001 | ||
# Weights for the action mutual information loss | ||
action_mutual_information_lambda: 0.15 | ||
# Weights for the action mutual information loss during pretraining | ||
action_mutual_information_lambda_pretraining: 0.15 | ||
# Weights for the kl distance loss between action states and reconstructed action states | ||
action_state_distribution_kl_lambda: 0.0 | ||
# Weights for the kl distance loss between action states and reconstructed action states during pretraining | ||
action_state_distribution_kl_lambda_pretraining: 0.0 | ||
|
||
# Number of steps between each plotting of the action space | ||
action_direction_plotting_freq: 1000 | ||
|
||
# Parameters for evaluation | ||
evaluation: | ||
|
||
evaluator: "evaluation.evaluator" | ||
|
||
max_evaluation_batches: 20 | ||
# Minimum number of steps between two successive evaluations | ||
eval_freq: 8000 | ||
# Parameters for batch building | ||
batching: | ||
batch_size: 8 | ||
# Number of observations that each batch sample possesses | ||
observations_count: 30 | ||
# Number of frames to skip between each observation | ||
skip_frames: 0 | ||
# Total number of frames that compose an observation | ||
observation_stacking: 1 | ||
# Number of threads to use for dataloading | ||
num_workers: 16 | ||
|
||
# Parameters for final evaluation dataset computation | ||
evaluation_dataset: | ||
|
||
# The number of ground truth context frames to use to produce each sequence | ||
ground_truth_observations_init: 4 | ||
builder: "evaluation.evaluation_dataset_builder" | ||
|
||
|
Oops, something went wrong.