Skip to content

Commit

Permalink
share lengths computation (line 257) between gru, bigru, attgru
Browse files Browse the repository at this point in the history
use lengths to get hidden layer from unidirectional gru (line 260)
  • Loading branch information
dyth committed Apr 9, 2019
1 parent 4245099 commit d2559fc
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions babyai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,13 @@ def forward(self, obs, memory, instr_embedding=None):
return {'dist': dist, 'value': value, 'memory': memory, 'extra_predictions': extra_predictions}

def _get_instr_embedding(self, instr):
lengths = (instr != 0).sum(1).long()
if self.lang_model == 'gru':
out, _ = self.instr_rnn(self.word_embedding(instr))
index = -1 - (instr <= 0).sum(dim=1)
hidden = out[range(len(index)), index, :]
hidden = out[range(len(lengths)), lengths-1, :]
return hidden

elif self.lang_model in ['bigru', 'attgru']:
lengths = (instr != 0).sum(1).long()
masks = (instr != 0).float()

if lengths.shape[0] > 1:
Expand Down

0 comments on commit d2559fc

Please sign in to comment.