Skip to content

Commit

Permalink
Update 08-seq_classification.ipynb (#814)
Browse files Browse the repository at this point in the history
08-seq_classification.ipynb evaluation was using a wrong output, which was fixed by #19 for the training part.
  • Loading branch information
Gaaaavin authored Feb 22, 2022
1 parent 3c890a2 commit 881d595
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions 08-seq_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,7 @@
" target_decoded = data_generator.decode_y_batch(target.cpu().numpy())\n",
"\n",
" output = model(data)\n",
" sequence_end = torch.tensor([len(sequence) for sequence in data_decoded]) - 1\n",
" output = output[torch.arange(data.shape[0]).long(), sequence_end, :]\n",
" output = output[:, -1, :]\n",
"\n",
" target = target.argmax(dim=1)\n",
" y_pred = output.argmax(dim=1)\n",
Expand Down

0 comments on commit 881d595

Please sign in to comment.