From 5ba5106acab6a642e9b790e5331ee519112a5623 Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Sun, 7 Aug 2022 01:33:16 -0700 Subject: [PATCH] [Bugfix] Fix the default value of `num_bases` in RelGraphConv module (#4321) * Fix doc and default settings for RelGraphConv * Add unit test * Split msg in two lines to pass CI-lint --- python/dgl/nn/pytorch/conv/relgraphconv.py | 22 +++++---- tests/pytorch/test_nn.py | 54 ++++++++++++++++++++++ 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/python/dgl/nn/pytorch/conv/relgraphconv.py b/python/dgl/nn/pytorch/conv/relgraphconv.py index 0e479ab9ef6b..a64e3abea4f3 100644 --- a/python/dgl/nn/pytorch/conv/relgraphconv.py +++ b/python/dgl/nn/pytorch/conv/relgraphconv.py @@ -49,16 +49,20 @@ class RelGraphConv(nn.Module): out_feat : int Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`. num_rels : int - Number of relations. . + Number of relations. regularizer : str, optional - Which weight regularizer to use "basis" or "bdd": + Which weight regularizer to use ("basis", "bdd" or ``None``): - - "basis" is short for basis-decomposition. - - "bdd" is short for block-diagonal-decomposition. + - "basis" is for basis-decomposition. + - "bdd" is for block-diagonal-decomposition. + - ``None`` applies no regularization. - Default applies no regularization. + Default: ``None``. num_bases : int, optional - Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + Number of bases. It comes into effect when a regularizer is applied. + If ``None``, it uses number of relations (``num_rels``). Default: ``None``. + Note that ``in_feat`` and ``out_feat`` must be divisible by ``num_bases`` + when applying "bdd" regularizer. bias : bool, optional True if bias is added. Default: ``True``. activation : callable, optional @@ -67,8 +71,8 @@ class RelGraphConv(nn.Module): True to include self loop message. Default: ``True``. dropout : float, optional Dropout rate. Default: ``0.0`` - layer_norm: float, optional - Add layer norm. Default: ``False`` + layer_norm: bool, optional + True to add layer norm. Default: ``False`` Examples -------- @@ -102,6 +106,8 @@ def __init__(self, dropout=0.0, layer_norm=False): super().__init__() + if regularizer is not None and num_bases is None: + num_bases = num_rels self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases) self.bias = bias self.activation = activation diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 8343fdfe66c2..45bd7469dbd3 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -412,6 +412,60 @@ def test_rgcn(idtype, O): h_new = rgc_bdd(g, h, r, norm) assert h_new.shape == (100, O) +@parametrize_idtype +@pytest.mark.parametrize('O', [1, 10, 40]) +def test_rgcn_default_nbasis(idtype, O): + ctx = F.ctx() + etype = [] + g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1)) + g = g.astype(idtype).to(F.ctx()) + # 5 etypes + R = 5 + for i in range(g.number_of_edges()): + etype.append(i % 5) + I = 10 + + h = th.randn((100, I)).to(ctx) + r = th.tensor(etype).to(ctx) + norm = th.rand((g.number_of_edges(), 1)).to(ctx) + sorted_r, idx = th.sort(r) + sorted_g = dgl.reorder_graph(g, edge_permute_algo='custom', permute_config={'edges_perm' : idx.to(idtype)}) + sorted_norm = norm[idx] + + rgc = nn.RelGraphConv(I, O, R).to(ctx) + th.save(rgc, tmp_buffer) # test pickle + rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx) + th.save(rgc_basis, tmp_buffer) # test pickle + if O % R == 0: + rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx) + th.save(rgc_bdd, tmp_buffer) # test pickle + + # basic usage + h_new = rgc(g, h, r) + assert h_new.shape == (100, O) + h_new_basis = rgc_basis(g, h, r) + assert h_new_basis.shape == (100, O) + if O % R == 0: + h_new_bdd = rgc_bdd(g, h, r) + assert h_new_bdd.shape == (100, O) + + # sorted input + h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True) + assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4) + h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True) + assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4) + if O % R == 0: + h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True) + assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4) + + # norm input + h_new = rgc(g, h, r, norm) + assert h_new.shape == (100, O) + h_new = rgc_basis(g, h, r, norm) + assert h_new.shape == (100, O) + if O % R == 0: + h_new = rgc_bdd(g, h, r, norm) + assert h_new.shape == (100, O) @parametrize_idtype @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))