diff --git a/metric.py b/metric.py new file mode 100644 index 0000000..fed78d2 --- /dev/null +++ b/metric.py @@ -0,0 +1,8 @@ +import torch +from sklearn.metrics import accuracy_score + +def accuracy(preds, target): + preds = torch.max(preds, 1)[1].float() + acc = accuracy_score(preds.cpu().numpy(), target.cpu().numpy()) + + return acc