From 40caf1ab50f641bc8c3d078c4e2daa4207bc727f Mon Sep 17 00:00:00 2001 From: Mufei Li Date: Thu, 24 Sep 2020 08:54:28 +0800 Subject: [PATCH] [HeteroGraph] Fix the failure of apply_nodes when the input function changes feature size for all nodes (#2223) * Fix * Fix --- python/dgl/heterograph.py | 6 +++--- tests/compute/test_heterograph.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/dgl/heterograph.py b/python/dgl/heterograph.py index 5ff1f4b1eee7..5f846fd2865c 100644 --- a/python/dgl/heterograph.py +++ b/python/dgl/heterograph.py @@ -4016,10 +4016,10 @@ def apply_nodes(self, func, v=ALL, ntype=None, inplace=False): ntid = self.get_ntype_id(ntype) ntype = self.ntypes[ntid] if is_all(v): - v = self.nodes(ntype) + v_id = self.nodes(ntype) else: - v = utils.prepare_tensor(self, v, 'v') - ndata = core.invoke_node_udf(self, v, ntype, func, orig_nid=v) + v_id = utils.prepare_tensor(self, v, 'v') + ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id) self._set_n_repr(ntid, v, ndata) def apply_edges(self, func, edges=ALL, etype=None, inplace=False): diff --git a/tests/compute/test_heterograph.py b/tests/compute/test_heterograph.py index 70b462638061..19f2c2758abe 100644 --- a/tests/compute/test_heterograph.py +++ b/tests/compute/test_heterograph.py @@ -1295,6 +1295,8 @@ def _check_typed_subgraph2(g, sg): def test_apply(idtype): def node_udf(nodes): return {'h': nodes.data['h'] * 2} + def node_udf2(nodes): + return {'h': F.sum(nodes.data['h'], dim=1, keepdims=True)} def edge_udf(edges): return {'h': edges.data['h'] * 2 + edges.src['h']} @@ -1314,6 +1316,11 @@ def edge_udf(edges): g['plays'].apply_edges(edge_udf) assert F.array_equal(g['plays'].edata['h'], F.ones((4, 5)) * 12) + # Test the case that feature size changes + g.nodes['user'].data['h'] = F.ones((3, 5)) + g.apply_nodes(node_udf2, ntype='user') + assert F.array_equal(g.nodes['user'].data['h'], F.ones((3, 1)) * 5) + # test fail case # fail due to multiple types with pytest.raises(DGLError):