diff --git a/python/dgl/nn/mxnet/conv/gatconv.py b/python/dgl/nn/mxnet/conv/gatconv.py index 81e31f8199ec..6d86e0de9c61 100644 --- a/python/dgl/nn/mxnet/conv/gatconv.py +++ b/python/dgl/nn/mxnet/conv/gatconv.py @@ -269,6 +269,8 @@ def forward(self, graph, feat, get_attention=False): *src_prefix_shape, self._num_heads, self._out_feats) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] + h_dst = h_dst[:graph.number_of_dst_nodes()] + dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] # NOTE: GAT paper uses "first concatenation then linear projection" # to compute attention scores, while ours is "first projection then # addition", the two approaches are mathematically equivalent: diff --git a/python/dgl/nn/pytorch/conv/gatconv.py b/python/dgl/nn/pytorch/conv/gatconv.py index 47e8936c45da..528045eda152 100644 --- a/python/dgl/nn/pytorch/conv/gatconv.py +++ b/python/dgl/nn/pytorch/conv/gatconv.py @@ -287,6 +287,8 @@ def forward(self, graph, feat, get_attention=False): *src_prefix_shape, self._num_heads, self._out_feats) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] + h_dst = h_dst[:graph.number_of_dst_nodes()] + dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] # NOTE: GAT paper uses "first concatenation then linear projection" # to compute attention scores, while ours is "first projection then # addition", the two approaches are mathematically equivalent: diff --git a/python/dgl/nn/pytorch/conv/sageconv.py b/python/dgl/nn/pytorch/conv/sageconv.py index 4a92bc0716da..0899cfce04db 100644 --- a/python/dgl/nn/pytorch/conv/sageconv.py +++ b/python/dgl/nn/pytorch/conv/sageconv.py @@ -238,7 +238,10 @@ def forward(self, graph, feat, edge_weight=None): if isinstance(feat, tuple): # heterogeneous graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst else: - graph.dstdata['h'] = graph.srcdata['h'] + if graph.is_block: + graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()] + else: + graph.dstdata['h'] = graph.srcdata['h'] graph.update_all(msg_fn, fn.sum('m', 'neigh')) # divide in_degrees degs = graph.in_degrees().to(feat_dst) diff --git a/python/dgl/nn/tensorflow/conv/gatconv.py b/python/dgl/nn/tensorflow/conv/gatconv.py index 1b6500dc2f11..45f099d4b66f 100644 --- a/python/dgl/nn/tensorflow/conv/gatconv.py +++ b/python/dgl/nn/tensorflow/conv/gatconv.py @@ -263,6 +263,8 @@ def call(self, graph, feat, get_attention=False): self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats)) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] + h_dst = h_dst[:graph.number_of_dst_nodes()] + dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] # NOTE: GAT paper uses "first concatenation then linear projection" # to compute attention scores, while ours is "first projection then # addition", the two approaches are mathematically equivalent: diff --git a/tests/mxnet/test_nn.py b/tests/mxnet/test_nn.py index 88a6569be0cb..c6aa703e2f03 100644 --- a/tests/mxnet/test_nn.py +++ b/tests/mxnet/test_nn.py @@ -90,7 +90,8 @@ def test_graph_conv2(idtype, g, norm, weight, bias, out_dim): conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias) conv.initialize(ctx=F.ctx()) ext_w = F.randn((5, out_dim)).as_in_context(F.ctx()) - nsrc = ndst = g.number_of_nodes() + nsrc = g.number_of_src_nodes() + ndst = g.number_of_dst_nodes() h = F.randn((nsrc, 5)).as_in_context(F.ctx()) if weight: h_out = conv(g, h) @@ -170,12 +171,17 @@ def test_gat_conv(g, idtype, out_dim, num_heads): gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5 gat.initialize(ctx=ctx) print(gat) - feat = F.randn((g.number_of_nodes(), 10)) + feat = F.randn((g.number_of_src_nodes(), 10)) h = gat(g, feat) - assert h.shape == (g.number_of_nodes(), num_heads, out_dim) + assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) _, a = gat(g, feat, True) assert a.shape == (g.number_of_edges(), num_heads, 1) + # test residual connection + gat = nn.GATConv(10, out_dim, num_heads, residual=True) + gat.initialize(ctx=ctx) + h = gat(g, feat) + @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('out_dim', [1, 2]) @@ -199,7 +205,7 @@ def test_sage_conv(idtype, g, aggre_type, out_dim): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() sage = nn.SAGEConv(5, out_dim, aggre_type) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) sage.initialize(ctx=ctx) h = sage(g, feat) assert h.shape[-1] == out_dim @@ -277,9 +283,9 @@ def test_agnn_conv(g, idtype): agnn_conv = nn.AGNNConv(0.1, True) agnn_conv.initialize(ctx=ctx) print(agnn_conv) - feat = F.randn((g.number_of_nodes(), 10)) + feat = F.randn((g.number_of_src_nodes(), 10)) h = agnn_conv(g, feat) - assert h.shape == (g.number_of_nodes(), 10) + assert h.shape == (g.number_of_dst_nodes(), 10) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @@ -387,9 +393,9 @@ def test_edge_conv(g, idtype, out_dim): edge_conv.initialize(ctx=ctx) print(edge_conv) # test #1: basic - h0 = F.randn((g.number_of_nodes(), 5)) + h0 = F.randn((g.number_of_src_nodes(), 5)) h1 = edge_conv(g, h0) - assert h1.shape == (g.number_of_nodes(), out_dim) + assert h1.shape == (g.number_of_dst_nodes(), out_dim) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @@ -418,9 +424,9 @@ def test_gin_conv(g, idtype, aggregator_type): print(gin_conv) # test #1: basic - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) h = gin_conv(g, feat) - assert h.shape == (g.number_of_nodes(), 5) + assert h.shape == (g.number_of_dst_nodes(), 5) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'])) @@ -446,10 +452,10 @@ def test_gmm_conv(g, idtype): ctx = F.ctx() gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv.initialize(ctx=ctx) - h0 = F.randn((g.number_of_nodes(), 5)) + h0 = F.randn((g.number_of_src_nodes(), 5)) pseudo = F.randn((g.number_of_edges(), 5)) h1 = gmm_conv(g, h0, pseudo) - assert h1.shape == (g.number_of_nodes(), 2) + assert h1.shape == (g.number_of_dst_nodes(), 2) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @@ -473,10 +479,10 @@ def test_nn_conv(g, idtype): nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max') nn_conv.initialize(ctx=ctx) # test #1: basic - h0 = F.randn((g.number_of_nodes(), 5)) + h0 = F.randn((g.number_of_src_nodes(), 5)) etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx) h1 = nn_conv(g, h0, etypes) - assert h1.shape == (g.number_of_nodes(), 2) + assert h1.shape == (g.number_of_dst_nodes(), 2) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'])) diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index ddbca5af5195..2e3d0397a615 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -533,14 +533,14 @@ def test_gat_conv(g, idtype, out_dim, num_heads): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() gat = nn.GATConv(5, out_dim, num_heads) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) gat = gat.to(ctx) h = gat(g, feat) # test pickle th.save(gat, tmp_buffer) - assert h.shape == (g.number_of_nodes(), num_heads, out_dim) + assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) _, a = gat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1) @@ -570,7 +570,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads): def test_sage_conv(idtype, g, aggre_type): g = g.astype(idtype).to(F.ctx()) sage = nn.SAGEConv(5, 10, aggre_type) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) sage = sage.to(F.ctx()) # test pickle th.save(sage, tmp_buffer) @@ -664,14 +664,14 @@ def test_gin_conv(g, idtype, aggregator_type): th.nn.Linear(5, 12), aggregator_type ) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) gin = gin.to(ctx) h = gin(g, feat) # test pickle th.save(h, tmp_buffer) - assert h.shape == (g.number_of_nodes(), 12) + assert h.shape == (g.number_of_dst_nodes(), 12) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @@ -694,10 +694,10 @@ def test_agnn_conv(g, idtype): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() agnn = nn.AGNNConv(1) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) agnn = agnn.to(ctx) h = agnn(g, feat) - assert h.shape == (g.number_of_nodes(), 5) + assert h.shape == (g.number_of_dst_nodes(), 5) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @@ -732,7 +732,7 @@ def test_nn_conv(g, idtype): ctx = F.ctx() edge_func = th.nn.Linear(4, 5 * 10) nnconv = nn.NNConv(5, 10, edge_func, 'mean') - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) efeat = F.randn((g.number_of_edges(), 4)) nnconv = nnconv.to(ctx) h = nnconv(g, feat, efeat) @@ -837,9 +837,9 @@ def test_edge_conv(g, idtype, out_dim): # test pickle th.save(edge_conv, tmp_buffer) - h0 = F.randn((g.number_of_nodes(), 5)) + h0 = F.randn((g.number_of_src_nodes(), 5)) h1 = edge_conv(g, h0) - assert h1.shape == (g.number_of_nodes(), out_dim) + assert h1.shape == (g.number_of_dst_nodes(), out_dim) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @@ -862,14 +862,14 @@ def test_dotgat_conv(g, idtype, out_dim, num_heads): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() dotgat = nn.DotGatConv(5, out_dim, num_heads) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) dotgat = dotgat.to(ctx) # test pickle th.save(dotgat, tmp_buffer) h = dotgat(g, feat) - assert h.shape == (g.number_of_nodes(), num_heads, out_dim) + assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) _, a = dotgat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1) diff --git a/tests/tensorflow/test_nn.py b/tests/tensorflow/test_nn.py index ff9e695f817e..261cc0b244a1 100644 --- a/tests/tensorflow/test_nn.py +++ b/tests/tensorflow/test_nn.py @@ -270,12 +270,16 @@ def test_gat_conv(g, idtype, out_dim, num_heads): g = g.astype(idtype).to(F.ctx()) ctx = F.ctx() gat = nn.GATConv(5, out_dim, num_heads) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) h = gat(g, feat) - assert h.shape == (g.number_of_nodes(), num_heads, out_dim) + assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) _, a = gat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1) + # test residual connection + gat = nn.GATConv(5, out_dim, num_heads, residual=True) + h = gat(g, feat) + @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('out_dim', [1, 2]) @@ -297,7 +301,7 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads): def test_sage_conv(idtype, g, aggre_type, out_dim): g = g.astype(idtype).to(F.ctx()) sage = nn.SAGEConv(5, out_dim, aggre_type) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) h = sage(g, feat) assert h.shape[-1] == out_dim @@ -374,9 +378,9 @@ def test_gin_conv(g, idtype, aggregator_type): tf.keras.layers.Dense(12), aggregator_type ) - feat = F.randn((g.number_of_nodes(), 5)) + feat = F.randn((g.number_of_src_nodes(), 5)) h = gin(g, feat) - assert h.shape == (g.number_of_nodes(), 12) + assert h.shape == (g.number_of_dst_nodes(), 12) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'])) @@ -398,9 +402,9 @@ def test_edge_conv(g, idtype, out_dim): g = g.astype(idtype).to(F.ctx()) edge_conv = nn.EdgeConv(out_dim) - h0 = F.randn((g.number_of_nodes(), 5)) + h0 = F.randn((g.number_of_src_nodes(), 5)) h1 = edge_conv(g, h0) - assert h1.shape == (g.number_of_nodes(), out_dim) + assert h1.shape == (g.number_of_dst_nodes(), out_dim) @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) diff --git a/tests/test_utils/graph_cases.py b/tests/test_utils/graph_cases.py index 39a1bd5f18fb..864d4accc9d8 100644 --- a/tests/test_utils/graph_cases.py +++ b/tests/test_utils/graph_cases.py @@ -96,7 +96,7 @@ def batched_graph0(): g3 = dgl.add_self_loop(dgl.graph(([0], [1]))) return dgl.batch([g1, g2, g3]) -@register_case(['block', 'bipartite', 'block-biparitite']) +@register_case(['block', 'bipartite', 'block-bipartite']) def block_graph0(): g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100) g = g.to(F.cpu())