forked from SeventeenChen/fashionMNIST_ResNet_master
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
52 lines (42 loc) · 1.59 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from torch.utils.tensorboard import SummaryWriter
from data_loader import *
from model import *
import os
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter(comment='ResNet34')
checkpoint_dir = './results'
ckpt_model_filename = 'ckpt_valid_acc_88.56999969482422_epoch_180.pth'
PATH = os.path.join(checkpoint_dir, ckpt_model_filename)
NUM_CLASSES = 10
# Other
GRAYSCALE = True
def compute_acc(model, data_loader, device):
correct_pred, num_examples = 0, 0
model.eval()
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.to(device)
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
num_examples = targets.size(0)
assert predicted_labels.size() == targets.size()
correct_pred = (predicted_labels == targets).sum()
# print('num_examples', num_examples)
break
return correct_pred.float() / num_examples * 100
if __name__ == '__main__':
model = resnet34(NUM_CLASSES, GRAYSCALE)
model.load_state_dict(torch.load(PATH))
model.to(device)
model = model.eval()
for batch_idx, (features, targets) in enumerate(test_loader):
with torch.set_grad_enabled(False):
test_acc = compute_acc(model, test_loader, device)
print("\nTest accuracy: %.2f%%" %(test_acc))
# writer.add_scalar('Test accuracy', test_acc)
ckpt_model_filename = 'ckpt_test_acc_{}.pth'.format(test_acc)
ckpt_model_path = os.path.join(checkpoint_dir, ckpt_model_filename) # model_save
torch.save(model.state_dict(), ckpt_model_path)
print("\nDone, save model at {}", ckpt_model_path)
break