Skip to content

Commit

Permalink
[HeteroGraph] Fix the failure of apply_nodes when the input function …
Browse files Browse the repository at this point in the history
…changes feature size for all nodes (dmlc#2223)

* Fix

* Fix
  • Loading branch information
mufeili authored Sep 24, 2020
1 parent 7e0107c commit 40caf1a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4016,10 +4016,10 @@ def apply_nodes(self, func, v=ALL, ntype=None, inplace=False):
ntid = self.get_ntype_id(ntype)
ntype = self.ntypes[ntid]
if is_all(v):
v = self.nodes(ntype)
v_id = self.nodes(ntype)
else:
v = utils.prepare_tensor(self, v, 'v')
ndata = core.invoke_node_udf(self, v, ntype, func, orig_nid=v)
v_id = utils.prepare_tensor(self, v, 'v')
ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id)
self._set_n_repr(ntid, v, ndata)

def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
Expand Down
7 changes: 7 additions & 0 deletions tests/compute/test_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,8 @@ def _check_typed_subgraph2(g, sg):
def test_apply(idtype):
def node_udf(nodes):
return {'h': nodes.data['h'] * 2}
def node_udf2(nodes):
return {'h': F.sum(nodes.data['h'], dim=1, keepdims=True)}
def edge_udf(edges):
return {'h': edges.data['h'] * 2 + edges.src['h']}

Expand All @@ -1314,6 +1316,11 @@ def edge_udf(edges):
g['plays'].apply_edges(edge_udf)
assert F.array_equal(g['plays'].edata['h'], F.ones((4, 5)) * 12)

# Test the case that feature size changes
g.nodes['user'].data['h'] = F.ones((3, 5))
g.apply_nodes(node_udf2, ntype='user')
assert F.array_equal(g.nodes['user'].data['h'], F.ones((3, 1)) * 5)

# test fail case
# fail due to multiple types
with pytest.raises(DGLError):
Expand Down

0 comments on commit 40caf1a

Please sign in to comment.