Skip to content

Commit

Permalink
Fix metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
santiag0m committed May 16, 2022
1 parent 2265bf2 commit 9611ea7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def compute_accuracy(model: nn.Module, dataloader: DataLoader) -> torch.Tensor:
preds = torch.argmax(preds, dim=-1)
correct = preds == targets
accuracy.update(correct.cpu())
return accuracy.value
return accuracy.value.item()


def compute_mse(
Expand All @@ -39,7 +39,7 @@ def compute_mse(
batch_mse = (targets - preds) ** 2
batch_mse = batch_mse.sum(axis=1)
mse.update(batch_mse.cpu())
return mse.value
return mse.value.item()


def compute_bias(
Expand Down

0 comments on commit 9611ea7

Please sign in to comment.