Skip to content

Commit a9ef65e

Browse files
[Bug fix] 401_CNN.py
1. fix typo 2. remove .squeeze for 1-D tensor
1 parent 0f4219c commit a9ef65e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tutorial-contents/401_CNN.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self):
6565
out_channels=16, # n_filters
6666
kernel_size=5, # filter size
6767
stride=1, # filter movement/step
68-
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
68+
padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
6969
), # output shape (16, 28, 28)
7070
nn.ReLU(), # activation
7171
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
@@ -115,7 +115,7 @@ def plot_with_labels(lowDWeights, labels):
115115

116116
if step % 50 == 0:
117117
test_output, last_layer = cnn(test_x)
118-
pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy()
118+
pred_y = torch.max(test_output, 1)[1].data.numpy()
119119
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
120120
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
121121
if HAS_SK:
@@ -129,6 +129,6 @@ def plot_with_labels(lowDWeights, labels):
129129

130130
# print 10 predictions from test data
131131
test_output, _ = cnn(test_x[:10])
132-
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
132+
pred_y = torch.max(test_output, 1)[1].data.numpy()
133133
print(pred_y, 'prediction number')
134134
print(test_y[:10].numpy(), 'real number')

0 commit comments

Comments
 (0)