Skip to content

Commit 5b1d191

Browse files
302 - remove .squeeze() from 1-D numpy array
prediction.data.numpy() is already 1-D, so the .squeeze() is unnecessary.
1 parent 906cf71 commit 5b1d191

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tutorial-contents/302_classification.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ def forward(self, x):
5959
# plot and show learning process
6060
plt.cla()
6161
prediction = torch.max(out, 1)[1]
62-
pred_y = prediction.data.numpy().squeeze()
62+
pred_y = prediction.data.numpy()
6363
target_y = y.data.numpy()
6464
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
6565
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
6666
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
6767
plt.pause(0.1)
6868

6969
plt.ioff()
70-
plt.show()
70+
plt.show()

0 commit comments

Comments
 (0)