-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #92 from pralab/91-full-support-for-backdoor-attacks
91 support for backdoor attacks
- Loading branch information
Showing
12 changed files
with
366 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
secmlt.adv.backdoor package | ||
=========================== | ||
|
||
Submodules | ||
---------- | ||
|
||
secmlt.adv.backdoor.base\_pytorch\_backdoor module | ||
-------------------------------------------------- | ||
|
||
.. automodule:: secmlt.adv.backdoor.base_pytorch_backdoor | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
Module contents | ||
--------------- | ||
|
||
.. automodule:: secmlt.adv.backdoor | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
secmlt.adv.poisoning package | ||
============================ | ||
|
||
Submodules | ||
---------- | ||
|
||
secmlt.adv.poisoning.backdoor module | ||
------------------------------------ | ||
|
||
.. automodule:: secmlt.adv.poisoning.backdoor | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
secmlt.adv.poisoning.base\_data\_poisoning module | ||
------------------------------------------------- | ||
|
||
.. automodule:: secmlt.adv.poisoning.base_data_poisoning | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
Module contents | ||
--------------- | ||
|
||
.. automodule:: secmlt.adv.poisoning | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ Subpackages | |
:maxdepth: 4 | ||
|
||
secmlt.adv.evasion | ||
secmlt.adv.poisoning | ||
|
||
Submodules | ||
---------- | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import torchvision.datasets | ||
from models.mnist_net import MNISTNet | ||
from secmlt.adv.poisoning.backdoor import BackdoorDatasetPyTorch | ||
from secmlt.metrics.classification import Accuracy, AttackSuccessRate | ||
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier | ||
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer | ||
from torch.optim import Adam | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def apply_patch(x: torch.Tensor) -> torch.Tensor: | ||
x[:, 0, 24:28, 24:28] = 1.0 | ||
return x | ||
|
||
|
||
dataset_path = "example_data/datasets/" | ||
device = "cpu" | ||
net = MNISTNet() | ||
net.to(device) | ||
optimizer = Adam(lr=1e-3, params=net.parameters()) | ||
training_dataset = torchvision.datasets.MNIST( | ||
transform=torchvision.transforms.ToTensor(), | ||
train=True, | ||
root=dataset_path, | ||
download=True, | ||
) | ||
target_label = 1 | ||
backdoored_mnist = BackdoorDatasetPyTorch( | ||
training_dataset, | ||
data_manipulation_func=apply_patch, | ||
trigger_label=target_label, | ||
portion=0.1, | ||
) | ||
|
||
training_data_loader = DataLoader(backdoored_mnist, batch_size=20, shuffle=False) | ||
test_dataset = torchvision.datasets.MNIST( | ||
transform=torchvision.transforms.ToTensor(), | ||
train=False, | ||
root=dataset_path, | ||
download=True, | ||
) | ||
test_data_loader = DataLoader(test_dataset, batch_size=20, shuffle=False) | ||
|
||
trainer = BasePyTorchTrainer(optimizer, epochs=5) | ||
model = BasePytorchClassifier(net, trainer=trainer) | ||
model.train(training_data_loader) | ||
|
||
# test accuracy without backdoor | ||
accuracy = Accuracy()(model, test_data_loader) | ||
print("test accuracy: ", accuracy) | ||
|
||
# test accuracy on backdoored dataset | ||
backdoored_test_set = BackdoorDatasetPyTorch( | ||
test_dataset, data_manipulation_func=apply_patch | ||
) | ||
backdoored_loader = DataLoader(backdoored_test_set, batch_size=20, shuffle=False) | ||
|
||
asr = AttackSuccessRate(y_target=target_label)(model, backdoored_loader) | ||
print(f"asr: {asr}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torchvision.datasets | ||
from models.mnist_net import MNISTNet | ||
from secmlt.adv.poisoning.base_data_poisoning import PoisoningDatasetPyTorch | ||
from secmlt.metrics.classification import Accuracy | ||
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier | ||
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer | ||
from torch.optim import Adam | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def flip_label(label): | ||
return 0 if label != 0 else 1 | ||
|
||
|
||
dataset_path = "example_data/datasets/" | ||
device = "cpu" | ||
net = MNISTNet() | ||
net.to(device) | ||
optimizer = Adam(lr=1e-3, params=net.parameters()) | ||
training_dataset = torchvision.datasets.MNIST( | ||
transform=torchvision.transforms.ToTensor(), | ||
train=True, | ||
root=dataset_path, | ||
download=True, | ||
) | ||
target_label = 1 | ||
poisoned_mnist = PoisoningDatasetPyTorch( | ||
training_dataset, | ||
label_manipulation_func=flip_label, | ||
portion=0.4, | ||
) | ||
|
||
training_data_loader = DataLoader(training_dataset, batch_size=20, shuffle=False) | ||
poisoned_data_loader = DataLoader(poisoned_mnist, batch_size=20, shuffle=False) | ||
|
||
test_dataset = torchvision.datasets.MNIST( | ||
transform=torchvision.transforms.ToTensor(), | ||
train=False, | ||
root=dataset_path, | ||
download=True, | ||
) | ||
test_data_loader = DataLoader(test_dataset, batch_size=20, shuffle=False) | ||
|
||
for k, data_loader in { | ||
"normal": training_data_loader, | ||
"poisoned": poisoned_data_loader, | ||
}.items(): | ||
trainer = BasePyTorchTrainer(optimizer, epochs=3) | ||
model = BasePytorchClassifier(net, trainer=trainer) | ||
model.train(data_loader) | ||
# test accuracy without backdoor | ||
accuracy = Accuracy()(model, test_data_loader) | ||
print(f"test accuracy on {k} data: {accuracy.item():.3f}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Backdoor attacks.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Simple backdoor attack in PyTorch.""" | ||
|
||
import torch | ||
from secmlt.adv.poisoning.base_data_poisoning import PoisoningDatasetPyTorch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class BackdoorDatasetPyTorch(PoisoningDatasetPyTorch): | ||
"""Dataset class for adding triggers for backdoor attacks.""" | ||
|
||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
data_manipulation_func: callable, | ||
trigger_label: int = 0, | ||
portion: float | None = None, | ||
poisoned_indexes: list[int] | torch.Tensor = None, | ||
) -> None: | ||
""" | ||
Create the backdoored dataset. | ||
Parameters | ||
---------- | ||
dataset : torch.utils.data.Dataset | ||
PyTorch dataset. | ||
data_manipulation_func: callable | ||
Function to manipulate the data and add the backdoor. | ||
trigger_label : int, optional | ||
Label to associate with the backdoored data (default 0). | ||
portion : float, optional | ||
Percentage of samples on which the backdoor will be injected (default 0.1). | ||
poisoned_indexes: list[int] | torch.Tensor | ||
Specific indexes of samples to perturb. Alternative to portion. | ||
""" | ||
super().__init__( | ||
dataset=dataset, | ||
data_manipulation_func=data_manipulation_func, | ||
label_manipulation_func=lambda _: trigger_label, | ||
portion=portion, | ||
poisoned_indexes=poisoned_indexes, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Base class for data poisoning.""" | ||
|
||
import random | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class PoisoningDatasetPyTorch(Dataset): | ||
"""Dataset class for adding poisoning samples.""" | ||
|
||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
data_manipulation_func: callable = lambda x: x, | ||
label_manipulation_func: callable = lambda x: x, | ||
portion: float | None = None, | ||
poisoned_indexes: list[int] | torch.Tensor = None, | ||
) -> None: | ||
""" | ||
Create the poisoned dataset. | ||
Parameters | ||
---------- | ||
dataset : torch.utils.data.Dataset | ||
PyTorch dataset. | ||
data_manipulation_func : callable | ||
Function that manipulates the data. | ||
label_manipulation_func: callable | ||
Function that returns the label to associate with the poisoned data. | ||
portion : float, optional | ||
Percentage of samples on which the poisoning will be injected (default 0.1). | ||
poisoned_indexes: list[int] | torch.Tensor | ||
Specific indexes of samples to perturb. Alternative to portion. | ||
""" | ||
self.dataset = dataset | ||
self.data_len = len(dataset) | ||
if portion is not None: | ||
if poisoned_indexes is not None: | ||
msg = "Specify either portion or poisoned_indexes, not both." | ||
raise ValueError(msg) | ||
if portion < 0.0 or portion > 1.0: | ||
msg = f"Poison ratio should be between 0.0 and 1.0. Passed {portion}." | ||
raise ValueError(msg) | ||
# calculate number of samples to poison | ||
num_poisoned_samples = int(portion * self.data_len) | ||
|
||
# randomly select indices to poison | ||
self.poisoned_indexes = set( | ||
random.sample(range(self.data_len), num_poisoned_samples) | ||
) | ||
elif poisoned_indexes is not None: | ||
self.poisoned_indexes = poisoned_indexes | ||
else: | ||
self.poisoned_indexes = range(self.data_len) | ||
|
||
self.data_manipulation_func = data_manipulation_func | ||
self.label_manipulation_func = label_manipulation_func | ||
|
||
def __len__(self) -> int: | ||
"""Get number of samples.""" | ||
return self.data_len | ||
|
||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]: | ||
""" | ||
Get item from the dataset. | ||
Parameters | ||
---------- | ||
idx : int | ||
Index of the item to return | ||
Returns | ||
------- | ||
tuple[torch.Tensor, int] | ||
Item at position specified by idx. | ||
""" | ||
x, label = self.dataset[idx] | ||
# poison portion of the data | ||
if idx in self.poisoned_indexes: | ||
x = self.data_manipulation_func(x=x.unsqueeze(0)).squeeze(0) | ||
target_label = self.label_manipulation_func(label) | ||
label = ( | ||
target_label | ||
if isinstance(label, int) | ||
else torch.Tensor(target_label).type(label.dtype) | ||
) | ||
return x, label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.