Skip to content

Commit

Permalink
Merge pull request #92 from pralab/91-full-support-for-backdoor-attacks
Browse files Browse the repository at this point in the history
91 support for backdoor attacks
  • Loading branch information
zangobot authored Oct 4, 2024
2 parents 1958f73 + a9bf8b9 commit 0fa051e
Show file tree
Hide file tree
Showing 12 changed files with 366 additions and 5 deletions.
21 changes: 21 additions & 0 deletions docs/source/secmlt.adv.backdoor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
secmlt.adv.backdoor package
===========================

Submodules
----------

secmlt.adv.backdoor.base\_pytorch\_backdoor module
--------------------------------------------------

.. automodule:: secmlt.adv.backdoor.base_pytorch_backdoor
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

.. automodule:: secmlt.adv.backdoor
:members:
:undoc-members:
:show-inheritance:
29 changes: 29 additions & 0 deletions docs/source/secmlt.adv.poisoning.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
secmlt.adv.poisoning package
============================

Submodules
----------

secmlt.adv.poisoning.backdoor module
------------------------------------

.. automodule:: secmlt.adv.poisoning.backdoor
:members:
:undoc-members:
:show-inheritance:

secmlt.adv.poisoning.base\_data\_poisoning module
-------------------------------------------------

.. automodule:: secmlt.adv.poisoning.base_data_poisoning
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

.. automodule:: secmlt.adv.poisoning
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/secmlt.adv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Subpackages
:maxdepth: 4

secmlt.adv.evasion
secmlt.adv.poisoning

Submodules
----------
Expand Down
8 changes: 8 additions & 0 deletions docs/source/secmlt.tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ secmlt.tests.test\_attacks module
:undoc-members:
:show-inheritance:

secmlt.tests.test\_backdoors module
-----------------------------------

.. automodule:: secmlt.tests.test_backdoors
:members:
:undoc-members:
:show-inheritance:

secmlt.tests.test\_constants module
-----------------------------------

Expand Down
60 changes: 60 additions & 0 deletions examples/backdoor_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torchvision.datasets
from models.mnist_net import MNISTNet
from secmlt.adv.poisoning.backdoor import BackdoorDatasetPyTorch
from secmlt.metrics.classification import Accuracy, AttackSuccessRate
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer
from torch.optim import Adam
from torch.utils.data import DataLoader


def apply_patch(x: torch.Tensor) -> torch.Tensor:
x[:, 0, 24:28, 24:28] = 1.0
return x


dataset_path = "example_data/datasets/"
device = "cpu"
net = MNISTNet()
net.to(device)
optimizer = Adam(lr=1e-3, params=net.parameters())
training_dataset = torchvision.datasets.MNIST(
transform=torchvision.transforms.ToTensor(),
train=True,
root=dataset_path,
download=True,
)
target_label = 1
backdoored_mnist = BackdoorDatasetPyTorch(
training_dataset,
data_manipulation_func=apply_patch,
trigger_label=target_label,
portion=0.1,
)

training_data_loader = DataLoader(backdoored_mnist, batch_size=20, shuffle=False)
test_dataset = torchvision.datasets.MNIST(
transform=torchvision.transforms.ToTensor(),
train=False,
root=dataset_path,
download=True,
)
test_data_loader = DataLoader(test_dataset, batch_size=20, shuffle=False)

trainer = BasePyTorchTrainer(optimizer, epochs=5)
model = BasePytorchClassifier(net, trainer=trainer)
model.train(training_data_loader)

# test accuracy without backdoor
accuracy = Accuracy()(model, test_data_loader)
print("test accuracy: ", accuracy)

# test accuracy on backdoored dataset
backdoored_test_set = BackdoorDatasetPyTorch(
test_dataset, data_manipulation_func=apply_patch
)
backdoored_loader = DataLoader(backdoored_test_set, batch_size=20, shuffle=False)

asr = AttackSuccessRate(y_target=target_label)(model, backdoored_loader)
print(f"asr: {asr}")
53 changes: 53 additions & 0 deletions examples/label_flipping_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torchvision.datasets
from models.mnist_net import MNISTNet
from secmlt.adv.poisoning.base_data_poisoning import PoisoningDatasetPyTorch
from secmlt.metrics.classification import Accuracy
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer
from torch.optim import Adam
from torch.utils.data import DataLoader


def flip_label(label):
return 0 if label != 0 else 1


dataset_path = "example_data/datasets/"
device = "cpu"
net = MNISTNet()
net.to(device)
optimizer = Adam(lr=1e-3, params=net.parameters())
training_dataset = torchvision.datasets.MNIST(
transform=torchvision.transforms.ToTensor(),
train=True,
root=dataset_path,
download=True,
)
target_label = 1
poisoned_mnist = PoisoningDatasetPyTorch(
training_dataset,
label_manipulation_func=flip_label,
portion=0.4,
)

training_data_loader = DataLoader(training_dataset, batch_size=20, shuffle=False)
poisoned_data_loader = DataLoader(poisoned_mnist, batch_size=20, shuffle=False)

test_dataset = torchvision.datasets.MNIST(
transform=torchvision.transforms.ToTensor(),
train=False,
root=dataset_path,
download=True,
)
test_data_loader = DataLoader(test_dataset, batch_size=20, shuffle=False)

