Skip to content

Commit

Permalink
black formatting upgraded
Browse files Browse the repository at this point in the history
  • Loading branch information
maurapintor committed Feb 20, 2024
1 parent a23583f commit a703b0a
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 25 deletions.
10 changes: 6 additions & 4 deletions secml2/manipulations/manipulation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from typing import Tuple

import torch

Expand All @@ -24,12 +25,13 @@ def _apply_perturbation_constraints(self, delta: torch.Tensor) -> torch.Tensor:
delta = constraint(delta)
return delta

def _apply_manipulation(self, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
...
def _apply_manipulation(
self, x: torch.Tensor, delta: torch.Tensor
) -> torch.Tensor: ...

def __call__(
self, x: torch.Tensor, delta: torch.Tensor
) -> (torch.Tensor, torch.Tensor):
) -> Tuple[torch.Tensor, torch.Tensor]:
delta.data = self._apply_perturbation_constraints(delta.data)
x_adv, delta = self._apply_manipulation(x, delta)
x_adv.data = self._apply_domain_constraints(x_adv.data)
Expand All @@ -39,5 +41,5 @@ def __call__(
class AdditiveManipulation(Manipulation):
def _apply_manipulation(
self, x: torch.Tensor, delta: torch.Tensor
) -> (torch.Tensor, torch.Tensor):
) -> Tuple[torch.Tensor, torch.Tensor]:
return x + delta, delta
12 changes: 4 additions & 8 deletions secml2/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def __init__(
)

@abstractmethod
def predict(self, x: torch.Tensor) -> torch.Tensor:
...
def predict(self, x: torch.Tensor) -> torch.Tensor: ...

def decision_function(self, x: torch.Tensor) -> torch.Tensor:
x = self._preprocessing(x)
Expand All @@ -42,16 +41,13 @@ def decision_function(self, x: torch.Tensor) -> torch.Tensor:
return x

@abstractmethod
def _decision_function(self, x: torch.Tensor) -> torch.Tensor:
...
def _decision_function(self, x: torch.Tensor) -> torch.Tensor: ...

@abstractmethod
def gradient(self, x: torch.Tensor, y: int) -> torch.Tensor:
...
def gradient(self, x: torch.Tensor, y: int) -> torch.Tensor: ...

@abstractmethod
def train(self, dataloader: DataLoader):
...
def train(self, dataloader: DataLoader): ...

def __call__(self, x: torch.Tensor) -> torch.Tensor:
return self.decision_function(x)
6 changes: 2 additions & 4 deletions secml2/models/data_processing/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

class DataProcessing(ABC):
@abstractmethod
def process(self, x: torch.Tensor) -> torch.Tensor:
...
def process(self, x: torch.Tensor) -> torch.Tensor: ...

def invert(self, x: torch.Tensor) -> torch.Tensor:
...
def invert(self, x: torch.Tensor) -> torch.Tensor: ...

def __call__(self, x: torch.Tensor) -> torch.Tensor:
return self.process(x)
7 changes: 2 additions & 5 deletions secml2/optimization/constraints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import math
from abc import abstractmethod

import torch


class Constraint:
def __call__(self, x: torch.Tensor, *args, **kwargs):
...
def __call__(self, x: torch.Tensor, *args, **kwargs): ...


class ClipConstraint(Constraint):
Expand All @@ -25,8 +23,7 @@ def __init__(self, radius=0, center=0, p=torch.inf):
self.radius = radius

@abstractmethod
def project(self, x: torch.Tensor) -> torch.Tensor:
...
def project(self, x: torch.Tensor) -> torch.Tensor: ...

def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = x + self.center
Expand Down
5 changes: 1 addition & 4 deletions secml2/optimization/gradient_processing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import math

import torch.linalg
from torch.nn.functional import normalize

from secml2.adv.evasion.perturbation_models import PerturbationModels


class GradientProcessing:
def __call__(self, grad: torch.Tensor) -> torch.Tensor:
...
def __call__(self, grad: torch.Tensor) -> torch.Tensor: ...


class LinearProjectionGradientProcessing(GradientProcessing):
Expand Down

0 comments on commit a703b0a

Please sign in to comment.