Skip to content

Commit

Permalink
seems to work now with pytorch 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle M. Tarplee committed May 29, 2018
1 parent 500fc66 commit d4bd35d
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 641 deletions.
24 changes: 12 additions & 12 deletions Experiments_FashionMNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,17 @@
" plt.legend(fashion_mnist_classes)\n",
"\n",
"def extract_embeddings(dataloader, model):\n",
" model.eval()\n",
" embeddings = np.zeros((len(dataloader.dataset), 2))\n",
" labels = np.zeros(len(dataloader.dataset))\n",
" k = 0\n",
" for images, target in dataloader:\n",
" images = Variable(images, volatile=True)\n",
" if cuda:\n",
" images = images.cuda()\n",
" embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()\n",
" labels[k:k+len(images)] = target.numpy()\n",
" k += len(images)\n",
" with torch.no_grad():\n",
" model.eval()\n",
" embeddings = np.zeros((len(dataloader.dataset), 2))\n",
" labels = np.zeros(len(dataloader.dataset))\n",
" k = 0\n",
" for images, target in dataloader:\n",
" if cuda:\n",
" images = images.cuda()\n",
" embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()\n",
" labels[k:k+len(images)] = target.numpy()\n",
" k += len(images)\n",
" return embeddings, labels"
]
},
Expand Down Expand Up @@ -1881,7 +1881,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
"version": "3.6.5"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit d4bd35d

Please sign in to comment.