From 688a9228a820c419d9548ea2b44a6e4fe0a2cc1e Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 11 Apr 2019 15:38:35 -0700 Subject: [PATCH] fix. (#491) --- examples/mxnet/tree_lstm/train.py | 6 +++--- examples/mxnet/tree_lstm/tree_lstm.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/mxnet/tree_lstm/train.py b/examples/mxnet/tree_lstm/train.py index 02ca1322b391..bfe72b97f4f5 100644 --- a/examples/mxnet/tree_lstm/train.py +++ b/examples/mxnet/tree_lstm/train.py @@ -122,7 +122,7 @@ def main(args): dur.append(time.time() - t0) # tok if step > 0 and step % args.log_every == 0: - pred = pred.argmax(axis=1) + pred = pred.argmax(axis=1).astype(batch.label.dtype) acc = (batch.label == pred).sum() root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]) @@ -139,7 +139,7 @@ def main(args): n = g.number_of_nodes() h = mx.nd.zeros((n, args.h_size), ctx=ctx) c = mx.nd.zeros((n, args.h_size), ctx=ctx) - pred = model(batch, h, c).argmax(1) + pred = model(batch, h, c).argmax(1).astype(batch.label.dtype) acc = (batch.label == pred).sum().asscalar() accs.append([acc, len(batch.label)]) @@ -175,7 +175,7 @@ def main(args): n = g.number_of_nodes() h = mx.nd.zeros((n, args.h_size), ctx=ctx) c = mx.nd.zeros((n, args.h_size), ctx=ctx) - pred = model(batch, h, c).argmax(axis=1) + pred = model(batch, h, c).argmax(axis=1).astype(batch.label.dtype) acc = (batch.label == pred).sum().asscalar() accs.append([acc, len(batch.label)]) diff --git a/examples/mxnet/tree_lstm/tree_lstm.py b/examples/mxnet/tree_lstm/tree_lstm.py index f4c78d6bac71..523537095be1 100644 --- a/examples/mxnet/tree_lstm/tree_lstm.py +++ b/examples/mxnet/tree_lstm/tree_lstm.py @@ -118,7 +118,8 @@ def forward(self, batch, h, c): g.register_apply_node_func(self.cell.apply_node_func) # feed embedding embeds = self.embedding(batch.wordid * batch.mask) - g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.expand_dims(-1) + wiou = self.cell.W_iou(self.dropout(embeds)) + g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype) g.ndata['h'] = h g.ndata['c'] = c # propagate