Skip to content

Commit

Permalink
foolbox imports conditional to foolbox installation
Browse files Browse the repository at this point in the history
  • Loading branch information
maurapintor committed Feb 20, 2024
1 parent eb14ac9 commit 28157c1
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 51 deletions.
2 changes: 2 additions & 0 deletions secml2/adv/evasion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


10 changes: 9 additions & 1 deletion secml2/adv/evasion/base_evasion_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ def check_perturbation_model_available(perturbation_model: str) -> bool:
if not PerturbationModels.is_perturbation_model_available(perturbation_model):
raise NotImplementedError("Unsupported or not-implemented threat model.")

def get_foolbox_implementation(self):
try:
import foolbox
except ImportError:
raise ImportError("Foolbox extra not installed.")
else:
return self._get_foolbox_implementation()

@staticmethod
def get_foolbox_implementation():
def _get_foolbox_implementation():
raise NotImplementedError("Foolbox implementation not available.")

@staticmethod
Expand Down
6 changes: 6 additions & 0 deletions secml2/adv/evasion/foolbox_attacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
import foolbox
except ImportError:
pass # foolbox is an extra component and requires the foolbox library
else:
from .foolbox_pgd import *
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Optional
from secml2.adv.evasion.base_evasion_attack import BaseEvasionAttack
from foolbox.attacks.base import Attack
from torch.utils.data import DataLoader
from secml2.models.base_model import BaseModel
from secml2.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from secml2.models.base_model import BaseModel
from foolbox.models.pytorch import PyTorchModel
from foolbox.criteria import Misclassification, TargetedMisclassification
import torch
from torch.utils.data import TensorDataset


class BaseFoolboxEvasionAttack(BaseEvasionAttack):
Expand Down
46 changes: 46 additions & 0 deletions secml2/adv/evasion/foolbox_attacks/foolbox_pgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

from typing import Optional
from secml2.adv.evasion.foolbox_attacks.foolbox_base import BaseFoolboxEvasionAttack
from secml2.adv.evasion.perturbation_models import PerturbationModels

from foolbox.attacks.projected_gradient_descent import (
L1ProjectedGradientDescentAttack,
L2ProjectedGradientDescentAttack,
LinfProjectedGradientDescentAttack,
)

class PGDFoolbox(BaseFoolboxEvasionAttack):
def __init__(
self,
perturbation_model: str,
epsilon: float,
num_steps: int,
step_size: float,
random_start: bool,
y_target: Optional[int] = None,
lb: float = 0.0,
ub: float = 1.0,
**kwargs
) -> None:
perturbation_models = {
PerturbationModels.L1: L1ProjectedGradientDescentAttack,
PerturbationModels.L2: L2ProjectedGradientDescentAttack,
PerturbationModels.LINF: LinfProjectedGradientDescentAttack,
}
foolbox_attack_cls = perturbation_models.get(perturbation_model, None)
if foolbox_attack_cls is None:
raise NotImplementedError(
"This threat model is not implemented in foolbox."
)

foolbox_attack = foolbox_attack_cls(
abs_stepsize=step_size, steps=num_steps, random_start=random_start
)

super().__init__(
foolbox_attack=foolbox_attack,
epsilon=epsilon,
y_target=y_target,
lb=lb,
ub=ub,
)
53 changes: 6 additions & 47 deletions secml2/adv/evasion/pgd.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
from typing import Optional, List

from foolbox.attacks.projected_gradient_descent import (
L1ProjectedGradientDescentAttack,
L2ProjectedGradientDescentAttack,
LinfProjectedGradientDescentAttack,
)

from secml2.adv.backends import Backends
from secml2.adv.evasion.base_evasion_attack import (
BaseEvasionAttackCreator,
)
from secml2.adv.evasion.composite_attack import CompositeEvasionAttack, CE_LOSS
from secml2.adv.evasion.foolbox import BaseFoolboxEvasionAttack
from secml2.adv.evasion.foolbox_attacks.foolbox_base import BaseFoolboxEvasionAttack
from secml2.adv.evasion.perturbation_models import PerturbationModels
from secml2.manipulations.manipulation import AdditiveManipulation
from secml2.optimization.constraints import (
ClipConstraint,
L1Constraint,
L2Constraint,
LInfConstraint,
Constraint,
)
from secml2.optimization.gradient_processing import LinearProjectionGradientProcessing
from secml2.optimization.initializer import Initializer
Expand Down Expand Up @@ -55,51 +48,17 @@ def __new__(
)

@staticmethod
def get_foolbox_implementation():
def _get_foolbox_implementation():
try:
from .foolbox_attacks.foolbox_pgd import PGDFoolbox
except ImportError:
raise ImportError("Foolbox extra not installed")
return PGDFoolbox

@staticmethod
def get_native_implementation():
return PGDNative


class PGDFoolbox(BaseFoolboxEvasionAttack):
def __init__(
self,
perturbation_model: str,
epsilon: float,
num_steps: int,
step_size: float,
random_start: bool,
y_target: Optional[int] = None,
lb: float = 0.0,
ub: float = 1.0,
**kwargs
) -> None:
perturbation_models = {
PerturbationModels.L1: L1ProjectedGradientDescentAttack,
PerturbationModels.L2: L2ProjectedGradientDescentAttack,
PerturbationModels.LINF: LinfProjectedGradientDescentAttack,
}
foolbox_attack_cls = perturbation_models.get(perturbation_model, None)
if foolbox_attack_cls is None:
raise NotImplementedError(
"This threat model is not implemented in foolbox."
)

foolbox_attack = foolbox_attack_cls(
abs_stepsize=step_size, steps=num_steps, random_start=random_start
)

super().__init__(
foolbox_attack=foolbox_attack,
epsilon=epsilon,
y_target=y_target,
lb=lb,
ub=ub,
)


class PGDNative(CompositeEvasionAttack):
def __init__(
self,
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
author_email="[email protected], [email protected]",
install_requires=[],
extras_require={
"pytorch": ["torch>=1.4,!=1.5.*", "torchvision>=0.5,!=0.6.*"],
"foolbox": ["foolbox>=3.3.0", "torch>=1.4,!=1.5.*", "torchvision>=0.5,!=0.6.*"],
},
python_requires=">=3.7"
Expand Down

0 comments on commit 28157c1

Please sign in to comment.