-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutil.py
24 lines (20 loc) · 845 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from torch.autograd import Variable
def accuracy(y_pred, y):
return (torch.max(y_pred, 1)[1] == y).float().mean().data[0]
''' Returns the mean accuracy on the test set, given a model '''
def eval_on_test(test_dataloader, model_fn):
acc = 0
for x, y in test_dataloader:
x, y = Variable(x), Variable(y)
if torch.cuda.is_available():
x, y = x.cuda(), y.cuda()
acc += accuracy(model_fn(x), y)
return round(acc / float(len(test_dataloader)), 3)
''' Converts a list of (x, x) pairs into two Tensors '''
def into_tensor(data, into_vars=True):
X1 = [x[0] for x in data]
X2 = [x[1] for x in data]
if torch.cuda.is_available():
return Variable(torch.stack(X1)).cuda(), Variable(torch.stack(X2)).cuda()
return Variable(torch.stack(X1)), Variable(torch.stack(X2))