Skip to content

Commit

Permalink
vision_aided_loss as independent module
Browse files Browse the repository at this point in the history
  • Loading branch information
nupurkmr9 committed May 4, 2022
1 parent f55a5a3 commit f200140
Show file tree
Hide file tree
Showing 188 changed files with 25,600 additions and 1,099 deletions.
161 changes: 46 additions & 115 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

### [video](https://youtu.be/oHdyJNdQ9E4) | [website](https://www.cs.cmu.edu/~vision-aided-gan/) | [paper](https://arxiv.org/abs/2112.09130)



**[NEW!]** Vision-aided GAN training with BigGAN and StyleGAN3

**[NEW!]** Using vision-aided Discriminator in your own GAN training.
<img src='docs/code.gif' align="center" width=800>


<br>

<div class="gif">
Expand All @@ -25,6 +33,8 @@ Ensembling Off-the-shelf Models for GAN Training <br>
[Nupur Kumari](https://nupurkmr9.github.io/), [Richard Zhang](https://richzhang.github.io/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/)<br>
In CVPR 2022



## Quantitative Comparison

<p align="center">
Expand All @@ -35,11 +45,10 @@ Our method outperforms recent GAN training methods by a large margin, especially

## Example Results
Below, we show visual comparisons between the baseline StyleGAN2-ADA and our model (Vision-aided GAN) for the
same randomly sample latent code.
same randomly sample latent code on 100-shot [Bridge-of-sighs](https://data-efficient-gans.mit.edu/datasets/100-shot-bridge_of_sighs.zip) and [AnimalFace Dog](https://data-efficient-gans.mit.edu/datasets/AnimalFace-dog.zip) dataset.

<img src="docs/lsuncat1k_compare.gif" width="800px"/>
<img src="docs/bridge.gif" width="400px"/><img src="docs/animalface_dog.gif" width="400px"/>

<img src="docs/ffhq1k_compare.gif" width="800px"/>

## Interpolation Videos
Latent interpolation results of models trained with our method on AnimalFace Cat (160 images), Dog (389 images), and Bridge-of-Sighs (100 photos).
Expand All @@ -48,6 +57,7 @@ Latent interpolation results of models trained with our method on AnimalFace Cat
<img src="docs/interp.gif" width="800px"/>
</p>


## Worst sample visualzation
We randomly sample 5k images and sort them according to Mahalanobis distance using mean and variance of real samples calculated in inception feature space. Below visualization shows the bottom 30 images according to the distance for StyleGAN2-ADA (left) and our model (right).

Expand Down Expand Up @@ -90,142 +100,63 @@ We randomly sample 5k images and sort them according to Mahalanobis distance usi
</p>
</details>

Example command to create similar visualization. The output image is saved in `out` directory for the below command.

```.bash
python calc_metrics.py --metrics sort_likelihood --name afhq_dog --split train --network https://www.cs.cmu.edu/~vision-aided-gan/models/main_paper_table3_afhq/vision-aided-gan-afhqdog-ada-3.pkl --data afhqdog
```

## Requirements
### Vision-aided StyleGAN2 training
Please see [vision-aided-stylegan2](https://github.com/nupurkmr9/vision_aided_module/tree/main/stylegan2) README for training StyleGAN2 models with our method. This code will reproduce all StyleGAN2 based results from our paper.

* 64-bit Python 3.8 and PyTorch 1.8.0 (or later). See [https://pytorch.org/](https://pytorch.org/) for PyTorch install instructions.
* Cuda toolkit 11.0 or later.
* python libraries: see scripts/requirements.txt
* StyleGAN2 code relies heavily on custom PyTorch extensions. For detail please refer to the repo [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch)
### Vision-aided Discriminator in a custom GAN model

To setup conda env with all requirements and pretrained networks run the following command:
install the library
```.bash
conda create -n vgan python=3.8
conda activate vgan
git clone https://github.com/nupurkmr9/vision-aided-gan.git
cd vision-aided-gan
bash docs/setup.sh
pip install .
```

For details on off-the-shelf models please see [MODELS.md](docs/MODELS.md)

## Using Pretrained Models
Our final trained models can be downloaded at this [link](https://www.cs.cmu.edu/~vision-aided-gan/models/)

**To generate images**:


```.bash
# random image generation from LSUN Church model

python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 --network=https://www.cs.cmu.edu/~vision-aided-gan/models/main_paper_table2_fulldataset/vision-aided-gan-lsunchurch-ada-3.pkl
```
The above command generates 4 images using the provided seed values and saves it in `out` directory controlled by `--outdir`. Our generator architecture is same as styleGAN2 and can be similarly used in the Python code as described in [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/README.md#using-networks-from-python).

**model evaluation**:
```.bash
python calc_metrics.py --network https://www.cs.cmu.edu/~vision-aided-gan/models/main_paper_table2_fulldataset/vision-aided-gan-lsunchurch-ada-3.pkl --metrics fid50k_full --data lsunchurch --clean 1
```
We use [clean-fid](https://github.com/GaParmar/clean-fid) library to calculate FID metric. We calclate the full real distribution statistics for FID calculation. For details on calculating the statistics, please refer to [clean-fid](https://github.com/GaParmar/clean-fid).
For default FID evaluation of StyleGAN2-ADA use `clean=0`. The above command will return the FID `~1.72`


## Datasets
```python

Dataset preparation is same as given in [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/README.md#preparing-datasets).
Example setup for 100-shot AnimalFace Dog and LSUN Church
import vision_aided_loss

**AnimalFace Dog**
```.bash
mkdir datasets
wget https://data-efficient-gans.mit.edu/datasets/AnimalFace-dog.zip -P datasets
```

**LSUN Church**
```.bash
cd ..
git clone https://github.com/fyu/lsun.git
cd lsun
python3 download.py -c church_outdoor
unzip church_outdoor_train_lmdb.zip
cd ../vision-aided-gan
mkdir datasets
python dataset_tool.py --source ../lsun/church_outdoor_train_lmdb/ --dest datasets/church1k.zip --max-images 1000 --transform=center-crop --width=256 --height=256
```

All other datasets can be downloaded from their repsective websites:

[FFHQ](https://github.com/NVlabs/ffhq-dataset), [LSUN Categories](http://dl.yf.io/lsun/objects/), [AFHQ](https://github.com/clovaai/stargan-v2), [AnimalFace Dog](https://data-efficient-gans.mit.edu/datasets/AnimalFace-dog.zip), [AnimalFace Cat](https://data-efficient-gans.mit.edu/datasets/AnimalFace-cat.zip), [100-shot Bridge-of-Sighs](https://data-efficient-gans.mit.edu/datasets/100-shot-bridge_of_sighs.zip)


## Training new networks

**Vision-aided GAN training with multiple pretrained networks**:
```.bash
python vision-aided-gan.py --outdir models/ --data datasets/AnimalFace-dog.zip --cfg paper256_2fmap --mirror 1 \
--aug ada --augpipe bgc --augcv ada --batch 16 --gpus 2 --kimgs-list '1000,1000,1000' --num 3
```

The network, sample generated images, and logs are saved at regular intervals (controlled by `--snap` flag) in `models/<exp-folder>` dir, where `<exp-folder>` name is based on input args. Network with each progressive additin of pretrained model is saved in a different directory. Logs are saved as TFevents by default. Wandb logging can be enabled by `--wandb-log` flag and setting wandb `entity` in `training.training_loop`.

If fine-tuning a baseline trained model with vision-aided adversarial loss include `--resume <network.pkl>` in the above command.
device='cuda'
discr = vision_aided_loss.Discriminator(cv_type='clip', loss_type='multilevel_sigmoid_s', device=device).to(device)
discr.cv_ensemble.requires_grad_(False) # Freeze feature extractor

`--kimgs-list` controls the number of iterations after which next off-the-shelf model is added. It is a comma separated list of iteration numbers. For dataset with training samples 1k, we initialize `--kimgs-list` to '4000,1000,1000', and for training samples >1k '8000,2000,2000'.
# Sample images
real = sample_real_image()
fake = G.forward(z)

# Update discriminator discr
lossD = discr(real, for_real=True) + discr(fake, for_real=False)
lossD.backward()

**Vision-aided Gan training with a specific pretrained network**
# Update generator G
lossG = discr(fake, for_G=True)
lossG.backward()

```.bash
python train.py --outdir models/ --data datasets/AnimalFace-dog.zip --kimg 10000 --cfg paper256_2fmap --gpus 2 \
--cv input-clip-output-conv_multi_level --cv-loss multilevel_s --augcv ada --mirror 1 --aug ada --warmup 1
# We recommend adding vision-aided adversarial loss after training GAN with standard loss till few warmup_iter.
```

**model selection**: returns the computer vision model with highest linear probe accuracy for the best FID model in a folder or the given network file.
Arg details:

```.bash
python model_selection.py --data mydataset.zip --network <mynetworkfolder or mynetworkpklfile>
```
* `cv_type`: name of the off-the-shelf model from `[clip, dino, swin, vgg, det_coco, seg_ade, face_seg, face_normals]`. Multiple models can be used with '+' separated model names.
* `output_type`: output feature type from off-the-shelf models. should be one of `[conv, conv_multi_level]`. Supports `conv_multi_level` only for clip and dino. For multiple models output_type should be '+' separated output_type for each model.
* `diffaug`: if True performs DiffAugment on vision-aided discriminator with poilcy `color,translation,cutout`. Recommended to keep this as True.
* `num_classes`: for conditional training use num_classes>0. Projection discriminator is used similar to [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch).
* `loss_type`: should be one of `[sigmoid, multilevel_sigmoid, sigmoid_s, multilevl_sigmoid_s, hinge, multilevel_hinge]`. Appeding `_s` enables [label smoothing](https://arxiv.org/abs/1606.03498). If loss_type is None output is a list of logits corresponding to each vision-aided discriminator.
* `device`: device for off-the-shelf model weights.

**To add you own pretrained Model**:
create the class file to extract pretrained features inside `vision_model` folder. Add the class path in the `class_name_dict` in `vision_model.cvmodel.CVWrapper` class. Update the architecture of trainable classifier head over pretrained features in `training.cv_discriminator`.
### Vision-aided StyleGAN3 training
Please see [vision-aided-stylegan3](https://github.com/nupurkmr9/vision_aided_module/tree/main/stylegan3) README for training StyleGAN3 models with our method.

### Vision-aided BigGAN training
Please see [vision-aided-biggan](https://github.com/nupurkmr9/vision_aided_module/tree/main/biggan) README for training BigGAN models with our method.

### To add you own pretrained Model

<details ><summary> <b>Training configuration details</b> </summary>
create the class file to extract pretrained features as `vision_aided_module/<custom_model>.py`. Add the class path in the `class_name_dict` in `vision_aided_module.cvmodel.CVBackbone` class. Update the architecture of trainable classifier head over pretrained features in `vision_aided_module.cv_discriminator`. Reinstall library via `pip install .`

Training configuration corresponding to training with our loss:

* `--cv=input-dino-output-conv_multi_level` pretrained network and its configuration.
* `--warmup=0` should be enabled when training from scratch. Introduces our loss after training with 500k images.
* `--cv-loss=multilevel` what loss to use on pretrained model based discriminator.
* `--augcv=ada` performs ADA augmentation on pretrained model based discriminator.
* `--augcv=diffaugment-<policy>` performs DiffAugment on pretrained model based discriminator with given poilcy e.g. `color,translation,cutout`
* `--augpipecv=bgc` ADA augmentation strategy. Note: cutout is always enabled.
* `--ada-target-cv=0.3` adjusts ADA target value for pretrained model based discriminator.
* `--exact-resume=1` enables resume along with optimizer and augmentation state. default is 0.

StyleGAN2 configurations:
* `--outdir='models/'` directory to save training runs.
* `--data` data directory created after running `dataset_tool.py`.
* `--metrics=fid50kfull` evaluates FID calculation during training at every `snap` iterations.
* `--cfg=paper256` architecture and hyperparameter configuration for G and D.
* `--mirror=1` enables horizontal flipping
* `--aug=ada` enables ADA augmentation in trainable D.
* `--diffaugment=color,translation,cutout` enables DiffAugment in trainable D.
* `--augpipe=bgc` ADA augmentation strategy in trainable D.
* `--snap=25` evaluation and model saving interval

Miscellaneous configurations:
* `--wandb-log=1` enables wandb logging.
* `--clean=1` enables FID calculation using [clean-fid](https://github.com/GaParmar/clean-fid) if the real distribution statistics are pre-calculated. default is false.

Run `python train.py --help` for more details and the full list of args.
</details>


## References
Expand Down
Loading

0 comments on commit f200140

Please sign in to comment.