diff --git a/python/dgl/nn/pytorch/conv/relgraphconv.py b/python/dgl/nn/pytorch/conv/relgraphconv.py index 9848dc5f5aa8..b10e4c105ec6 100644 --- a/python/dgl/nn/pytorch/conv/relgraphconv.py +++ b/python/dgl/nn/pytorch/conv/relgraphconv.py @@ -347,6 +347,8 @@ def forward(self, g, feat, etypes, norm=None): pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device)) num = th.tensor([len(etypes)], device=g.device) etypes = (th.cat([pos[1:], num]) - pos).tolist() + if norm is not None: + norm = norm[index] with g.local_scope(): g.srcdata['h'] = feat diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index a672d6df5bd5..9e4451b69036 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -328,7 +328,7 @@ def test_rgcn(): assert F.allclose(h_new, h_new_low) # with norm - norm = th.zeros((g.number_of_edges(), 1)).to(ctx) + norm = th.rand((g.number_of_edges(), 1)).to(ctx) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx) @@ -408,7 +408,7 @@ def test_rgcn_sorted(): assert F.allclose(h_new, h_new_low) # with norm - norm = th.zeros((g.number_of_edges(), 1)).to(ctx) + norm = th.rand((g.number_of_edges(), 1)).to(ctx) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx) @@ -965,6 +965,7 @@ def forward(self, g, h, arg1=None, *, arg2=None): test_simple_pool() test_set_trans() test_rgcn() + test_rgcn_sorted() test_tagconv() test_gat_conv() test_sage_conv()