Skip to content

PyTorch code for our NeurIPS 2022 paper "Cross Aggregation Transformer for Image Restoration"

License

Notifications You must be signed in to change notification settings

zhengchen1999/CAT

Repository files navigation

Cross Aggregation Transformer for Image Restoration

Zheng Chen, Yulun Zhang, Jinjin Gu, Yongbing Zhang, Linghe Kong, and Xin Yuan, "Cross Aggregation Transformer for Image Restoration", NeurIPS, 2022 (Spotlight)

[paper] [arXiv] [supplementary material] [visual results] [pretrained models]


Abstract: Recently, Transformer architecture has been introduced into image restoration to replace convolution neural network (CNN) with surprising results. Considering the high computational complexity of Transformer with global attention, some methods use the local square window to limit the scope of self-attention. However, these methods lack direct interaction among different windows, which limits the establishment of long-range dependencies. To address the above issue, we propose a new image restoration model, Cross Aggregation Transformer (CAT). The core of our CAT is the Rectangle-Window Self-Attention (Rwin-SA), which utilizes horizontal and vertical rectangle window attention in different heads parallelly to expand the attention area and aggregate the features cross different windows. We also introduce the Axial-Shift operation for different window interactions. Furthermore, we propose the Locality Complementary Module to complement the self-attention mechanism, which incorporates the inductive bias of CNN (e.g., translation invariance and locality) into Transformer, enabling global-local coupling. Extensive experiments demonstrate that our CAT outperforms recent state-of-the-art methods on several image restoration applications.


SR (x4) HQ LQ SwinIR CAT (ours)

Dependencies

  • Python 3.8
  • PyTorch 1.8.0
  • NVIDIA GPU + CUDA
# Clone the github repo and go to the default directory 'CAT'.
git clone https://github.com/zhengchen1999/CAT.git
conda create -n CAT python=3.8
conda activate CAT
pip install -r requirements.txt
python setup.py develop

TODO

  • Image SR
  • JPEG Compression Artifact Reduction
  • Image Denoising
  • Other tasks

Contents

  1. Datasets
  2. Models
  3. Training
  4. Testing
  5. Results
  6. Citation
  7. Acknowledgements

Datasets

Used training and testing sets can be downloaded as follows:

Task Training Set Testing Set Visual Results
image SR DIV2K (800 training images, 100 validation images) + Flickr2K (2650 images) [complete training dataset DF2K] Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset download] here
grayscale JPEG compression artifact reduction DIV2K (800 training images) + Flickr2K (2650 images) + WED(4744 images) + BSD500 (400 training&testing images) [complete training dataset DFWB] Classic5 +LIVE + Urban100 [complete testing dataset download] here
real image denoising SIDD (320 training images) [complete training dataset SIDD] SIDD + DND [complete testing dataset download] here

Here the visual results are generated under SR (x4), JPEG compression artifact reduction (q10), and real image denoising.

Download training and testing datasets and put them into the corresponding folders of datasets/ and restormer/datasets. See datasets for the detail of directory structure.

Models

Task Method Params (M) FLOPs (G) Dataset PSNR (dB) SSIM Model Zoo Visual Results
SR CAT-R 16.60 292.7 Urban100 27.45 0.8254 Google Drive Google Drive
SR CAT-A 16.60 360.7 Urban100 27.89 0.8339 Google Drive Google Drive
SR CAT-R-2 11.93 216.3 Urban100 27.59 0.8285 Google Drive Google Drive
SR CAT-A-2 16.60 387.9 Urban100 27.99 0.8357 Google Drive Google Drive
CAR CAT 16.20 346.4 LIVE1 29.89 0.8295 Google Drive Google Drive
real-DN CAT 25.77 53.2 SIDD 40.01 0.9600 Google Drive Google Drive

The performance is reported on Urban100 (x4, SR), LIVE1 (q=10, CAR), and SIDD (real-DN). The test input size of FLOPs is 128 x 128.

Training

Image SR

  • Cd to 'CAT' and run the setup script.

    # If already in CAT and set up, please ignore
    python setup.py develop
  • Download training (DF2K, already processed) and testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in datasets/.

  • Run the following scripts. The training configuration is in options/train/.

    # CAT-R, SR, input=64x64, 4 GPUs
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x2.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x3.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x4.yml --launcher pytorch
    
    # CAT-A, SR, input=64x64, 4 GPUs
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x2.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x3.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x4.yml --launcher pytorch
    
    # CAT-R-2, SR, input=64x64, 4 GPUs
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x2.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x3.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x4.yml --launcher pytorch
    
    # CAT-A-2, SR, input=64x64, 4 GPUs
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x2.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x3.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x4.yml --launcher pytorch
  • The training experiment is in experiments/.

JPEG Compression Artifact Reduction

  • Cd to 'CAT' and run the setup script

    # If already in CAT and set up, please ignore
    python setup.py develop
  • Download training (DFWB, already processed) and testing (Classic5, LIVE1, Urban100, already processed) datasets, place them in datasets/.

  • Run the following scripts. The training configuration is in options/train/.

    # CAT, CAR, input=128x128, 4 GPUs
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q10.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q20.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q30.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q40.yml --launcher pytorch
  • The training experiment is in experiments/.

