Skip to content

Commit

Permalink
[Bugfix] Fix the default value of num_bases in RelGraphConv module (d…
Browse files Browse the repository at this point in the history
…mlc#4321)

* Fix doc and default settings for RelGraphConv

* Add unit test

* Split msg in two lines to pass CI-lint
  • Loading branch information
chang-l authored Aug 7, 2022
1 parent 43ba94e commit 5ba5106
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
22 changes: 14 additions & 8 deletions python/dgl/nn/pytorch/conv/relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,20 @@ class RelGraphConv(nn.Module):
out_feat : int
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
num_rels : int
Number of relations. .
Number of relations.
regularizer : str, optional
Which weight regularizer to use "basis" or "bdd":
Which weight regularizer to use ("basis", "bdd" or ``None``):
- "basis" is short for basis-decomposition.
- "bdd" is short for block-diagonal-decomposition.
- "basis" is for basis-decomposition.
- "bdd" is for block-diagonal-decomposition.
- ``None`` applies no regularization.
Default applies no regularization.
Default: ``None``.
num_bases : int, optional
Number of bases. Needed when ``regularizer`` is specified. Default: ``None``.
Number of bases. It comes into effect when a regularizer is applied.
If ``None``, it uses number of relations (``num_rels``). Default: ``None``.
Note that ``in_feat`` and ``out_feat`` must be divisible by ``num_bases``
when applying "bdd" regularizer.
bias : bool, optional
True if bias is added. Default: ``True``.
activation : callable, optional
Expand All @@ -67,8 +71,8 @@ class RelGraphConv(nn.Module):
True to include self loop message. Default: ``True``.
dropout : float, optional
Dropout rate. Default: ``0.0``
layer_norm: float, optional
Add layer norm. Default: ``False``
layer_norm: bool, optional
True to add layer norm. Default: ``False``
Examples
--------
Expand Down Expand Up @@ -102,6 +106,8 @@ def __init__(self,
dropout=0.0,
layer_norm=False):
super().__init__()
if regularizer is not None and num_bases is None:
num_bases = num_rels
self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases)
self.bias = bias
self.activation = activation
Expand Down
54 changes: 54 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,60 @@ def test_rgcn(idtype, O):
h_new = rgc_bdd(g, h, r, norm)
assert h_new.shape == (100, O)

@parametrize_idtype
@pytest.mark.parametrize('O', [1, 10, 40])
def test_rgcn_default_nbasis(idtype, O):
ctx = F.ctx()
etype = []
g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))
g = g.astype(idtype).to(F.ctx())
# 5 etypes
R = 5
for i in range(g.number_of_edges()):
etype.append(i % 5)
I = 10

h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx)
norm = th.rand((g.number_of_edges(), 1)).to(ctx)
sorted_r, idx = th.sort(r)
sorted_g = dgl.reorder_graph(g, edge_permute_algo='custom', permute_config={'edges_perm' : idx.to(idtype)})
sorted_norm = norm[idx]

rgc = nn.RelGraphConv(I, O, R).to(ctx)
th.save(rgc, tmp_buffer) # test pickle
rgc_basis = nn.RelGraphConv(I, O, R, "basis").to(ctx)
th.save(rgc_basis, tmp_buffer) # test pickle
if O % R == 0:
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd").to(ctx)
th.save(rgc_bdd, tmp_buffer) # test pickle

# basic usage
h_new = rgc(g, h, r)
assert h_new.shape == (100, O)
h_new_basis = rgc_basis(g, h, r)
assert h_new_basis.shape == (100, O)
if O % R == 0:
h_new_bdd = rgc_bdd(g, h, r)
assert h_new_bdd.shape == (100, O)

# sorted input
h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)
assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)
h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)
assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)
if O % R == 0:
h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)
assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)

# norm input
h_new = rgc(g, h, r, norm)
assert h_new.shape == (100, O)
h_new = rgc_basis(g, h, r, norm)
assert h_new.shape == (100, O)
if O % R == 0:
h_new = rgc_bdd(g, h, r, norm)
assert h_new.shape == (100, O)

@parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
Expand Down

0 comments on commit 5ba5106

Please sign in to comment.