This is the PyTorch implementation of Single-Domain Adaptation via Target-Aware Generative Augmentations.
CelebA-HQ dataset was utilized for all our experiments.
We tested our code with the following package versions
pytorch 1.10.2
cudatoolkit 10.2.89
ninja 1.10.2.3
The checkpoints for the styleGANs can be downloaded from here
├── SISTA_DA
│ ├── celeba_dataloader.py
│ ├── celebahq_dataloader.py
│ ├── data_list.py
│ ├── image_NRC_target.py
│ ├── image_source.py
│ ├── image_target_memo.py
│ ├── image_target.py
│ ├── loss.py
│ ├── network.py
│ ├── randconv.py
│ ├── DATA/
│ └── utils_memo.py
| └── run.sh
|
├── GenerativeAugmentations
│ ├── data_augmentation.ipynb
│ ├── e4e
│ ├── e4e_projection.py
│ ├── model.py
│ ├── models
│ │ ├── dlibshape_predictor_68_face_landmarks.dat
│ │ ├── e4e_ffhq_encode.pt
│ │ ├── psp_ffhq_toonify.pt
│ │ ├── stylegan2-color_sketch.pt
│ │ ├── stylegan2-ffhq-config-f.pt
│ │ ├── stylegan2-pencil_sketch.pt
│ │ ├── stylegan2-sketch.pt
│ │ ├── stylegan2-toon.pt
│ │ └── stylegan2-water_color.pt
│ ├── op
│ ├── README.MD
│ ├── transformations.py
│ └── util.py
├── README.MD
The jupyter notebook data_augmentation.ipynb guides in generating the augementated images. The notebook illustrates three major steps
- StyleGAN fine tuning to target domain
- Target domain image generation
- Target-Aware augmentation
To train the source model on a desired attribute
python image_source.py --attribute 'Smiling'
python SISTA_DA/image_NRC_target.py --variant 'interp_concat' --attribute 'Smiling' --t 1
This command adapts a source trained model to target domain A (controlled by the --t
flag) for the attribute 'Smiling' using 'SISTA_DA' protocol.
Similarly for using the images generated by pruning-zero or pruning-rewind please run
Pruning Zero:
python SISTA_DA/image_NRC_target.py --variant '' --prune True --attribute 'Smiling' --t 1
Pruning Rewind:
python SISTA_DA/image_NRC_target.py --variant 'prune_rewind' --prune True --attribute 'Smiling' --t 1
Adaptation using the unlabled target
python SISTA_DA/image_NRC_target.py --variant 'direct_target' --attribute 'Smiling' --t 1
To generate the results for baseline performance of MEMO with two different varients MEMO (Augmix) and MEMO (RandConv).
MEMO (Augmix)
python SISTA_DA/image_target_memo --augmix True --attribute 'Smiling' --t 1
and
MEMO (RandConv)
python SISTA_DA/image_target_memo --augmix False --attribute 'Smiling' --t 1
If you use this code or ideas from our paper, please cite our paper:
@article{subramanyam2022SISTA,
title={Single-SISTA_DA Domain Adaptation via Target-Aware Generative Augmentations},
author = {Subramanyam, Rakshith and Thopalli, Kowshik and Berman, Spring and Turaga, Pavan and Thiagarajan, Jayaraman J.},
journal={arXiv preprint arxiv.2210.16692},
year={2022}
}
This code builds upon the following codebases: StyleGAN2 by rosalinity, e4e, StyleGAN-NADA, NRC, MEMO and RandConv. We thank the authors of the respective works for publicly sharing their code. Please cite them when appropriate.