Skip to content

Commit

Permalink
fix residual (dmlc#2962)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Jun 1, 2021
1 parent 2ad7a9e commit fcfe52a
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 35 deletions.
2 changes: 2 additions & 0 deletions python/dgl/nn/mxnet/conv/gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ def forward(self, graph, feat, get_attention=False):
*src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
Expand Down
2 changes: 2 additions & 0 deletions python/dgl/nn/pytorch/conv/gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def forward(self, graph, feat, get_attention=False):
*src_prefix_shape, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
Expand Down
5 changes: 4 additions & 1 deletion python/dgl/nn/pytorch/conv/sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ def forward(self, graph, feat, edge_weight=None):
if isinstance(feat, tuple): # heterogeneous
graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
else:
graph.dstdata['h'] = graph.srcdata['h']
if graph.is_block:
graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
else:
graph.dstdata['h'] = graph.srcdata['h']
graph.update_all(msg_fn, fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
Expand Down
2 changes: 2 additions & 0 deletions python/dgl/nn/tensorflow/conv/gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def call(self, graph, feat, get_attention=False):
self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats))
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
Expand Down
34 changes: 20 additions & 14 deletions tests/mxnet/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
conv.initialize(ctx=F.ctx())
ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())
nsrc = ndst = g.number_of_nodes()
nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).as_in_context(F.ctx())
if weight:
h_out = conv(g, h)
Expand Down Expand Up @@ -170,12 +171,17 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5
gat.initialize(ctx=ctx)
print(gat)
feat = F.randn((g.number_of_nodes(), 10))
feat = F.randn((g.number_of_src_nodes(), 10))
h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, True)
assert a.shape == (g.number_of_edges(), num_heads, 1)

# test residual connection
gat = nn.GATConv(10, out_dim, num_heads, residual=True)
gat.initialize(ctx=ctx)
h = gat(g, feat)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
Expand All @@ -199,7 +205,7 @@ def test_sage_conv(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
sage = nn.SAGEConv(5, out_dim, aggre_type)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == out_dim
Expand Down Expand Up @@ -277,9 +283,9 @@ def test_agnn_conv(g, idtype):
agnn_conv = nn.AGNNConv(0.1, True)
agnn_conv.initialize(ctx=ctx)
print(agnn_conv)
feat = F.randn((g.number_of_nodes(), 10))
feat = F.randn((g.number_of_src_nodes(), 10))
h = agnn_conv(g, feat)
assert h.shape == (g.number_of_nodes(), 10)
assert h.shape == (g.number_of_dst_nodes(), 10)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand Down Expand Up @@ -387,9 +393,9 @@ def test_edge_conv(g, idtype, out_dim):
edge_conv.initialize(ctx=ctx)
print(edge_conv)
# test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h0 = F.randn((g.number_of_src_nodes(), 5))
h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim)
assert h1.shape == (g.number_of_dst_nodes(), out_dim)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand Down Expand Up @@ -418,9 +424,9 @@ def test_gin_conv(g, idtype, aggregator_type):
print(gin_conv)

# test #1: basic
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
h = gin_conv(g, feat)
assert h.shape == (g.number_of_nodes(), 5)
assert h.shape == (g.number_of_dst_nodes(), 5)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
Expand All @@ -446,10 +452,10 @@ def test_gmm_conv(g, idtype):
ctx = F.ctx()
gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
gmm_conv.initialize(ctx=ctx)
h0 = F.randn((g.number_of_nodes(), 5))
h0 = F.randn((g.number_of_src_nodes(), 5))
pseudo = F.randn((g.number_of_edges(), 5))
h1 = gmm_conv(g, h0, pseudo)
assert h1.shape == (g.number_of_nodes(), 2)
assert h1.shape == (g.number_of_dst_nodes(), 2)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand All @@ -473,10 +479,10 @@ def test_nn_conv(g, idtype):
nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
nn_conv.initialize(ctx=ctx)
# test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h0 = F.randn((g.number_of_src_nodes(), 5))
etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
h1 = nn_conv(g, h0, etypes)
assert h1.shape == (g.number_of_nodes(), 2)
assert h1.shape == (g.number_of_dst_nodes(), 2)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
Expand Down
24 changes: 12 additions & 12 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,14 +533,14 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
gat = nn.GATConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
gat = gat.to(ctx)
h = gat(g, feat)

