Skip to content

Commit

Permalink
Dima fix imitation learning (mila-iqia#93)
Browse files Browse the repository at this point in the history
* make sure instr_embedding and mask have the same length

* fix typos in the comments
  • Loading branch information
rizar authored Mar 4, 2020
1 parent b91b34b commit 007ce2c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions babyai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,17 @@ def forward(self, obs, memory, instr_embedding=None):
# outputs: B x L x D
# memory: B x M
mask = (obs.instr != 0).float()

# The mask tensor should be always at least a big as instr_embedding.
# It can be too big though, because instr_embeddings might be shorter than obs.instr
# when the last position in a batch is all zeros.
# The mask tensor has the same length as obs.instr, and
# thus can be both shorter and longer than instr_embedding.
# It can be longer if instr_embedding is computed
# for a subbatch of obs.instr.
# It can be shorter if obs.instr is a subbatch of
# the batch that instr_embeddings was computed for.
# Here, we make sure that mask and instr_embeddings
# have equal length along dimension 1.
mask = mask[:, :instr_embedding.shape[1]]
instr_embedding = instr_embedding[:, :mask.shape[1]]

keys = self.memory2key(memory)
pre_softmax = (keys[:, None, :] * instr_embedding).sum(2) + 1000 * mask
attention = F.softmax(pre_softmax, dim=1)
Expand Down

0 comments on commit 007ce2c

Please sign in to comment.