Skip to content

Commit

Permalink
bug fix: vocab.load_vectors signature update
Browse files Browse the repository at this point in the history
  • Loading branch information
sivareddyg authored and soumith committed Oct 26, 2017
1 parent 9a02f2a commit 23f8abf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if os.path.isfile(args.vector_cache):
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
inputs.vocab.load_vectors(wv_dir=args.data_cache, wv_type=args.word_vectors, wv_dim=args.d_embed)
inputs.vocab.load_vectors(args.word_vectors)
makedirs(os.path.dirname(args.vector_cache))
torch.save(inputs.vocab.vectors, args.vector_cache)
answers.build_vocab(train)
Expand Down
5 changes: 2 additions & 3 deletions snli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_args():
parser = ArgumentParser(description='PyTorch/torchtext SNLI example')
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--d_embed', type=int, default=300)
parser.add_argument('--d_embed', type=int, default=100)
parser.add_argument('--d_proj', type=int, default=300)
parser.add_argument('--d_hidden', type=int, default=300)
parser.add_argument('--n_layers', type=int, default=1)
Expand All @@ -37,9 +37,8 @@ def get_args():
parser.add_argument('--train_embed', action='store_false', dest='fix_emb')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--save_path', type=str, default='results')
parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), '.data_cache'))
parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt'))
parser.add_argument('--word_vectors', type=str, default='glove.42B')
parser.add_argument('--word_vectors', type=str, default='glove.6B.100d')
parser.add_argument('--resume_snapshot', type=str, default='')
args = parser.parse_args()
return args

0 comments on commit 23f8abf

Please sign in to comment.