Skip to content

Commit

Permalink
black linting
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Oct 18, 2022
1 parent c055965 commit c2ccf13
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
17 changes: 15 additions & 2 deletions goli/ipu/ipu_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,9 @@ def fbeta_score_ipu(
weights = tp + fn
fbeta = (weights * fbeta).sum() / weights.sum()
else:
raise ValueError(f"`average={average}` not yet supported. Chose between None, Micro, Macro, or Weighted")
raise ValueError(
f"`average={average}` not yet supported. Chose between None, Micro, Macro, or Weighted"
)

return fbeta

Expand Down Expand Up @@ -810,7 +812,18 @@ def f1_score_ipu(
``average`` parameter)
"""

return fbeta_score_ipu(preds, target, beta=beta, average=average, mdmc_average=mdmc_average, ignore_index=ignore_index, num_classes=num_classes, threshold=threshold, top_k=top_k, multiclass=multiclass,)
return fbeta_score_ipu(
preds,
target,
beta=beta,
average=average,
mdmc_average=mdmc_average,
ignore_index=ignore_index,
num_classes=num_classes,
threshold=threshold,
top_k=top_k,
multiclass=multiclass,
)


def mean_squared_error_ipu(preds: Tensor, target: Tensor, squared: bool) -> Tensor:
Expand Down
18 changes: 13 additions & 5 deletions goli/trainer/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def _parse_target_nan_mask(target_nan_mask):
# Only a few str options are accepted
target_nan_mask = target_nan_mask.lower()
accepted_str = ["ignore", "none"]
assert target_nan_mask in accepted_str, f"Provided {target_nan_mask} not in accepted_str={accepted_str}"
assert (
target_nan_mask in accepted_str
), f"Provided {target_nan_mask} not in accepted_str={accepted_str}"

if target_nan_mask == "none":
target_nan_mask = None
Expand All @@ -211,22 +213,26 @@ def _parse_multitask_handling(multitask_handling, target_nan_mask):
Parse the `multitask_handling` parameter
"""

if (multitask_handling is None):
if multitask_handling is None:
# None is accepted
pass
elif isinstance(multitask_handling, str):
# Only a few str options are accepted
multitask_handling = multitask_handling.lower()
accepted_str = ["flatten", "mean-per-label", "none"]
assert multitask_handling in accepted_str, f"Provided {multitask_handling} not in accepted_str={accepted_str}"
assert (
multitask_handling in accepted_str
), f"Provided {multitask_handling} not in accepted_str={accepted_str}"

if multitask_handling == "none":
multitask_handling = None
else:
raise ValueError(f"Unrecognized option `multitask_handling={multitask_handling}`")

if (target_nan_mask == "ignore") and (multitask_handling is None):
raise ValueError("Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`. Use either 'flatten' or 'mean-per-label'")
raise ValueError(
"Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`. Use either 'flatten' or 'mean-per-label'"
)

return multitask_handling

Expand Down Expand Up @@ -261,7 +267,9 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor:

if self.multitask_handling is None:
# In case of no multi-task handling, apply the nan filtering, then compute the metrics
assert self.target_nan_mask != "ignore", f"Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`. Use either 'flatten' or 'mean-per-label'"
assert (
self.target_nan_mask != "ignore"
), f"Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`. Use either 'flatten' or 'mean-per-label'"
preds, target = self._filter_nans(preds, target)
metric_val = self.metric(preds, target, **self.kwargs)
elif self.multitask_handling == "flatten":
Expand Down
14 changes: 8 additions & 6 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def test_target_nan_mask(self):
self.assertAlmostEqual(score4, mean_squared_error(this_preds, this_target), msg=err_msg)

# Ignore NaNs in each column and average the score
metric = MetricWrapper(metric="mse", target_nan_mask="ignore", multitask_handling="mean-per-label")
metric = MetricWrapper(
metric="mse", target_nan_mask="ignore", multitask_handling="mean-per-label"
)
score5 = metric(preds, target)

this_target = target.clone()
Expand Down Expand Up @@ -153,11 +155,11 @@ def test_pickling(self):
# Raise with incompatible options
with self.assertRaises(ValueError):
MetricWrapper(
metric=metric,
threshold_kwargs=threshold_kwargs,
target_nan_mask=target_nan_mask,
multitask_handling=multitask_handling,
**kwargs,
metric=metric,
threshold_kwargs=threshold_kwargs,
target_nan_mask=target_nan_mask,
multitask_handling=multitask_handling,
**kwargs,
)

else:
Expand Down

0 comments on commit c2ccf13

Please sign in to comment.