Skip to content

Commit

Permalink
BinaryAccuracy support for bool target tensor (#136)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #136

Addressing github issue #134

Reviewed By: JKSenthil

Differential Revision: D44179119

fbshipit-source-id: 14d75bd3122776611b11649c50d12de871164f49
  • Loading branch information
bobakfb authored and facebook-github-bot committed Mar 20, 2023
1 parent 8f2f12a commit 007487d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
10 changes: 10 additions & 0 deletions tests/metrics/functional/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def test_binary_accuracy(self) -> None:

self._test_binary_accuracy_with_input(input, target)

def test_binary_accuracy_with_bool_target(self) -> None:
num_classes = 2
input = torch.randint(high=num_classes, size=(BATCH_SIZE,))
target = torch.randint(high=num_classes, size=(BATCH_SIZE,))

output = binary_accuracy(input, target)
output_bool = binary_accuracy(input, target.bool())

torch.testing.assert_close(output, output_bool)

def test_binary_accuracy_with_rounding(self) -> None:
num_classes = 2
input = torch.rand(size=(BATCH_SIZE,))
Expand Down
11 changes: 10 additions & 1 deletion torcheval/metrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,16 @@ def _binary_accuracy_update(
input = torch.where(input < threshold, 0, 1)

num_correct = (input == target).sum()
num_total = target.new_tensor(target.shape[0])
if target.dtype == torch.bool:
num_total = torch.tensor(
target.shape[0],
dtype=torch.int64,
device=target.device,
requires_grad=False,
)
else:
# this is faster than using torch.tensor, but breaks for bool tensors because the shape will be cast to 1 in a bool tensor
num_total = target.new_tensor(target.shape[0])
return num_correct, num_total


Expand Down

0 comments on commit 007487d

Please sign in to comment.