diff --git a/tutorials/02-intermediate/bidirectional_recurrent_neural_network/main.py b/tutorials/02-intermediate/bidirectional_recurrent_neural_network/main.py index a0ecd773..ef892893 100644 --- a/tutorials/02-intermediate/bidirectional_recurrent_neural_network/main.py +++ b/tutorials/02-intermediate/bidirectional_recurrent_neural_network/main.py @@ -53,8 +53,12 @@ def forward(self, x): # Forward propagate LSTM out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2) + out1, out2 = torch.chunk(out, 2, dim=2) + out_cat = torch.cat((out1[:, -1, :], out2[:, 0, :]), 1) + # Decode the hidden state of the last time step - out = self.fc(out[:, -1, :]) + out = self.fc(out_cat) + return out model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device) @@ -99,4 +103,4 @@ def forward(self, x): print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # Save the model checkpoint -torch.save(model.state_dict(), 'model.ckpt') \ No newline at end of file +torch.save(model.state_dict(), 'model.ckpt')