Skip to content

πŸ’» 🐈 Added a self-attention layer to the CycleGAN implementation (PyTorch).

License

Notifications You must be signed in to change notification settings

Topaz1618/CycleganSA

Repository files navigation

Introduction

CycleGANSA is an improved version of the CycleGAN model. It includes a self-attention mechanism and semantic segmentation to speed up generation times and fix blurred backgrounds in transformed images. It has been tested on the MNIST and horse2zebra (h2z) datasets. This work has been published by IEEE and can be accessed here.

Results

After adding self-attention and semantic segmentation, the generation speed and quality have significantly improved.

CycleGAN with self-attention

The generation process is significantly faster compared to the original CycleGAN.

[Left to Right: Original Image, CycleGAN Output, CycleganSA Output]

CycleGAN with Semantic Segmentation

CycleGAN Result

The original CycleGAN has an issue with a blurred background after style transformation.

[Left to Right: Original Image, CycleGAN Output]

CycleganSA Result

The background issue has been solved after added semantic segmentation.

[Left to Right: Original Image, CycleGAN Output, CycleganSA Output]

Environment

  • Ubuntu18.04
  • Python 3.8
  • CUDA 11.3
  • cuDNN 8, NVCC
  • Pytorch 1.11.0
  • torchvision 0.12.0
  • torchaudio 0.11.0

Usage

Dataset Structure

-- datasets
    └── dataset_name
        β”œβ”€β”€ domain_A
        └── domain_B

Training

To train with a shallow self-attention mechanism:

python train.py --model_type shallow --dataroot datasets/horse2zebra/ --name h2z_sa_shallow

python train.py --model_type shallow --dataroot datasets/mnist/ --name mnist_sa_shallow

To train with a deep self-attention mechanism:

python train.py --model_type deep --dataroot datasets/horse2zebra/ --name h2z_sa_deep --gpu 0

python train.py --model_type deep --dataroot datasets/mnist/ --name h2z_sa_deep

To train without self-attention mechanism:

python train.py --dataroot datasets/horse2zebra/ --name h2z_sa_shallow

python train.py--dataroot datasets/mnist/ --name mnist_sa_shallow

Training Parameters

  • --dataroot: Path to the dataset
  • --name: Specific checkpoint location. (will be saved in checkpoint/{name})
  • --results_dir: Directory to save the results
  • --gpu-ids: IDs of GPUs to use, default is 0 for GPU, -1 for CPU
  • --model_type: shallow | deep | default
    • shallow: Use self-attention in shallow CycleGAN
    • deep: Use self-attention in deep CycleGAN
    • default: Use the original CycleGAN

Testing

 python test.py --dataroot datasets/mnist/testA --name mnist_sa_deep --results_dir output/mnist/sa_shallow --direction AtoB

Testing Parameters

  • --dataroot: Path to the test images, ensure to select images from domain A or B
  • --name: Specific checkpoint model location. (will be fetched from checkpoint/{name})
  • --results_dir: Directory to save the results
  • --gpu-ids: IDs of GPUs to use, default is 0 for GPU, -1 for CPU
  • --direction: Specify the direction of the transformation, AtoB or BtoA

Acknowledgements & References

This project is based on the CycleGAN model. We acknowledge the authors of the original CycleGAN paper:

  • Zhu, Jun-Yan, Taesung Park, Phillip Isola, and Alexei A. Efros. "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks." In Proceedings of the IEEE International Conference on Computer Vision (ICCV), 2017.

Link to the original paper: CycleGAN Paper

Original CycleGAN code: pytorch-CycleGAN-and-pix2pix

Copyright and License

CycleganSA is licensed under the BSD License License. Refer to the LICENSE file for more information.

Please feel free to ask any questions or provide suggestions. Thank you for using and contributing to CycleganSA!

About

πŸ’» 🐈 Added a self-attention layer to the CycleGAN implementation (PyTorch).

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published