Skip to content

Commit

Permalink
adding glove vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann committed Jan 24, 2017
1 parent 623cd7f commit 92d566a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 3 additions & 1 deletion snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

train, val, test = datasets.SNLI.splits(inputs, answers)

inputs.build_vocab(train)
inputs.build_vocab(train, vectors=(args.data_cache, args.word_vectors, args.d_embed))
answers.build_vocab(train)

train_iter, val_iter, test_iter = data.BucketIterator.splits(
Expand All @@ -33,6 +33,8 @@
config.n_cells *= 2

model = SNLIClassifier(config)
if args.wv_path:
model.embed.weight = inputs.vocab.vectors
model.cuda()
criterion = nn.CrossEntropyLoss()
opt = O.Adam(model.parameters(), lr=args.lr)
Expand Down
3 changes: 3 additions & 0 deletions snli/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from argparse import ArgumentParser

def get_args():
Expand All @@ -15,5 +16,7 @@ def get_args():
parser.add_argument('--bidirectional', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), '.data_cache'))
parser.add_argument('--word_vectors', type=str, default='glove.42B')
args = parser.parse_args()
return args

0 comments on commit 92d566a

Please sign in to comment.