Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 348193373
  • Loading branch information
aslanides authored and derpson committed Jan 5, 2021
1 parent 17700a6 commit a6aeb26
Show file tree
Hide file tree
Showing 7 changed files with 686 additions and 0 deletions.
63 changes: 63 additions & 0 deletions adversarial_robustness/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Adversarial Robustness

This repository contains the code needed to evaluate models trained in
[Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](https://arxiv.org/abs/2010.03593)


## Contents

We have released our top-performing models in two formats compatible with
[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/).
This repository also contains our model definitions.

## Running the example code

### Downloading a model

Download a model from links listed in the following table.
Clean and robust accuracies are measured on the full test set.
The robust accuracy is measured using
[AutoAttack](https://github.com/fra31/auto-attack).

| dataset | norm | radius | architecture | extra data | clean | robust | link |
|---|:---:|:---:|:---:|:---:|---:|---:|:---:|
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2713; | 91.10% | 65.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.pt)
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | &#x2713; | 89.48% | 62.80% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2717; | 85.29% | 57.20% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_without.pt)
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-34-20 | &#x2717; | 85.64% | 56.86% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn34-20_without.pt)
| CIFAR-10 | &#8467;<sub>2</sub> | 128 / 255 | WRN-70-16 | &#x2713; | 94.74% | 80.53% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.pt)
| CIFAR-10 | &#8467;<sub>2</sub> | 128 / 255 | WRN-70-16 | &#x2717; | 90.90% | 74.50% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.pt)
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2713; | 69.15% | 36.88% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.pt)
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2717; | 60.86% | 30.03% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.pt)
| MNIST | &#8467;<sub>&infin;</sub> | 0.3 | WRN-28-10 | &#x2717; | 99.26% | 96.34% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/mnist_linf_wrn28-10_without.pt)

### Using the model

Once downloaded, a model can be evaluated (clean accuracy) by running the
`eval.py` script in either the `jax` or `pytorch` folders. E.g.:

```
cd jax
python3 eval.py \
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
```


## Citing this work

If you use this code or these models in your work, please cite the accompanying
paper:

```
@article{gowal2020uncovering,
title={Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples},
author={Gowal, Sven and Qin, Chongli and Uesato, Jonathan and Mann, Timothy and Kohli, Pushmeet},
journal={arXiv preprint arXiv:2010.03593},
year={2020},
url={https://arxiv.org/pdf/2010.03593}
}
```

## Disclaimer

This is not an official Google product.
104 changes: 104 additions & 0 deletions adversarial_robustness/jax/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST."""

from absl import app
from absl import flags
import haiku as hk
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tqdm

from adversarial_robustness.jax import model_zoo

_CKPT = flags.DEFINE_string(
'ckpt', None, 'Path to checkpoint.')
_DATASET = flags.DEFINE_enum(
'dataset', 'cifar10', ['cifar10', 'cifar100', 'mnist'],
'Dataset on which the checkpoint is evaluated.')
_WIDTH = flags.DEFINE_integer(
'width', 16, 'Width of WideResNet.')
_DEPTH = flags.DEFINE_integer(
'depth', 70, 'Depth of WideResNet.')
_BATCH_SIZE = flags.DEFINE_integer(
'batch_size', 100, 'Batch size.')
_NUM_BATCHES = flags.DEFINE_integer(
'num_batches', 0,
'Number of batches to evaluate (zero means the whole dataset).')


def main(unused_argv):
print(f'Loading "{_CKPT.value}"')
print(f'Using a WideResNet with depth {_DEPTH.value} and width '
f'{_WIDTH.value}.')

# Create dataset.
if _DATASET.value == 'mnist':
_, data_test = tf.keras.datasets.mnist.load_data()
normalize_fn = model_zoo.mnist_normalize
elif _DATASET.value == 'cifar10':
_, data_test = tf.keras.datasets.cifar10.load_data()
normalize_fn = model_zoo.cifar10_normalize
else:
assert _DATASET.value == 'cifar100'
_, data_test = tf.keras.datasets.cifar100.load_data()
normalize_fn = model_zoo.cifar100_normalize

# Create model.
@hk.transform_with_state
def model_fn(x, is_training=False):
model = model_zoo.WideResNet(
num_classes=10, depth=_DEPTH.value, width=_WIDTH.value,
activation='swish')
return model(normalize_fn(x), is_training=is_training)

# Build dataset.
images, labels = data_test
samples = (images.astype(np.float32) / 255.,
np.squeeze(labels, axis=-1).astype(np.int64))
data = tf.data.Dataset.from_tensor_slices(samples).batch(_BATCH_SIZE.value)
test_loader = tfds.as_numpy(data)

# Load model parameters.
rng_seq = hk.PRNGSequence(0)
if _CKPT.value == 'dummy':
for images, _ in test_loader:
break
params, state = model_fn.init(next(rng_seq), images, is_training=True)
# Reset iterator.
test_loader = tfds.as_numpy(data)
else:
params, state = np.load(_CKPT.value, allow_pickle=True)

# Evaluation.
correct = 0
total = 0
batch_count = 0
total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value)
for images, labels in tqdm.tqdm(test_loader, total=total_batches):
outputs = model_fn.apply(params, state, next(rng_seq), images)[0]
predicted = np.argmax(outputs, 1)
total += labels.shape[0]
correct += (predicted == labels).sum().item()
batch_count += 1
if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value:
break
print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%')


if __name__ == '__main__':
flags.mark_flag_as_required('ckpt')
app.run(main)
165 changes: 165 additions & 0 deletions adversarial_robustness/jax/model_zoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""WideResNet implementation in JAX using Haiku."""

