Skip to content

Commit

Permalink
Add functional code for multiclass AUPRC
Browse files Browse the repository at this point in the history
Summary: Adding functional code for the multiclass area under precision recall curve, also called average precision.

Reviewed By: ananthsub

Differential Revision: D41512016

fbshipit-source-id: e2361e95cd521e2d4f6a5e1a13bb77f358c3836c
  • Loading branch information
bobakfb authored and facebook-github-bot committed Nov 30, 2022
1 parent 022b906 commit e62e143
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 7 deletions.
202 changes: 196 additions & 6 deletions tests/metrics/functional/classification/test_auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import unittest
from typing import Optional, Tuple

import numpy as np

import torch
from sklearn.metrics import average_precision_score as sk_ap
from torcheval.metrics.functional import binary_auprc
from torcheval.utils.test_utils.metric_class_tester import (
BATCH_SIZE,
# NUM_PROCESSES,
# NUM_TOTAL_UPDATES,
)
from torcheval.metrics.functional import binary_auprc, multiclass_auprc
from torcheval.utils.test_utils.metric_class_tester import BATCH_SIZE


class TestBinaryAUPRC(unittest.TestCase):
Expand Down Expand Up @@ -149,3 +147,195 @@ def test_binary_auprc_invalid_input(self) -> None:
r"torch.Size\(\[4, 5\]\), target: torch.Size\(\[4, 5\]\).",
):
binary_auprc(torch.rand(4, 5), torch.rand(4, 5), num_tasks=2)


class TestMulticlassAUPRC(unittest.TestCase):
def _get_sklearn_equivalent(
self, input: torch.Tensor, target: torch.Tensor, device: str = "cpu"
) -> torch.Tensor:
# Convert input/target to sklearn style inputs
# run each task once at a time since no multi-task/multiclass
# available for sklearn
skinputs = input.numpy()
sktargets = target.numpy()
auprcs = []
for i in range(input.shape[1]):
skinput = skinputs[:, i]
sktarget = np.where(sktargets == i, 1, 0)
auprcs.append(sk_ap(sktarget, skinput))
return torch.tensor(auprcs, device=device).to(torch.float32)

def _test_multiclass_auprc_with_input(
self,
input: torch.Tensor,
target: torch.Tensor,
num_classes: int,
compute_result: Optional[torch.Tensor] = None,
) -> None:

device = "cpu"
if torch.cuda.is_available():
device = "cuda"

# get sklearn compute result if none given
if compute_result is None:
compute_result = self._get_sklearn_equivalent(input, target, device)

# Get torcheval compute result
te_compute_result = multiclass_auprc(
input.to(device=device),
target.to(device=device),
num_classes=num_classes,
average=None,
)

# test no average
torch.testing.assert_close(
te_compute_result,
compute_result,
equal_nan=True,
atol=1e-8,
rtol=1e-5,
)

