Skip to content

Commit

Permalink
[NN] GINEConv (dmlc#3934)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
mufeili and jermainewang authored Apr 25, 2022
1 parent df7a612 commit 248bece
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/nn-pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Conv Layers
~dgl.nn.pytorch.conv.SGConv
~dgl.nn.pytorch.conv.APPNPConv
~dgl.nn.pytorch.conv.GINConv
~dgl.nn.pytorch.conv.GINEConv
~dgl.nn.pytorch.conv.GatedGraphConv
~dgl.nn.pytorch.conv.GMMConv
~dgl.nn.pytorch.conv.ChebConv
Expand Down
11 changes: 6 additions & 5 deletions python/dgl/nn/pytorch/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .gatv2conv import GATv2Conv
from .egatconv import EGATConv
from .ginconv import GINConv
from .gineconv import GINEConv
from .gmmconv import GMMConv
from .graphconv import GraphConv, EdgeWeightNorm
from .nnconv import NNConv
Expand All @@ -31,8 +32,8 @@
from .pnaconv import PNAConv

__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes', 'EGNNConv',
'PNAConv']
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GINEConv',
'GatedGraphConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv',
'DenseSAGEConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv',
'TWIRLSConv', 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes',
'EGNNConv', 'PNAConv']
98 changes: 98 additions & 0 deletions python/dgl/nn/pytorch/conv/gineconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Torch Module for Graph Isomorphism Network layer variant with edge features"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
import torch.nn.functional as F
from torch import nn

from .... import function as fn
from ....utils import expand_as_pair

class GINEConv(nn.Module):
r"""Graph Isomorphism Network with Edge Features, introduced by
`Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
.. math::
h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
\sum_{j\in\mathcal{N}(i)}\mathrm{ReLU}(h_j^{l} + e_{j,i}^{l})\right)
where :math:`e_{j,i}^{l}` is the edge feature.
Parameters
----------
apply_func : callable module or None
The :math:`f_\Theta` in the formula. If not None, it will be applied to
the updated node features. The default value is None.
init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``.
Examples
--------
>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GINEConv
>>> g = dgl.graph(([0, 1, 2], [1, 1, 3]))
>>> in_feats = 10
>>> out_feats = 20
>>> nfeat = torch.randn(g.num_nodes(), in_feats)
>>> efeat = torch.randn(g.num_edges(), in_feats)
>>> conv = GINEConv(nn.Linear(in_feats, out_feats))
>>> res = conv(g, nfeat, efeat)
>>> print(res.shape)
torch.Size([4, 20])
"""
def __init__(self,
apply_func=None,
init_eps=0,
learn_eps=False):
super(GINEConv, self).__init__()
self.apply_func = apply_func
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))

def message(self, edges):
r"""User-defined Message Function"""
return {'m': F.relu(edges.src['hn'] + edges.data['he'])}

def forward(self, graph, node_feat, edge_feat):
r"""Forward computation.
Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it is the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input feature size requirement of ``apply_func``.
edge_feat : torch.Tensor
Edge feature. It is a tensor of shape :math:`(E, D_{in})` where :math:`E`
is the number of edges.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where
:math:`D_{out}` is the output feature size of ``apply_func``.
If ``apply_func`` is None, :math:`D_{out}` should be the same
as :math:`D_{in}`.
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(node_feat, graph)
graph.srcdata['hn'] = feat_src
graph.edata['he'] = edge_feat
graph.update_all(self.message, fn.sum('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
25 changes: 24 additions & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,29 @@ def test_gin_conv(g, idtype, aggregator_type):
gin = gin.to(ctx)
h = gin(g, feat)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
def test_gine_conv(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gine = nn.GINEConv(
th.nn.Linear(5, 12)
)
th.save(gine, tmp_buffer)
nfeat = F.randn((g.number_of_src_nodes(), 5))
efeat = F.randn((g.num_edges(), 5))
gine = gine.to(ctx)
h = gine(g, nfeat, efeat)

# test pickle
th.save(gine, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), 12)

gine = nn.GINEConv(None)
th.save(gine, tmp_buffer)
gine = gine.to(ctx)
h = gine(g, nfeat, efeat)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
Expand Down Expand Up @@ -1441,7 +1464,7 @@ def test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):

@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('out_size', [16, 32])
@pytest.mark.parametrize('aggregators',
@pytest.mark.parametrize('aggregators',
[['mean', 'max', 'sum'], ['min', 'std', 'var'], ['moment3', 'moment4', 'moment5']])
@pytest.mark.parametrize('scalers', [['identity'], ['amplification', 'attenuation']])
@pytest.mark.parametrize('delta', [2.5, 7.4])
Expand Down

0 comments on commit 248bece

Please sign in to comment.