from typing import Any, Mapping, Optional, Text

import haiku as hk
import jax
import jax.numpy as jnp


CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2471, 0.2435, 0.2616)
CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
CIFAR100_STD = (0.2673, 0.2564, 0.2762)


class _WideResNetBlock(hk.Module):
"""Block of a WideResNet."""

def __init__(self, num_filters, stride=1, projection_shortcut=False,
activation=jax.nn.relu, norm_args=None, name=None):
super().__init__(name=name)
num_bottleneck_layers = 1
self._activation = activation
if norm_args is None:
norm_args = {
'create_offset': False,
'create_scale': True,
'decay_rate': .99,
}
self._bn_modules = []
self._conv_modules = []
for i in range(num_bottleneck_layers + 1):
s = stride if i == 0 else 1
self._bn_modules.append(hk.BatchNorm(
name='batchnorm_{}'.format(i),
**norm_args))
self._conv_modules.append(hk.Conv2D(
output_channels=num_filters,
padding='SAME',
kernel_shape=(3, 3),
stride=s,
with_bias=False,
name='conv_{}'.format(i))) # pytype: disable=not-callable
if projection_shortcut:
self._shortcut = hk.Conv2D(
output_channels=num_filters,
kernel_shape=(1, 1),
stride=stride,
with_bias=False,
name='shortcut') # pytype: disable=not-callable
else:
self._shortcut = None

def __call__(self, inputs, **norm_kwargs):
x = inputs
orig_x = inputs
for i, (bn, conv) in enumerate(zip(self._bn_modules, self._conv_modules)):
x = bn(x, **norm_kwargs)
x = self._activation(x)
if self._shortcut is not None and i == 0:
orig_x = x
x = conv(x)
if self._shortcut is not None:
shortcut_x = self._shortcut(orig_x)
x += shortcut_x
else:
x += orig_x
return x


class WideResNet(hk.Module):
"""WideResNet designed for CIFAR-10."""

def __init__(self,
num_classes: int = 10,
depth: int = 28,
width: int = 10,
activation: Text = 'relu',
norm_args: Optional[Mapping[Text, Any]] = None,
name: Optional[Text] = None):
super(WideResNet, self).__init__(name=name)
if (depth - 4) % 6 != 0:
raise ValueError('depth should be 6n+4.')
self._activation = getattr(jax.nn, activation)
if norm_args is None:
norm_args = {
'create_offset': False,
'create_scale': True,
'decay_rate': .99,
}
self._conv = hk.Conv2D(
output_channels=16,
kernel_shape=(3, 3),
stride=1,
with_bias=False,
name='init_conv') # pytype: disable=not-callable
self._bn = hk.BatchNorm(
name='batchnorm',
**norm_args)
self._linear = hk.Linear(
num_classes,
name='logits')

blocks_per_layer = (depth - 4) // 6
filter_sizes = [width * n for n in [16, 32, 64]]
self._blocks = []
for layer_num, filter_size in enumerate(filter_sizes):
blocks_of_layer = []
for i in range(blocks_per_layer):
stride = 2 if (layer_num != 0 and i == 0) else 1
projection_shortcut = (i == 0)
blocks_of_layer.append(_WideResNetBlock(
num_filters=filter_size,
stride=stride,
projection_shortcut=projection_shortcut,
activation=self._activation,
norm_args=norm_args,
name='resnet_lay_{}_block_{}'.format(layer_num, i)))
self._blocks.append(blocks_of_layer)

def __call__(self, inputs, **norm_kwargs):
net = inputs
net = self._conv(net)

# Blocks.
for blocks_of_layer in self._blocks:
for block in blocks_of_layer:
net = block(net, **norm_kwargs)
net = self._bn(net, **norm_kwargs)
net = self._activation(net)

net = jnp.mean(net, axis=[1, 2])
return self._linear(net)


def mnist_normalize(image: jnp.array) -> jnp.array:
image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant',
constant_values=0)
return (image - .5) * 2.


def cifar10_normalize(image: jnp.array) -> jnp.array:
means = jnp.array(CIFAR10_MEAN, dtype=image.dtype)
stds = jnp.array(CIFAR10_STD, dtype=image.dtype)
return (image - means) / stds


def cifar100_normalize(image: jnp.array) -> jnp.array:
means = jnp.array(CIFAR100_MEAN, dtype=image.dtype)
stds = jnp.array(CIFAR100_STD, dtype=image.dtype)
return (image - means) / stds
Loading

0 comments on commit a6aeb26

Please sign in to comment.