Skip to content

Commit

Permalink
mnist 0.2 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Aug 6, 2017
1 parent 9053040 commit fb9ca4d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test():
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
Expand Down

0 comments on commit fb9ca4d

Please sign in to comment.