Skip to content

Commit

Permalink
Bugs In MNIST (pytorch#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored Jul 7, 2017
2 parents d6e6324 + 53f25e0 commit cab5705
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def train(epoch):
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))

def test(epoch):
def test():
model.eval()
test_loss = 0
correct = 0
Expand All @@ -98,17 +98,16 @@ def test(epoch):
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target).data[0]
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()

test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
test()

0 comments on commit cab5705

Please sign in to comment.