Real Image Denoising

  • Cd to 'CAT/restormer' and run the setup script

    # If already in restormer and set up, please ignore
    python setup.py develop --no_cuda_ext
  • Download training (SIDD-train, contains validation dataset, already processed) datasets, and place them in datasets/ (restormer/datasets/).

  • Run the following scripts. The training configuration is in options/ (restormer/options/).

    # CAT, Real DN, Progressive Learning, 8 GPUs
    python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train_RealDenoising_CAT.yml --launcher pytorch
  • The training experiment is in experiments/ (restormer/experiments/).

Testing

Image SR

  • Cd to 'CAT' and run the setup script

    # If already in CAT and set up, please ignore
    python setup.py develop
  • Download the pre-trained models and place them in experiments/pretrained_models/.

    We provide pre-trained models for image SR: CAT-R, CAT-A, CAT-A, and CAT-R-2 (x2, x3, x4).

  • Download testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in datasets/.

  • Run the following scripts. The testing configuration is in options/test/ (e.g., test_CAT_R_sr_x2.yml).

    Note 1: You can set use_chop: True (default: False) in YML to chop the image for testing.

    # No self-ensemble
    # CAT-R, SR, reproduces results in Table 2 of the main paper
    python basicsr/test.py -opt options/test/test_CAT_R_sr_x2.yml
    python basicsr/test.py -opt options/test/test_CAT_R_sr_x3.yml
    python basicsr/test.py -opt options/test/test_CAT_R_sr_x4.yml
    
    # CAT-A, SR, reproduces results in Table 2 of the main paper
    python basicsr/test.py -opt options/test/test_CAT_A_sr_x2.yml
    python basicsr/test.py -opt options/test/test_CAT_A_sr_x3.yml
    python basicsr/test.py -opt options/test/test_CAT_A_sr_x4.yml
    
    # CAT-R-2, SR, reproduces results in Table 1 of the supplementary material
    python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x2.yml
    python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x3.yml
    python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x4.yml
    
    # CAT-A-2, SR, reproduces results in Table 1 of the supplementary material
    python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x2.yml
    python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x3.yml
    python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x4.yml
  • The output is in results/.

JPEG Compression Artifact Reduction

  • Cd to 'CAT' and run the setup script

    # If already in CAT and set up, please ignore
    python setup.py develop
  • Download the pre-trained models and place them in experiments/pretrained_models/.

    We provide pre-trained models for JPEG compression artifact reduction: CAT (q10, q20, q30, q40).

  • Download testing (Classic5, LIVE, Urban100, already processed) datasets, place them in datasets/.

  • Run the following scripts. The testing configuration is in options/test/ (e.g., test_CAT_car_q10.yml).

    # No self-ensemble
    # CAT-A, CAR, rereproduces results in Table 3 of the main paper
    python basicsr/test.py -opt options/test/test_CAT_car_q10.yml
    python basicsr/test.py -opt options/test/test_CAT_car_q20.yml
    python basicsr/test.py -opt options/test/test_CAT_car_q30.yml
    python basicsr/test.py -opt options/test/test_CAT_car_q40.yml
  • The output is in results/.

Real Image Denoising

  • Cd to 'CAT' and run the setup script

    # If already in CAT and set up, please ignore
    python setup.py develop
  • Download the pre-trained models and place them in experiments/pretrained_models/.

  • Download testing (SIDD, DND) datasets, place them in datasets/.

  • Run the following scripts. The testing configuration is in options/test/.

    # No self-ensemble
    # CAT, real DN, reproduces results in Table 4 of the main paper
    # testing on SIDD
    python test_real_denoising_sidd.py --save_images
    evaluate_sidd.m
    
    # testing on DND
    python test_real_denoising_dnd.py --save_images
  • The output is in results/.

Results

We achieve state-of-the-art performance on image SR, JPEG compression artifact reduction and real image denoising. Detailed results can be found in the paper. All visual results of CAT can be downloaded here.

Image SR (click to expand)
  • results in Table 2 of the main paper

  • results in Table 1 of the supplementary material

  • visual comparison (x4) in the main paper

  • visual comparison (x4) in the supplementary material

JPEG Compression Artifact Reduction (click to expand)
  • results in Table 3 of the main paper

  • results in Table 3 of the supplementary material (test on Urban100)

  • visual comparison (q=10) in the main paper

  • visual comparison (q=10) in the supplementary material

Real Image Denoising (click to expand)
  • results in Table 4 of the main paper

*: We re-test the SIDD with all official pre-trained models.

Citation

If you find the code helpful in your research or work, please cite the following paper(s).

@inproceedings{chen2022cross,
    title={Cross Aggregation Transformer for Image Restoration},
    author={Chen, Zheng and Zhang, Yulun and Gu, Jinjin and Zhang, Yongbing and Kong, Linghe and Yuan, Xin},
    booktitle={NeurIPS},
    year={2022}
}

Acknowledgements

This code is built on BasicSR and Restormer.

About

PyTorch code for our NeurIPS 2022 paper "Cross Aggregation Transformer for Image Restoration"

Resources

License

Stars

Watchers

Forks

Packages

No packages published