Code for the paper Better Diffusion Models Further Improve Adversarial Training (ICML 2023).
This project is tested under the following environment settings:
- OS: Ubuntu 20.04.3
- GPU: NVIDIA A100
- Cuda: 11.1, Cudnn: v8.2
- Python: 3.9.5
- PyTorch: 1.8.0
- Torchvision: 0.9.0
The adversarial training codes are modifed based on the PyTorch implementation of Rebuffi et al., 2021. The generation codes are modifed based on the official implementation of EDM. For data generation, please refer to edm/README.md
for more details.
- Install or download AutoAttack:
pip install git+https://github.com/fra31/auto-attack
- Install or download RandAugment:
pip install git+https://github.com/ildoonet/pytorch-randaugment
- Download EDM generated data to
./edm_data
. For TinyImageNet, we provide data generated by ImageNet EDM. Since 20M and 50M data files are too large, we split them into several parts:
dataset | size | link |
---|---|---|
CIFAR-10 | 1M | npz |
CIFAR-10 | 5M | npz |
CIFAR-10 | 10M | npz |
CIFAR-10 | 20M | part1 part2 |
CIFAR-10 | 50M | part1 part2 part3 part4 |
CIFAR-100 | 1M | npz |
CIFAR-100 | 50M | part1 part2 part3 part4 |
SVHN | 1M | npz |
SVHN | 50M | part1 part2 part3 part4 part5 |
TinyImageNet | 1M | npz |
- Merge 20M and 50M generated data:
python merge-data.py
Run train-wa.py
for reproducing the results reported in the papers. For example, train a WideResNet-28-10 model via TRADES on CIFAR-10 with the 1M additional generated data provided by EDM (Karras et al., 2022):
python train-wa.py --data-dir 'dataset-data' \
--log-dir 'trained_models' \
--desc 'WRN28-10Swish_cifar10s_lr0p2_TRADES5_epoch400_bs512_fraction0p7_ls0p1' \
--data cifar10s \
--batch-size 512 \
--model wrn-28-10-swish \
--num-adv-epochs 400 \
--lr 0.2 \
--beta 5.0 \
--unsup-fraction 0.7 \
--aux-data-filename 'edm_data/cifar10/1m.npz' \
--ls 0.1
The trained models can be evaluated by running eval-aa.py
which uses AutoAttack for evaluating the robust accuracy. Run the command (taking the checkpoint above as an example):
python eval-aa.py --data-dir 'dataset-data' \
--log-dir 'trained_models' \
--desc 'WRN28-10Swish_cifar10s_lr0p2_TRADES5_epoch400_bs512_fraction0p7_ls0p1'
To evaluate the model on last epoch under AutoAttack, run the command:
python eval-last-aa.py --data-dir 'dataset-data' \
--log-dir 'trained_models' \
--desc 'WRN28-10Swish_cifar10s_lr0p2_TRADES5_epoch400_bs512_fraction0p7_ls0p1'
We provide the state-of-the-art pre-trained checkpoints of WRN-28-10 (Swish) and WRN-70-16 (Swish). Refer to argtxt
below for specific hyper-parameters. Clean and robust accuracies are measured on the full test set. The robust accuracy is measured using AutoAttack.
dataset | norm | radius | architecture | clean | robust | link |
---|---|---|---|---|---|---|
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-28-10 | 92.44% | 67.31% | checkpoint argtxt |
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-70-16 | 93.25% | 70.69% | checkpoint argtxt |
CIFAR-10 | ℓ2 | 128 / 255 | WRN-28-10 | 95.16% | 83.63% | checkpoint argtxt |
CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | 95.54% | 84.86% | checkpoint argtxt |
CIFAR-100 | ℓ∞ | 8 / 255 | WRN-28-10 | 72.58% | 38.83% | checkpoint argtxt |
CIFAR-100 | ℓ∞ | 8 / 255 | WRN-70-16 | 75.22% | 42.67% | checkpoint argtxt |
SVHN | ℓ∞ | 8 / 255 | WRN-28-10 | 95.56% | 64.01% | checkpoint argtxt |
TinyImageNet | ℓ∞ | 8 / 255 | WRN-28-10 | 65.19% | 31.30% | checkpoint argtxt |
For evaluation under AutoAttack:
- Download
checkpoint
totrained_models/mymodel/weights-best.pt
- Download
argtxt
totrained_models/mymodel/args.txt
- Run the command:
python eval-aa.py --data-dir 'dataset-data' --log-dir 'trained_models' --desc 'mymodel'
We have uploaded CIFAR-10/CIFAR-100 models to the model zoo of RobustBench. See the tour to evaluate the performance by RobustBench.
If you find the code useful for your research, please consider citing
@inproceedings{wang2023better,
title={Better Diffusion Models Further Improve Adversarial Training},
author={Wang, Zekai and Pang, Tianyu and Du, Chao and Lin, Min and Liu, Weiwei and Yan, Shuicheng},
booktitle={International Conference on Machine Learning (ICML)},
year={2023}
}
and/or our related works
@inproceedings{pang2022robustness,
title={Robustness and Accuracy Could be Reconcilable by (Proper) Definition},
author={Pang, Tianyu and Lin, Min and Yang, Xiao and Zhu, Jun and Yan, Shuicheng},
booktitle={International Conference on Machine Learning (ICML)},
year={2022}
}
@inproceedings{pang2021bag,
title={Bag of Tricks for Adversarial Training},
author={Pang, Tianyu and Yang, Xiao and Dong, Yinpeng and Su, Hang and Zhu, Jun},
booktitle={International Conference on Learning Representations (ICLR)},
year={2021}
}