for k, data_loader in {
"normal": training_data_loader,
"poisoned": poisoned_data_loader,
}.items():
trainer = BasePyTorchTrainer(optimizer, epochs=3)
model = BasePytorchClassifier(net, trainer=trainer)
model.train(data_loader)
# test accuracy without backdoor
accuracy = Accuracy()(model, test_data_loader)
print(f"test accuracy on {k} data: {accuracy.item():.3f}")
3 changes: 2 additions & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ ignore = [
"FBT002", # boolean type default argument
"COM812", # flake8-commas "Trailing comma missing"
"ISC001", # implicitly concatenated string literals on one line
"UP007"
"UP007", # conflict non-pep8 annotations
"S311" # random generator not suitable for cryptographic purposes
]

[lint.per-file-ignores]
Expand Down
1 change: 1 addition & 0 deletions src/secmlt/adv/poisoning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Backdoor attacks."""
41 changes: 41 additions & 0 deletions src/secmlt/adv/poisoning/backdoor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Simple backdoor attack in PyTorch."""

import torch
from secmlt.adv.poisoning.base_data_poisoning import PoisoningDatasetPyTorch
from torch.utils.data import Dataset


class BackdoorDatasetPyTorch(PoisoningDatasetPyTorch):
"""Dataset class for adding triggers for backdoor attacks."""

def __init__(
self,
dataset: Dataset,
data_manipulation_func: callable,
trigger_label: int = 0,
portion: float | None = None,
poisoned_indexes: list[int] | torch.Tensor = None,
) -> None:
"""
Create the backdoored dataset.
Parameters
----------
dataset : torch.utils.data.Dataset
PyTorch dataset.
data_manipulation_func: callable
Function to manipulate the data and add the backdoor.
trigger_label : int, optional
Label to associate with the backdoored data (default 0).
portion : float, optional
Percentage of samples on which the backdoor will be injected (default 0.1).
poisoned_indexes: list[int] | torch.Tensor
Specific indexes of samples to perturb. Alternative to portion.
"""
super().__init__(
dataset=dataset,
data_manipulation_func=data_manipulation_func,
label_manipulation_func=lambda _: trigger_label,
portion=portion,
poisoned_indexes=poisoned_indexes,
)
88 changes: 88 additions & 0 deletions src/secmlt/adv/poisoning/base_data_poisoning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Base class for data poisoning."""

import random

import torch
from torch.utils.data import Dataset


class PoisoningDatasetPyTorch(Dataset):
"""Dataset class for adding poisoning samples."""

def __init__(
self,
dataset: Dataset,
data_manipulation_func: callable = lambda x: x,
label_manipulation_func: callable = lambda x: x,
portion: float | None = None,
poisoned_indexes: list[int] | torch.Tensor = None,
) -> None:
"""
Create the poisoned dataset.
Parameters
----------
dataset : torch.utils.data.Dataset
PyTorch dataset.
data_manipulation_func : callable
Function that manipulates the data.
label_manipulation_func: callable
Function that returns the label to associate with the poisoned data.
portion : float, optional
Percentage of samples on which the poisoning will be injected (default 0.1).
poisoned_indexes: list[int] | torch.Tensor
Specific indexes of samples to perturb. Alternative to portion.
"""
self.dataset = dataset
self.data_len = len(dataset)
if portion is not None:
if poisoned_indexes is not None:
msg = "Specify either portion or poisoned_indexes, not both."
raise ValueError(msg)
if portion < 0.0 or portion > 1.0:
msg = f"Poison ratio should be between 0.0 and 1.0. Passed {portion}."
raise ValueError(msg)
# calculate number of samples to poison
num_poisoned_samples = int(portion * self.data_len)

# randomly select indices to poison
self.poisoned_indexes = set(
random.sample(range(self.data_len), num_poisoned_samples)
)
elif poisoned_indexes is not None:
self.poisoned_indexes = poisoned_indexes
else:
self.poisoned_indexes = range(self.data_len)

self.data_manipulation_func = data_manipulation_func
self.label_manipulation_func = label_manipulation_func

def __len__(self) -> int:
"""Get number of samples."""
return self.data_len

def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
"""
Get item from the dataset.
Parameters
----------
idx : int
Index of the item to return
Returns
-------
tuple[torch.Tensor, int]
Item at position specified by idx.
"""
x, label = self.dataset[idx]
# poison portion of the data
if idx in self.poisoned_indexes:
x = self.data_manipulation_func(x=x.unsqueeze(0)).squeeze(0)
target_label = self.label_manipulation_func(label)
label = (
target_label
if isinstance(label, int)
else torch.Tensor(target_label).type(label.dtype)
)
return x, label
19 changes: 15 additions & 4 deletions src/secmlt/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,30 @@


@pytest.fixture
def data_loader() -> DataLoader[tuple[torch.Tensor]]:
def dataset() -> TensorDataset:
"""Create fake dataset."""
data = torch.randn(100, 3, 32, 32).clamp(0, 1)
labels = torch.randint(0, 10, (100,))
return TensorDataset(data, labels)


@pytest.fixture
def data_loader(dataset: TensorDataset) -> DataLoader[tuple[torch.Tensor]]:
"""
Create fake data loader.
Parameters
----------
dataset : TensorDataset
Dataset to wrap in the loader
Returns
-------
DataLoader[tuple[torch.Tensor]]
A loader with random samples and labels.
"""
# Create a dummy dataset loader for testing
data = torch.randn(100, 3, 32, 32).clamp(0, 1)
labels = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, labels)
return DataLoader(dataset, batch_size=10)


Expand Down
Loading

0 comments on commit 0fa051e

Please sign in to comment.