Skip to content

Commit

Permalink
changed interfaces of methods and abstractmethods to allow custom arg…
Browse files Browse the repository at this point in the history
…s and kwargs
maurapintor committed Oct 2, 2024
1 parent 09cf7fa commit 253d7c0
Showing 2 changed files with 12 additions and 10 deletions.
14 changes: 8 additions & 6 deletions src/secmlt/models/base_model.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ def __init__(
)

@abstractmethod
def predict(self, x: torch.Tensor) -> torch.Tensor:
def predict(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Return output predictions for given model.
@@ -52,10 +52,12 @@ def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
...

def decision_function(self, x: torch.Tensor) -> torch.Tensor:
def decision_function(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Return the decision function from the model.
Requires override to specify custom args and kwargs passing.
Parameters
----------
x : torch.Tensor
@@ -71,7 +73,7 @@ def decision_function(self, x: torch.Tensor) -> torch.Tensor:
return self._postprocessing(x)

@abstractmethod
def _decision_function(self, x: torch.Tensor) -> torch.Tensor:
def _decision_function(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Specific decision function of the model (data already preprocessed).
@@ -88,7 +90,7 @@ def _decision_function(self, x: torch.Tensor) -> torch.Tensor:
...

@abstractmethod
def gradient(self, x: torch.Tensor, y: int) -> torch.Tensor:
def gradient(self, x: torch.Tensor, y: int, *args, **kwargs) -> torch.Tensor:
"""
Compute gradients of the score y w.r.t. x.
@@ -118,7 +120,7 @@ def train(self, dataloader: DataLoader) -> "BaseModel":
"""
...

def __call__(self, x: torch.Tensor) -> torch.Tensor:
def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Forward function of the model.
@@ -132,4 +134,4 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor
Model ouptut scores.
"""
return self.decision_function(x)
return self.decision_function(x, *args, **kwargs)
8 changes: 4 additions & 4 deletions src/secmlt/models/data_processing/data_processing.py
Original file line number Diff line number Diff line change
@@ -9,10 +9,10 @@ class DataProcessing(ABC):
"""Abstract data processing class."""

@abstractmethod
def _process(self, x: torch.Tensor) -> torch.Tensor: ...
def _process(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: ...

@abstractmethod
def invert(self, x: torch.Tensor) -> torch.Tensor:
def invert(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Apply the inverted transform (if defined).
@@ -28,7 +28,7 @@ def invert(self, x: torch.Tensor) -> torch.Tensor:
"""
...

def __call__(self, x: torch.Tensor) -> torch.Tensor:
def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Apply the forward transformation.
@@ -42,4 +42,4 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor
The samples after transformation.
"""
return self._process(x)
return self._process(x, *args, **kwargs)

0 comments on commit 253d7c0

Please sign in to comment.