Skip to content

Commit

Permalink
[NN] Added activation function as an optional parameter to GINConv (d…
Browse files Browse the repository at this point in the history
…mlc#3565)

* Added activation function as an optional parameter

* lint fixes

* Modified the input parameters in tandem with other classes

* lint corrections

* corrected tests

* Reverting back to the old interface

* lint corrections

Co-authored-by: Minjie Wang <[email protected]>
Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
3 people authored Dec 29, 2021
1 parent 4889c57 commit 9c10654
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions python/dgl/nn/pytorch/conv/ginconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ class GINConv(nn.Module):
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``.
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Example
-------
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
Expand All @@ -67,15 +70,35 @@ class GINConv(nn.Module):
0.8843, -0.8764],
[-0.1804, 0.0758, -0.5159, 0.3569, -0.1408, -0.1395, -0.2387, 0.7773,
0.5266, -0.4465]], grad_fn=<AddmmBackward>)
>>> # With activation
>>> from torch.nn.functional import relu
>>> conv = GINConv(lin, 'max', activation=relu)
>>> res = conv(g, feat)
>>> res
tensor([[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[2.5011, 0.0000, 0.0089, 2.0541, 0.8262, 0.0000, 0.0000, 0.1371, 0.0000,
0.0000]], grad_fn=<ReluBackward0>)
"""
def __init__(self,
apply_func,
aggregator_type,
init_eps=0,
learn_eps=False):
learn_eps=False,
activation=None):
super(GINConv, self).__init__()
self.apply_func = apply_func
self._aggregator_type = aggregator_type
self.activation = activation
if aggregator_type not in ('sum', 'max', 'mean'):
raise KeyError(
'Aggregator type {} not recognized.'.format(aggregator_type))
Expand All @@ -85,6 +108,22 @@ def __init__(self,
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))

self.reset_parameters()

def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The model parameters are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.apply_func.weight, gain=gain)

def forward(self, graph, feat, edge_weight=None):
r"""
Expand Down Expand Up @@ -129,4 +168,7 @@ def forward(self, graph, feat, edge_weight=None):
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
# activation
if self.activation is not None:
rst = self.activation(rst)
return rst

0 comments on commit 9c10654

Please sign in to comment.