diff --git a/babyai/model.py b/babyai/model.py index b954e398..33e6a72a 100644 --- a/babyai/model.py +++ b/babyai/model.py @@ -216,7 +216,6 @@ def forward(self, obs, memory, instr_embedding=None): # It can be too big though, because instr_embeddings might be shorter than obs.instr # when the last position in a batch is all zeros. mask = mask[:, :instr_embedding.shape[1]] - instr_embedding = instr_embedding[:, :mask.shape[1]] keys = self.memory2key(memory) pre_softmax = (keys[:, None, :] * instr_embedding).sum(2) + 1000 * mask attention = F.softmax(pre_softmax, dim=1)