Skip to content

Commit

Permalink
Update test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jun 16, 2020
1 parent 3390634 commit 816cba3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
Binary file not shown.
Binary file not shown.
27 changes: 18 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,30 @@
GRAYSCALE = True


def compute_acc(model, data_loader, device):
def compute_accuracy(model, data_loader, device):
correct_pred, num_examples = 0, 0
model.eval()
top1_acc, top2_acc = 0.0, 0.0
top2_correct = 0

for i, (features, targets) in enumerate(data_loader):

features = features.to(device)
targets = targets.to(device)

logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
top2_indices = probas.topk(2, 1)[1]
targets_trans = targets.t().view(-1,1)
# print(top2_indices.shape, targets_trans.repeat(1, 2).shape)
top2_predict = (top2_indices == targets_trans.repeat(1, 2))
top2_correct += top2_predict.sum()
num_examples = targets.size(0)
assert predicted_labels.size() == targets.size()
correct_pred = (predicted_labels == targets).sum()
# print('num_examples', num_examples)
top1_acc = correct_pred.float() / num_examples * 100
top2_acc = top2_correct.float() / num_examples * 100
break
return correct_pred.float() / num_examples * 100

return top1_acc, top2_acc

if __name__ == '__main__':
model = resnet34(NUM_CLASSES, GRAYSCALE)
Expand All @@ -41,12 +50,12 @@ def compute_acc(model, data_loader, device):

with torch.set_grad_enabled(False):

test_acc = compute_acc(model, test_loader, device)
print("\nTest accuracy: %.2f%%" %(test_acc))
test_acc_top1, test_acc_top2= compute_accuracy(model, test_loader, device)
print("\nTest accuracy: %.2f%%" %(test_acc_top1))
# writer.add_scalar('Test accuracy', test_acc)
ckpt_model_filename = 'ckpt_test_acc_{}.pth'.format(test_acc)
ckpt_model_filename = 'ckpt_test_acc_top1_{}_top2_{}.pth'.format(test_acc_top1, test_acc_top2)
ckpt_model_path = os.path.join(checkpoint_dir, ckpt_model_filename) # model_save
torch.save(model.state_dict(), ckpt_model_path)
print("\nDone, save model at {}", ckpt_model_path)

break
break

0 comments on commit 816cba3

Please sign in to comment.