forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 348193373
- Loading branch information
Showing
7 changed files
with
686 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Adversarial Robustness | ||
|
||
This repository contains the code needed to evaluate models trained in | ||
[Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](https://arxiv.org/abs/2010.03593) | ||
|
||
|
||
## Contents | ||
|
||
We have released our top-performing models in two formats compatible with | ||
[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/). | ||
This repository also contains our model definitions. | ||
|
||
## Running the example code | ||
|
||
### Downloading a model | ||
|
||
Download a model from links listed in the following table. | ||
Clean and robust accuracies are measured on the full test set. | ||
The robust accuracy is measured using | ||
[AutoAttack](https://github.com/fra31/auto-attack). | ||
|
||
| dataset | norm | radius | architecture | extra data | clean | robust | link | | ||
|---|:---:|:---:|:---:|:---:|---:|---:|:---:| | ||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✓ | 91.10% | 65.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.pt) | ||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-28-10 | ✓ | 89.48% | 62.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt) | ||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✗ | 85.29% | 57.20% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.pt) | ||
| CIFAR-10 | ℓ<sub>∞</sub> | 8 / 255 | WRN-34-20 | ✗ | 85.64% | 56.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.pt) | ||
| CIFAR-10 | ℓ<sub>2</sub> | 128 / 255 | WRN-70-16 | ✓ | 94.74% | 80.53% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.pt) | ||
| CIFAR-10 | ℓ<sub>2</sub> | 128 / 255 | WRN-70-16 | ✗ | 90.90% | 74.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.pt) | ||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✓ | 69.15% | 36.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.pt) | ||
| CIFAR-100 | ℓ<sub>∞</sub> | 8 / 255 | WRN-70-16 | ✗ | 60.86% | 30.03% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.pt) | ||
| MNIST | ℓ<sub>∞</sub> | 0.3 | WRN-28-10 | ✗ | 99.26% | 96.34% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.pt) | ||
|
||
### Using the model | ||
|
||
Once downloaded, a model can be evaluated (clean accuracy) by running the | ||
`eval.py` script in either the `jax` or `pytorch` folders. E.g.: | ||
|
||
``` | ||
cd jax | ||
python3 eval.py \ | ||
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10 | ||
``` | ||
|
||
|
||
## Citing this work | ||
|
||
If you use this code or these models in your work, please cite the accompanying | ||
paper: | ||
|
||
``` | ||
@article{gowal2020uncovering, | ||
title={Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples}, | ||
author={Gowal, Sven and Qin, Chongli and Uesato, Jonathan and Mann, Timothy and Kohli, Pushmeet}, | ||
journal={arXiv preprint arXiv:2010.03593}, | ||
year={2020}, | ||
url={https://arxiv.org/pdf/2010.03593} | ||
} | ||
``` | ||
|
||
## Disclaimer | ||
|
||
This is not an official Google product. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright 2020 Deepmind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST.""" | ||
|
||
from absl import app | ||
from absl import flags | ||
import haiku as hk | ||
import numpy as np | ||
import tensorflow.compat.v2 as tf | ||
import tensorflow_datasets as tfds | ||
import tqdm | ||
|
||
from adversarial_robustness.jax import model_zoo | ||
|
||
_CKPT = flags.DEFINE_string( | ||
'ckpt', None, 'Path to checkpoint.') | ||
_DATASET = flags.DEFINE_enum( | ||
'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'], | ||
'Dataset on which the checkpoint is evaluated.') | ||
_WIDTH = flags.DEFINE_integer( | ||
'width', 16, 'Width of WideResNet.') | ||
_DEPTH = flags.DEFINE_integer( | ||
'depth', 70, 'Depth of WideResNet.') | ||
_BATCH_SIZE = flags.DEFINE_integer( | ||
'batch_size', 100, 'Batch size.') | ||
_NUM_BATCHES = flags.DEFINE_integer( | ||
'num_batches', 0, | ||
'Number of batches to evaluate (zero means the whole dataset).') | ||
|
||
|
||
def main(unused_argv): | ||
print(f'Loading "{_CKPT.value}"') | ||
print(f'Using a WideResNet with depth {_DEPTH.value} and width ' | ||
f'{_WIDTH.value}.') | ||
|
||
# Create dataset. | ||
if _DATASET.value == 'mnist': | ||
_, data_test = tf.keras.datasets.mnist.load_data() | ||
normalize_fn = model_zoo.mnist_normalize | ||
elif _DATASET.value == 'cifar10': | ||
_, data_test = tf.keras.datasets.cifar10.load_data() | ||
normalize_fn = model_zoo.cifar10_normalize | ||
else: | ||
assert _DATASET.value == 'cifar100' | ||
_, data_test = tf.keras.datasets.cifar100.load_data() | ||
normalize_fn = model_zoo.cifar100_normalize | ||
|
||
# Create model. | ||
@hk.transform_with_state | ||
def model_fn(x, is_training=False): | ||
model = model_zoo.WideResNet( | ||
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value, | ||
activation='swish') | ||
return model(normalize_fn(x), is_training=is_training) | ||
|
||
# Build dataset. | ||
images, labels = data_test | ||
samples = (images.astype(np.float32) / 255., | ||
np.squeeze(labels, axis=-1).astype(np.int64)) | ||
data = tf.data.Dataset.from_tensor_slices(samples).batch(_BATCH_SIZE.value) | ||
test_loader = tfds.as_numpy(data) | ||
|
||
# Load model parameters. | ||
rng_seq = hk.PRNGSequence(0) | ||
if _CKPT.value == 'dummy': | ||
for images, _ in test_loader: | ||
break | ||
params, state = model_fn.init(next(rng_seq), images, is_training=True) | ||
# Reset iterator. | ||
test_loader = tfds.as_numpy(data) | ||
else: | ||
params, state = np.load(_CKPT.value, allow_pickle=True) | ||
|
||
# Evaluation. | ||
correct = 0 | ||
total = 0 | ||
batch_count = 0 | ||
total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value) | ||
for images, labels in tqdm.tqdm(test_loader, total=total_batches): | ||
outputs = model_fn.apply(params, state, next(rng_seq), images)[0] | ||
predicted = np.argmax(outputs, 1) | ||
total += labels.shape[0] | ||
correct += (predicted == labels).sum().item() | ||
batch_count += 1 | ||
if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value: | ||
break | ||
print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%') | ||
|
||
|
||
if __name__ == '__main__': | ||
flags.mark_flag_as_required('ckpt') | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright 2020 Deepmind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""WideResNet implementation in JAX using Haiku.""" | ||
|
||
from typing import Any, Mapping, Optional, Text | ||
|
||
import haiku as hk | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) | ||
CIFAR10_STD = (0.2471, 0.2435, 0.2616) | ||
CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) | ||
CIFAR100_STD = (0.2673, 0.2564, 0.2762) | ||
|
||
|
||
class _WideResNetBlock(hk.Module): | ||
"""Block of a WideResNet.""" | ||
|
||
def __init__(self, num_filters, stride=1, projection_shortcut=False, | ||
activation=jax.nn.relu, norm_args=None, name=None): | ||
super().__init__(name=name) | ||
num_bottleneck_layers = 1 | ||
self._activation = activation | ||
if norm_args is None: | ||
norm_args = { | ||
'create_offset': False, | ||
'create_scale': True, | ||
'decay_rate': .99, | ||
} | ||
self._bn_modules = [] | ||
self._conv_modules = [] | ||
for i in range(num_bottleneck_layers + 1): | ||
s = stride if i == 0 else 1 | ||
self._bn_modules.append(hk.BatchNorm( | ||
name='batchnorm_{}'.format(i), | ||
**norm_args)) | ||
self._conv_modules.append(hk.Conv2D( | ||
output_channels=num_filters, | ||
padding='SAME', | ||
kernel_shape=(3, 3), | ||
stride=s, | ||
with_bias=False, | ||
name='conv_{}'.format(i))) # pytype: disable=not-callable | ||
if projection_shortcut: | ||
self._shortcut = hk.Conv2D( | ||
output_channels=num_filters, | ||
kernel_shape=(1, 1), | ||
stride=stride, | ||
with_bias=False, | ||
name='shortcut') # pytype: disable=not-callable | ||
else: | ||
self._shortcut = None | ||
|
||
def __call__(self, inputs, **norm_kwargs): | ||
x = inputs | ||
orig_x = inputs | ||
for i, (bn, conv) in enumerate(zip(self._bn_modules, self._conv_modules)): | ||
x = bn(x, **norm_kwargs) | ||
x = self._activation(x) | ||
if self._shortcut is not None and i == 0: | ||
orig_x = x | ||
x = conv(x) | ||
if self._shortcut is not None: | ||
shortcut_x = self._shortcut(orig_x) | ||
x += shortcut_x | ||
else: | ||
x += orig_x | ||
return x | ||
|
||
|
||
class WideResNet(hk.Module): | ||
"""WideResNet designed for CIFAR-10.""" | ||
|
||
def __init__(self, | ||
num_classes: int = 10, | ||
depth: int = 28, | ||
width: int = 10, | ||
activation: Text = 'relu', | ||
norm_args: Optional[Mapping[Text, Any]] = None, | ||
name: Optional[Text] = None): | ||
super(WideResNet, self).__init__(name=name) | ||
if (depth - 4) % 6 != 0: | ||
raise ValueError('depth should be 6n+4.') | ||
self._activation = getattr(jax.nn, activation) | ||
if norm_args is None: | ||
norm_args = { | ||
'create_offset': False, | ||
'create_scale': True, | ||
'decay_rate': .99, | ||
} | ||
self._conv = hk.Conv2D( | ||
output_channels=16, | ||
kernel_shape=(3, 3), | ||
stride=1, | ||
with_bias=False, | ||
name='init_conv') # pytype: disable=not-callable | ||
self._bn = hk.BatchNorm( | ||
name='batchnorm', | ||
**norm_args) | ||
self._linear = hk.Linear( | ||
num_classes, | ||
name='logits') | ||
|
||
blocks_per_layer = (depth - 4) // 6 | ||
filter_sizes = [width * n for n in [16, 32, 64]] | ||
self._blocks = [] | ||
for layer_num, filter_size in enumerate(filter_sizes): | ||
blocks_of_layer = [] | ||
for i in range(blocks_per_layer): | ||
stride = 2 if (layer_num != 0 and i == 0) else 1 | ||
projection_shortcut = (i == 0) | ||
blocks_of_layer.append(_WideResNetBlock( | ||
num_filters=filter_size, | ||
stride=stride, | ||
projection_shortcut=projection_shortcut, | ||
activation=self._activation, | ||
norm_args=norm_args, | ||
name='resnet_lay_{}_block_{}'.format(layer_num, i))) | ||
self._blocks.append(blocks_of_layer) | ||
|
||
def __call__(self, inputs, **norm_kwargs): | ||
net = inputs | ||
net = self._conv(net) | ||
|
||
# Blocks. | ||
for blocks_of_layer in self._blocks: | ||
for block in blocks_of_layer: | ||
net = block(net, **norm_kwargs) | ||
net = self._bn(net, **norm_kwargs) | ||
net = self._activation(net) | ||
|
||
net = jnp.mean(net, axis=[1, 2]) | ||
return self._linear(net) | ||
|
||
|
||
def mnist_normalize(image: jnp.array) -> jnp.array: | ||
image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant', | ||
constant_values=0) | ||
return (image - .5) * 2. | ||
|
||
|
||
def cifar10_normalize(image: jnp.array) -> jnp.array: | ||
means = jnp.array(CIFAR10_MEAN, dtype=image.dtype) | ||
stds = jnp.array(CIFAR10_STD, dtype=image.dtype) | ||
return (image - means) / stds | ||
|
||
|
||
def cifar100_normalize(image: jnp.array) -> jnp.array: | ||
means = jnp.array(CIFAR100_MEAN, dtype=image.dtype) | ||
stds = jnp.array(CIFAR100_STD, dtype=image.dtype) | ||
return (image - means) / stds |
Oops, something went wrong.