Skip to content

Commit

Permalink
fix. (dmlc#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da authored and szha committed Apr 11, 2019
1 parent e2e0432 commit 688a922
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions examples/mxnet/tree_lstm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main(args):
dur.append(time.time() - t0) # tok

if step > 0 and step % args.log_every == 0:
pred = pred.argmax(axis=1)
pred = pred.argmax(axis=1).astype(batch.label.dtype)
acc = (batch.label == pred).sum()
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids])
Expand All @@ -139,7 +139,7 @@ def main(args):
n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(1)
pred = model(batch, h, c).argmax(1).astype(batch.label.dtype)

acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)])
Expand Down Expand Up @@ -175,7 +175,7 @@ def main(args):
n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(axis=1)
pred = model(batch, h, c).argmax(axis=1).astype(batch.label.dtype)

acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)])
Expand Down
3 changes: 2 additions & 1 deletion examples/mxnet/tree_lstm/tree_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def forward(self, batch, h, c):
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.expand_dims(-1)
wiou = self.cell.W_iou(self.dropout(embeds))
g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
Expand Down

0 comments on commit 688a922

Please sign in to comment.