def _get_rand_inputs_multiclass(
self, num_classes: int, batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
input = torch.rand(size=[batch_size, num_classes])
targets = torch.randint(low=0, high=num_classes, size=[batch_size])
return input, targets

def test_multiclass_auprc_with_good_input(self) -> None:
num_classes = 2
input, target = self._get_rand_inputs_multiclass(num_classes, BATCH_SIZE)
self._test_multiclass_auprc_with_input(input, target, num_classes)

num_classes = 4
input, target = self._get_rand_inputs_multiclass(num_classes, BATCH_SIZE)
self._test_multiclass_auprc_with_input(input, target, num_classes)

num_classes = 5
input, target = self._get_rand_inputs_multiclass(num_classes, BATCH_SIZE)
self._test_multiclass_auprc_with_input(input, target, num_classes)

num_classes = 8
input, target = self._get_rand_inputs_multiclass(num_classes, BATCH_SIZE)
self._test_multiclass_auprc_with_input(input, target, num_classes)

def test_multiclass_auprc_options(self) -> None:
### average = macro, num_classes not given
input, target = self._get_rand_inputs_multiclass(5, BATCH_SIZE)
compute_result = torch.mean(self._get_sklearn_equivalent(input, target))
te_compute_result = multiclass_auprc(input, target, average="macro")
torch.testing.assert_close(
te_compute_result,
compute_result,
equal_nan=True,
atol=1e-8,
rtol=1e-5,
)

### average = macro (not given), num_classes given
input, target = self._get_rand_inputs_multiclass(5, BATCH_SIZE)
compute_result = torch.mean(self._get_sklearn_equivalent(input, target))
te_compute_result = multiclass_auprc(input, target)
print(compute_result, te_compute_result)
torch.testing.assert_close(
te_compute_result,
compute_result,
equal_nan=True,
atol=1e-8,
rtol=1e-5,
)

### average = none
input, target = self._get_rand_inputs_multiclass(5, BATCH_SIZE)
compute_result = self._get_sklearn_equivalent(input, target)
te_compute_result = multiclass_auprc(input, target, average="none")
torch.testing.assert_close(
te_compute_result,
compute_result,
equal_nan=True,
atol=1e-8,
rtol=1e-5,
)
te_compute_result = multiclass_auprc(input, target, average=None)
torch.testing.assert_close(
te_compute_result,
compute_result,
equal_nan=True,
atol=1e-8,
rtol=1e-5,
)

def test_multiclass_auprc_docstring_examples(self) -> None:
input = torch.tensor([[0.5647, 0.2726], [0.9143, 0.1895], [0.7782, 0.3082]])
target = torch.tensor([0, 1, 0])
output = torch.tensor([0.5833, 0.3333])
result = multiclass_auprc(input, target, average=None)
torch.testing.assert_close(
result,
output,
equal_nan=True,
atol=1e-4,
rtol=1e-3,
)

avg_result = multiclass_auprc(input, target)
avg_output = torch.tensor(0.4583)
torch.testing.assert_close(
avg_result,
avg_output,
equal_nan=True,
atol=1e-4,
rtol=1e-3,
)

input = torch.tensor([[0.1, 1], [0.5, 1], [0.7, 1], [0.8, 0]])
target = torch.tensor([1, 0, 0, 1])
result = multiclass_auprc(input, target, 2, average=None)
output = torch.tensor([0.5833, 0.4167])
torch.testing.assert_close(
avg_result,
avg_output,
equal_nan=True,
atol=1e-4,
rtol=1e-3,
)

def test_multiclass_auroc_invalid_input(self) -> None:
with self.assertRaisesRegex(
ValueError, "`average` was not in the allowed value of .*, got micro."
):
num_classes = 4
input, target = self._get_rand_inputs_multiclass(num_classes, BATCH_SIZE)
multiclass_auprc(
input,
target,
num_classes=num_classes,
average="micro",
)

with self.assertRaisesRegex(ValueError, "`num_classes` has to be at least 2."):
multiclass_auprc(torch.rand(4, 2), torch.rand(2), num_classes=1)

with self.assertRaisesRegex(
ValueError,
"The `input` and `target` should have the same first dimension, "
r"got shapes torch.Size\(\[4, 2\]\) and torch.Size\(\[3\]\).",
):
multiclass_auprc(torch.rand(4, 2), torch.rand(3), num_classes=2)

with self.assertRaisesRegex(
ValueError,
"target should be a one-dimensional tensor, "
r"got shape torch.Size\(\[3, 2\]\).",
):
multiclass_auprc(torch.rand(3, 2), torch.rand(3, 2), num_classes=2)

with self.assertRaisesRegex(
ValueError,
r"input should have shape of \(num_sample, num_classes\), "
r"got torch.Size\(\[3, 4\]\) and num_classes=2.",
):
multiclass_auprc(torch.rand(3, 4), torch.rand(3), num_classes=2)
2 changes: 2 additions & 0 deletions torcheval/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
binary_recall,
binary_recall_at_fixed_precision,
multiclass_accuracy,
multiclass_auprc,
multiclass_auroc,
multiclass_binned_precision_recall_curve,
multiclass_confusion_matrix,
Expand Down Expand Up @@ -59,6 +60,7 @@
"mean",
"mean_squared_error",
"multiclass_accuracy",
"multiclass_auprc",
"multiclass_auroc",
"multiclass_binned_precision_recall_curve",
"multiclass_confusion_matrix",
Expand Down
6 changes: 5 additions & 1 deletion torcheval/metrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
multilabel_accuracy,
topk_multilabel_accuracy,
)
from torcheval.metrics.functional.classification.auprc import binary_auprc
from torcheval.metrics.functional.classification.auprc import (
binary_auprc,
multiclass_auprc,
)

from torcheval.metrics.functional.classification.auroc import (
binary_auroc,
Expand Down Expand Up @@ -67,6 +70,7 @@
"binary_recall",
"binary_recall_at_fixed_precision",
"multiclass_accuracy",
"multiclass_auprc",
"multiclass_auroc",
"multiclass_binned_auroc",
"multiclass_binned_precision_recall_curve",
Expand Down
Loading

0 comments on commit e62e143

Please sign in to comment.