Skip to content

Commit

Permalink
Fixed python3.5-related bug in models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexis Conneau committed Jul 6, 2017
1 parent 59d2ea9 commit 260bfd4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions encoder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, sent_tuple):
sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
idx_unsort = np.argsort(idx_sort)

idx_sort = torch.cuda.LongTensor(idx_sort) if self.use_cuda else torch.LongTensor(idx_sort)
idx_sort = torch.from_numpy(idx_sort).cuda() if self.use_cuda else torch.from_numpy(idx_sort)
sent = sent.index_select(1, Variable(idx_sort))

# Handling padding in Recurrent Networks
Expand All @@ -47,7 +47,7 @@ def forward(self, sent_tuple):
sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]

# Un-sort by length
idx_unsort = torch.cuda.LongTensor(idx_unsort) if self.use_cuda else torch.LongTensor(idx_unsort)
idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.use_cuda else torch.from_numpy(idx_sort)
sent_output = sent_output.index_select(1, Variable(idx_unsort))

# Pooling
Expand Down
6 changes: 3 additions & 3 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def forward(self, sent_tuple):
# Sort by length (keep idx)
sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
idx_unsort = np.argsort(idx_sort)
idx_sort = torch.cuda.LongTensor(idx_sort) if self.use_cuda else torch.LongTensor(idx_sort)

idx_sort = torch.from_numpy(idx_sort).cuda() if self.use_cuda else torch.from_numpy(idx_sort)
sent = sent.index_select(1, Variable(idx_sort))

# Handling padding in Recurrent Networks
Expand All @@ -50,7 +50,7 @@ def forward(self, sent_tuple):
sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0]

# Un-sort by length
idx_unsort = torch.cuda.LongTensor(idx_unsort) if self.use_cuda else torch.LongTensor(idx_unsort)
idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.use_cuda else torch.from_numpy(idx_sort)
sent_output = sent_output.index_select(1, Variable(idx_unsort))

# Pooling
Expand Down

0 comments on commit 260bfd4

Please sign in to comment.