diff --git a/snli/train.py b/snli/train.py index e64eb540f5..5b518a6ec8 100644 --- a/snli/train.py +++ b/snli/train.py @@ -26,7 +26,7 @@ if os.path.isfile(args.vector_cache): inputs.vocab.vectors = torch.load(args.vector_cache) else: - inputs.vocab.load_vectors(wv_dir=args.data_cache, wv_type=args.word_vectors, wv_dim=args.d_embed) + inputs.vocab.load_vectors(args.word_vectors) makedirs(os.path.dirname(args.vector_cache)) torch.save(inputs.vocab.vectors, args.vector_cache) answers.build_vocab(train) diff --git a/snli/util.py b/snli/util.py index 509fdbdf1f..d1e23feaf2 100644 --- a/snli/util.py +++ b/snli/util.py @@ -22,7 +22,7 @@ def get_args(): parser = ArgumentParser(description='PyTorch/torchtext SNLI example') parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=128) - parser.add_argument('--d_embed', type=int, default=300) + parser.add_argument('--d_embed', type=int, default=100) parser.add_argument('--d_proj', type=int, default=300) parser.add_argument('--d_hidden', type=int, default=300) parser.add_argument('--n_layers', type=int, default=1) @@ -37,9 +37,8 @@ def get_args(): parser.add_argument('--train_embed', action='store_false', dest='fix_emb') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--save_path', type=str, default='results') - parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), '.data_cache')) parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt')) - parser.add_argument('--word_vectors', type=str, default='glove.42B') + parser.add_argument('--word_vectors', type=str, default='glove.6B.100d') parser.add_argument('--resume_snapshot', type=str, default='') args = parser.parse_args() return args