Skip to content

Commit

Permalink
New pytorch + embedding fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-Emmanuel Dias authored and Charles-Emmanuel Dias committed Oct 10, 2017
1 parent 7e99c81 commit 87ab47c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(self, packed_batch):
all_att = self._masked_softmax(attend,self._list_to_bytemask(list(len_s))).transpose(0,1) # attW,sent
attended = all_att.unsqueeze(2).expand_as(enc_sents) * enc_sents

return attended.sum(0).squeeze(0)
return attended.sum(0,True).squeeze(0)

def _list_to_bytemask(self,l):
mask = self._buffers['mask'].resize_(len(l),l[0]).fill_(1)
Expand All @@ -45,7 +45,7 @@ def _list_to_bytemask(self,l):

def _masked_softmax(self,mat,mask):
exp = torch.exp(mat) * Variable(mask)
sum_exp = exp.sum(1)+0.0001
sum_exp = exp.sum(1,True)+0.0001

return exp/sum_exp.expand_as(exp)

Expand Down
22 changes: 15 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test(epoch,net,dataset,cuda):
def accuracy(out,truth):
def sm(mat):
exp = torch.exp(mat)
sum_exp = exp.sum(1)+0.0001
sum_exp = exp.sum(1,True)+0.0001
return exp/sum_exp.expand_as(exp)

_,max_i = torch.max(sm(out),1)
Expand Down Expand Up @@ -206,7 +206,16 @@ def main(args):
print(25*"-" + "\nBuilding word vectors: \n"+"-"*25)

vectorizer = Vectorizer(max_word_len=args.max_words,max_sent_len=args.max_sents)
vectorizer.build_dict(train_set.field_iter(0),args.max_feat)

if args.emb:
tensor,dic = load_embeddings(args.emb)
net = HierarchicalDoc(ntoken=len(dic),num_class=num_class)
net.set_emb_tensor(torch.FloatTensor(tensor))
vectorizer.word_dict = dic
else:
vectorizer.build_dict(train_set.field_iter(0),args.max_feat)
net = HierarchicalDoc(ntoken=len(vectorizer.word_dict),num_class=num_class)


tuple_batch = tuple_batcher_builder(vectorizer,train=True)
tuple_batch_test = tuple_batcher_builder(vectorizer,train=False)
Expand All @@ -216,20 +225,19 @@ def main(args):
sampler = None
if args.balance:
sampler = BucketSampler(train_set)
sampler_t = BucketSampler(test_set)


dataloader = DataLoader(train_set, batch_size=args.b_size, shuffle=True, sampler=sampler, num_workers=1, collate_fn=tuple_batch)#, collate_fn=<function default_collate>, pin_memory=False)
dataloader_test = DataLoader(test_set, batch_size=args.b_size, shuffle=True, num_workers=1, collate_fn=tuple_batch_test)#, collate_fn=<function default_collate>, pin_memory=False)
dataloader = DataLoader(train_set, batch_size=args.b_size, shuffle=False, sampler=sampler, num_workers=1, collate_fn=tuple_batch)#, collate_fn=<function default_collate>, pin_memory=False)
dataloader_test = DataLoader(test_set, batch_size=args.b_size, shuffle=False,sampler=sampler_t, num_workers=1, collate_fn=tuple_batch_test)#, collate_fn=<function default_collate>, pin_memory=False)


criterion = torch.nn.CrossEntropyLoss()



net = HierarchicalDoc(ntoken=len(vectorizer.word_dict),num_class=num_class)

if args.emb:
net.set_emb_tensor(torch.FloatTensor(tensor))


if args.cuda:
net.cuda()
Expand Down

0 comments on commit 87ab47c

Please sign in to comment.