Skip to content

Commit

Permalink
[NN] Fix GINConv (dmlc#3692)
Browse files Browse the repository at this point in the history
* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix lint

* lint

* Update

* Update

* lint fix

* Fix CI

* Fix

* Fix CI

* Update

* Fix

* Update

* Update

* Update ginconv.py

* Update test_nn.py

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
mufeili and Ubuntu authored Jan 27, 2022
1 parent c8fef62 commit 05c6c3c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 21 deletions.
24 changes: 4 additions & 20 deletions python/dgl/nn/pytorch/conv/ginconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class GINConv(nn.Module):
----------
apply_func : callable activation function/layer or None
If not None, apply this function to the updated node feature,
the :math:`f_\Theta` in the formula.
the :math:`f_\Theta` in the formula, default: None.
aggregator_type : str
Aggregator type to use (``sum``, ``max`` or ``mean``).
Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'.
init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
Expand Down Expand Up @@ -90,8 +90,8 @@ class GINConv(nn.Module):
0.0000]], grad_fn=<ReluBackward0>)
"""
def __init__(self,
apply_func,
aggregator_type,
apply_func=None,
aggregator_type='sum',
init_eps=0,
learn_eps=False,
activation=None):
Expand All @@ -108,22 +108,6 @@ def __init__(self,
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))

self.reset_parameters()

def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The model parameters are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.apply_func.weight, gain=gain)

def forward(self, graph, feat, edge_weight=None):
r"""
Expand Down
7 changes: 6 additions & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,11 @@ def test_gin_conv(g, idtype, aggregator_type):
th.save(gin, tmp_buffer)

assert h.shape == (g.number_of_dst_nodes(), 12)

gin = nn.GINConv(None, aggregator_type)
th.save(gin, tmp_buffer)
gin = gin.to(ctx)
h = gin(g, feat)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
Expand Down Expand Up @@ -1383,4 +1388,4 @@ def test_twirls():
test_atomic_conv()
test_cf_conv()
test_hetero_conv()
test_twirls()
test_twirls()

0 comments on commit 05c6c3c

Please sign in to comment.