-
Notifications
You must be signed in to change notification settings - Fork 4
/
mnist_example.py
105 lines (89 loc) · 2.79 KB
/
mnist_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from loaders.get_loaders import get_mnist_loader
from models.mnist_net import get_mnist_model
from secmlt.adv.backends import Backends
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.adv.evasion.pgd import PGD
from secmlt.metrics.classification import Accuracy
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from secmlt.trackers.trackers import (
LossTracker,
PerturbationNormTracker,
PredictionTracker,
)
device = "cpu"
model_path = "example_data/models/mnist"
dataset_path = "example_data/datasets/"
net = get_mnist_model(model_path).to(device)
test_loader = get_mnist_loader(dataset_path)
# Wrap model
model = BasePytorchClassifier(net)
# Test accuracy on original data
accuracy = Accuracy()(model, test_loader)
print(f"test accuracy: {accuracy.item():.2f}")
# Create and run attack
epsilon = 1
num_steps = 10
step_size = 0.05
perturbation_model = LpPerturbationModels.LINF
y_target = None
trackers = [
LossTracker(),
PredictionTracker(),
PerturbationNormTracker(perturbation_model),
]
native_attack = PGD(
perturbation_model=perturbation_model,
epsilon=epsilon,
num_steps=num_steps,
step_size=step_size,
random_start=False,
y_target=y_target,
backend=Backends.NATIVE,
trackers=trackers,
)
native_adv_ds = native_attack(model, test_loader)
for tracker in trackers:
print(tracker.name)
print(tracker.get())
# Test accuracy on adversarial examples
n_robust_accuracy = Accuracy()(model, native_adv_ds)
print("robust accuracy native: ", n_robust_accuracy)
# Create and run attack
foolbox_attack = PGD(
perturbation_model=perturbation_model,
epsilon=epsilon,
num_steps=num_steps,
step_size=step_size,
random_start=False,
y_target=y_target,
backend=Backends.FOOLBOX,
)
f_adv_ds = foolbox_attack(model, test_loader)
advlib_attack = PGD(
perturbation_model=perturbation_model,
epsilon=epsilon,
num_steps=num_steps,
step_size=step_size,
random_start=False,
loss_function="dlr",
y_target=y_target,
backend=Backends.ADVLIB,
)
al_adv_ds = advlib_attack(model, test_loader)
# Test accuracy on foolbox
f_robust_accuracy = Accuracy()(model, f_adv_ds)
print("robust accuracy foolbox: ", f_robust_accuracy)
# Test accuracy on adv lib
al_robust_accuracy = Accuracy()(model, al_adv_ds)
print("robust accuracy AdvLib: ", al_robust_accuracy)
native_data, native_labels = next(iter(native_adv_ds))
f_data, f_labels = next(iter(f_adv_ds))
real_data, real_labels = next(iter(test_loader))
distance = torch.linalg.norm(
native_data.detach().cpu().flatten(start_dim=1)
- f_data.detach().cpu().flatten(start_dim=1),
ord=LpPerturbationModels.pert_models[perturbation_model],
dim=1,
)
print("Solutions are :", distance, f"{perturbation_model} distant")