Skip to content

Commit

Permalink
Update test_capsnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jindongwang authored Apr 9, 2018
1 parent d5b3d5e commit f16ffdd
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions test_capsnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import datasets, transforms
from capsnet import CapsNet
from data_loader import Dataset
Expand Down Expand Up @@ -65,7 +63,7 @@ def train(model, optimizer, train_loader, epoch):

train_loss += loss.data[0]
if batch_id % 100 == 0:
tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy:{:.6f}".format(
tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
epoch,
N_EPOCHS,
batch_id + 1,
Expand All @@ -75,7 +73,6 @@ def train(model, optimizer, train_loader, epoch):
train_loss / len(train_loader)))



def test(capsule_net, test_loader, epoch):
capsule_net.eval()
test_loss = 0
Expand All @@ -95,7 +92,9 @@ def test(capsule_net, test_loader, epoch):
correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==
np.argmax(target.data.cpu().numpy(), 1))

tqdm.write("Epoch: [{}/{}], test accuracy:{:.6f},loss:{:.6f}".format(epoch, N_EPOCHS, correct / len(test_loader.dataset),test_loss / len(test_loader)))
tqdm.write(
"Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(epoch, N_EPOCHS, correct / len(test_loader.dataset),
test_loss / len(test_loader)))


if __name__ == '__main__':
Expand All @@ -110,7 +109,7 @@ def test(capsule_net, test_loader, epoch):
capsule_net = capsule_net.cuda()
capsule_net = capsule_net.module

optimizer = Adam(capsule_net.parameters())
optimizer = torch.optim.Adam(capsule_net.parameters())

for e in range(1, N_EPOCHS + 1):
train(capsule_net, optimizer, mnist.train_loader, e)
Expand Down

0 comments on commit f16ffdd

Please sign in to comment.