Skip to content

Commit

Permalink
[Bug] Heterogeneous graph convolution bugfix (dmlc#2578)
Browse files Browse the repository at this point in the history
* fix heterograph conv

* remove test cases

* fix test

* fix test

* fix test

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
BarclayII and jermainewang authored Jan 29, 2021
1 parent e4ddafe commit 1f6eba9
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 94 deletions.
4 changes: 2 additions & 2 deletions python/dgl/nn/mxnet/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class HeteroGraphConv(nn.Block):
``'user'`` and ``'game'`` nodes.
>>> import mxnet.ndarray as nd
>>> h1 = {'user' : nd.randomrandn(g.number_of_nodes('user'), 5)}
>>> h1 = {'user' : nd.random.randn(g.number_of_nodes('user'), 5)}
>>> h2 = conv(g, h1)
>>> print(h2.keys())
dict_keys(['user', 'game'])
Expand Down Expand Up @@ -167,7 +167,7 @@ def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/pytorch/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/tensorflow/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def call(self, g, inputs, mod_args=None, mod_kwargs=None):
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
Expand Down
38 changes: 7 additions & 31 deletions tests/mxnet/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,19 +707,18 @@ def test_hetero_conv(agg, idtype):
uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))

h = conv(g, {'user': uf})
h = conv(g, {'user': uf, 'store': sf, 'game': gf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)

h = conv(g, {'user': uf, 'store': sf})
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
Expand All @@ -728,37 +727,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)

h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)

# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv.initialize(ctx=F.ctx())

h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)

# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)

# test with mod args
class MyMod(mx.gluon.nn.Block):
Expand All @@ -781,7 +757,7 @@ def forward(self, g, h, arg1=None): # mxnet does not support kwargs
agg)
conv.initialize(ctx=F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args)
h = conv(g, {'user' : uf, 'store' : sf, 'game': gf}, mod_args)
assert mod1.carg1 == 1
assert mod2.carg1 == 1
assert mod3.carg1 == 0
Expand Down
36 changes: 7 additions & 29 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,16 +939,17 @@ def test_hetero_conv(agg, idtype):
gf = F.randn((4, 4))
sf = F.randn((2, 3))

h = conv(g, {'user': uf})
h = conv(g, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)

h = conv(g, {'user': uf, 'store': sf})
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
Expand All @@ -957,37 +958,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)

h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)

# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv = conv.to(F.ctx())

h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)

# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)

# test with mod args
class MyMod(th.nn.Module):
Expand All @@ -1014,7 +992,7 @@ def forward(self, g, h, arg1=None, *, arg2=None):
conv = conv.to(F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1
assert mod1.carg2 == 0
assert mod2.carg1 == 1
Expand Down
37 changes: 7 additions & 30 deletions tests/tensorflow/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,19 +401,18 @@ def test_hetero_conv(agg, idtype):
uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))

h = conv(g, {'user': uf})
h = conv(g, {'user': uf, 'store': sf, 'game': gf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)

h = conv(g, {'user': uf, 'store': sf})
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx())
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
Expand All @@ -422,36 +421,14 @@ def test_hetero_conv(agg, idtype):
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)

h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)

# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)

h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
h = conv(block, {'user': uf, 'game': gf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)

# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
assert h['game'].shape == (4, 2, 4)

# test with mod args
class MyMod(tf.keras.layers.Layer):
Expand All @@ -477,7 +454,7 @@ def call(self, g, h, arg1=None, *, arg2=None):
agg)
mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1
assert mod1.carg2 == 0
assert mod2.carg1 == 1
Expand Down

0 comments on commit 1f6eba9

Please sign in to comment.