Skip to content

Commit

Permalink
fix dtype error
Browse files Browse the repository at this point in the history
sent1 was float64 and sent2 float32 in some cases
  • Loading branch information
mschrimpf committed Feb 19, 2021
1 parent b13982e commit a95bded
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions neural_nlp/models/implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,8 @@ def glue_dataset(self, examples, label_list, output_mode):
sents1 = [self._sent_mean(self._encode_sentence(sent)) for sent in tqdm(text_a)]
sents2 = [self._sent_mean(self._encode_sentence(sent)) for sent in tqdm(text_b)]
for sent1, sent2 in zip(sents1, sents2):
sent1 = torch.tensor(sent1)
sent2 = torch.tensor(sent2)
sent1 = torch.tensor(sent1, dtype=torch.float64)
sent2 = torch.tensor(sent2, dtype=torch.float64)
f = torch.cat([sent1, sent2, torch.abs(sent1 - sent2), sent1 * sent2], -1)
features.append(PytorchWrapper._tensor_to_numpy(f))
all_features = torch.tensor(features).float()
Expand Down

0 comments on commit a95bded

Please sign in to comment.