Skip to content

Commit

Permalink
Fixed PyTorch 0.4.0 changes (changes torch.cuda.float tensors to float)
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhakar267 committed May 18, 2018
1 parent 4f60395 commit 3d1fbde
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/dico_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def build_dictionary(src_emb, tgt_emb, params, s2t_candidates=None, t2s_candidat
elif params.dico_build == 'T2S':
dico = t2s_candidates
else:
s2t_candidates = s2t_candidates.numpy() if isinstance(s2t_candidates, torch.Tensor) else s2t_candidates
t2s_candidates = t2s_candidates.numpy() if isinstance(t2s_candidates, torch.Tensor) else t2s_candidates
s2t_candidates = set([(a, b) for a, b in s2t_candidates])
t2s_candidates = set([(a, b) for a, b in t2s_candidates])
if params.dico_build == 'S2T|T2S':
Expand Down
2 changes: 2 additions & 0 deletions src/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from copy import deepcopy
import numpy as np
from torch.autograd import Variable
from torch import Tensor as torch_tensor

from . import get_wordsim_scores, get_crosslingual_wordsim_scores, get_wordanalogy_scores
from . import get_word_translation_accuracy
Expand Down Expand Up @@ -202,6 +203,7 @@ def dist_mean_cosine(self, to_log):
mean_cosine = -1e9
else:
mean_cosine = (src_emb[dico[:dico_max_size, 0]] * tgt_emb[dico[:dico_max_size, 1]]).sum(1).mean()
mean_cosine = mean_cosine.item() if isinstance(mean_cosine, torch_tensor) else mean_cosine
logger.info("Mean cosine (%s method, %s build, %i max size): %.5f"
% (dico_method, _params.dico_build, dico_max_size, mean_cosine))
to_log['mean_cosine-%s-%s-%i' % (dico_method, _params.dico_build, dico_max_size)] = mean_cosine
Expand Down

0 comments on commit 3d1fbde

Please sign in to comment.