# test pickle
th.save(gat, tmp_buffer)

assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1)

Expand Down Expand Up @@ -570,7 +570,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
def test_sage_conv(idtype, g, aggre_type):
g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
sage = sage.to(F.ctx())
# test pickle
th.save(sage, tmp_buffer)
Expand Down Expand Up @@ -664,14 +664,14 @@ def test_gin_conv(g, idtype, aggregator_type):
th.nn.Linear(5, 12),
aggregator_type
)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
gin = gin.to(ctx)
h = gin(g, feat)

# test pickle
th.save(h, tmp_buffer)

assert h.shape == (g.number_of_nodes(), 12)
assert h.shape == (g.number_of_dst_nodes(), 12)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand All @@ -694,10 +694,10 @@ def test_agnn_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
agnn = nn.AGNNConv(1)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
agnn = agnn.to(ctx)
h = agnn(g, feat)
assert h.shape == (g.number_of_nodes(), 5)
assert h.shape == (g.number_of_dst_nodes(), 5)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand Down Expand Up @@ -732,7 +732,7 @@ def test_nn_conv(g, idtype):
ctx = F.ctx()
edge_func = th.nn.Linear(4, 5 * 10)
nnconv = nn.NNConv(5, 10, edge_func, 'mean')
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
efeat = F.randn((g.number_of_edges(), 4))
nnconv = nnconv.to(ctx)
h = nnconv(g, feat, efeat)
Expand Down Expand Up @@ -837,9 +837,9 @@ def test_edge_conv(g, idtype, out_dim):
# test pickle
th.save(edge_conv, tmp_buffer)

h0 = F.randn((g.number_of_nodes(), 5))
h0 = F.randn((g.number_of_src_nodes(), 5))
h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim)
assert h1.shape == (g.number_of_dst_nodes(), out_dim)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand All @@ -862,14 +862,14 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
dotgat = nn.DotGatConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
dotgat = dotgat.to(ctx)

# test pickle
th.save(dotgat, tmp_buffer)

h = dotgat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1)

Expand Down
18 changes: 11 additions & 7 deletions tests/tensorflow/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,16 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
gat = nn.GATConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
h = gat(g, feat)
assert h.shape == (g.number_of_nodes(), num_heads, out_dim)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1)

# test residual connection
gat = nn.GATConv(5, out_dim, num_heads, residual=True)
h = gat(g, feat)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 2])
Expand All @@ -297,7 +301,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
def test_sage_conv(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv(5, out_dim, aggre_type)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
h = sage(g, feat)
assert h.shape[-1] == out_dim

Expand Down Expand Up @@ -374,9 +378,9 @@ def test_gin_conv(g, idtype, aggregator_type):
tf.keras.layers.Dense(12),
aggregator_type
)
feat = F.randn((g.number_of_nodes(), 5))
feat = F.randn((g.number_of_src_nodes(), 5))
h = gin(g, feat)
assert h.shape == (g.number_of_nodes(), 12)
assert h.shape == (g.number_of_dst_nodes(), 12)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
Expand All @@ -398,9 +402,9 @@ def test_edge_conv(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx())
edge_conv = nn.EdgeConv(out_dim)

h0 = F.randn((g.number_of_nodes(), 5))
h0 = F.randn((g.number_of_src_nodes(), 5))
h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), out_dim)
assert h1.shape == (g.number_of_dst_nodes(), out_dim)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/graph_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def batched_graph0():
g3 = dgl.add_self_loop(dgl.graph(([0], [1])))
return dgl.batch([g1, g2, g3])

@register_case(['block', 'bipartite', 'block-biparitite'])
@register_case(['block', 'bipartite', 'block-bipartite'])
def block_graph0():
g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)
g = g.to(F.cpu())
Expand Down

0 comments on commit fcfe52a

Please sign in to comment.