From a703b0aea090025beedc06e7cdeed1d626646299 Mon Sep 17 00:00:00 2001 From: maurapintor Date: Tue, 20 Feb 2024 18:02:04 +0100 Subject: [PATCH] black formatting upgraded --- secml2/manipulations/manipulation.py | 10 ++++++---- secml2/models/base_model.py | 12 ++++-------- secml2/models/data_processing/data_processing.py | 6 ++---- secml2/optimization/constraints.py | 7 ++----- secml2/optimization/gradient_processing.py | 5 +---- 5 files changed, 15 insertions(+), 25 deletions(-) diff --git a/secml2/manipulations/manipulation.py b/secml2/manipulations/manipulation.py index 07b4a0f..2af07d4 100644 --- a/secml2/manipulations/manipulation.py +++ b/secml2/manipulations/manipulation.py @@ -1,4 +1,5 @@ from abc import ABC +from typing import Tuple import torch @@ -24,12 +25,13 @@ def _apply_perturbation_constraints(self, delta: torch.Tensor) -> torch.Tensor: delta = constraint(delta) return delta - def _apply_manipulation(self, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: - ... + def _apply_manipulation( + self, x: torch.Tensor, delta: torch.Tensor + ) -> torch.Tensor: ... def __call__( self, x: torch.Tensor, delta: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> Tuple[torch.Tensor, torch.Tensor]: delta.data = self._apply_perturbation_constraints(delta.data) x_adv, delta = self._apply_manipulation(x, delta) x_adv.data = self._apply_domain_constraints(x_adv.data) @@ -39,5 +41,5 @@ def __call__( class AdditiveManipulation(Manipulation): def _apply_manipulation( self, x: torch.Tensor, delta: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> Tuple[torch.Tensor, torch.Tensor]: return x + delta, delta diff --git a/secml2/models/base_model.py b/secml2/models/base_model.py index 4840be1..f1cb06f 100644 --- a/secml2/models/base_model.py +++ b/secml2/models/base_model.py @@ -32,8 +32,7 @@ def __init__( ) @abstractmethod - def predict(self, x: torch.Tensor) -> torch.Tensor: - ... + def predict(self, x: torch.Tensor) -> torch.Tensor: ... def decision_function(self, x: torch.Tensor) -> torch.Tensor: x = self._preprocessing(x) @@ -42,16 +41,13 @@ def decision_function(self, x: torch.Tensor) -> torch.Tensor: return x @abstractmethod - def _decision_function(self, x: torch.Tensor) -> torch.Tensor: - ... + def _decision_function(self, x: torch.Tensor) -> torch.Tensor: ... @abstractmethod - def gradient(self, x: torch.Tensor, y: int) -> torch.Tensor: - ... + def gradient(self, x: torch.Tensor, y: int) -> torch.Tensor: ... @abstractmethod - def train(self, dataloader: DataLoader): - ... + def train(self, dataloader: DataLoader): ... def __call__(self, x: torch.Tensor) -> torch.Tensor: return self.decision_function(x) diff --git a/secml2/models/data_processing/data_processing.py b/secml2/models/data_processing/data_processing.py index 335b609..2ea30c2 100644 --- a/secml2/models/data_processing/data_processing.py +++ b/secml2/models/data_processing/data_processing.py @@ -5,11 +5,9 @@ class DataProcessing(ABC): @abstractmethod - def process(self, x: torch.Tensor) -> torch.Tensor: - ... + def process(self, x: torch.Tensor) -> torch.Tensor: ... - def invert(self, x: torch.Tensor) -> torch.Tensor: - ... + def invert(self, x: torch.Tensor) -> torch.Tensor: ... def __call__(self, x: torch.Tensor) -> torch.Tensor: return self.process(x) diff --git a/secml2/optimization/constraints.py b/secml2/optimization/constraints.py index 2ca9bff..30ac6ad 100644 --- a/secml2/optimization/constraints.py +++ b/secml2/optimization/constraints.py @@ -1,12 +1,10 @@ -import math from abc import abstractmethod import torch class Constraint: - def __call__(self, x: torch.Tensor, *args, **kwargs): - ... + def __call__(self, x: torch.Tensor, *args, **kwargs): ... class ClipConstraint(Constraint): @@ -25,8 +23,7 @@ def __init__(self, radius=0, center=0, p=torch.inf): self.radius = radius @abstractmethod - def project(self, x: torch.Tensor) -> torch.Tensor: - ... + def project(self, x: torch.Tensor) -> torch.Tensor: ... def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x = x + self.center diff --git a/secml2/optimization/gradient_processing.py b/secml2/optimization/gradient_processing.py index b77f0df..b4039b0 100644 --- a/secml2/optimization/gradient_processing.py +++ b/secml2/optimization/gradient_processing.py @@ -1,5 +1,3 @@ -import math - import torch.linalg from torch.nn.functional import normalize @@ -7,8 +5,7 @@ class GradientProcessing: - def __call__(self, grad: torch.Tensor) -> torch.Tensor: - ... + def __call__(self, grad: torch.Tensor) -> torch.Tensor: ... class LinearProjectionGradientProcessing(GradientProcessing):