Skip to content

Commit

Permalink
[NN] Support scalar edge weight for GraphConv, SAGEConv and GINConv (d…
Browse files Browse the repository at this point in the history
…mlc#2557)

* add edge weight in forward

* fix lint

* fix

* fix

* address comments

* add utils

* add util to normalize in gcn way

* fix lint

* add unittest

* fix lint

* fix docstring

* fix docstring

* address comments

* improve notation consistence

* use preferred fn
  • Loading branch information
hetong007 authored Jan 26, 2021
1 parent 8900450 commit 0855d25
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 21 deletions.
7 changes: 7 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ GraphConv
:members: weight, bias, forward, reset_parameters
:show-inheritance:

EdgeWeightNorm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.conv.EdgeWeightNorm
:members: forward
:show-inheritance:

RelGraphConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 3 additions & 3 deletions python/dgl/function/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ def src_mul_edge(src, edge, out):
----------
src : str
The source feature field.
dst : str
The destination feature field.
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'h', 'm')
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
"""
return getattr(sys.modules[__name__], "u_mul_e")(src, edge, out)

Expand Down
4 changes: 2 additions & 2 deletions python/dgl/nn/pytorch/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .gatconv import GATConv
from .ginconv import GINConv
from .gmmconv import GMMConv
from .graphconv import GraphConv
from .graphconv import GraphConv, EdgeWeightNorm
from .nnconv import NNConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
Expand All @@ -22,7 +22,7 @@
from .cfconv import CFConv
from .dotgatconv import DotGatConv

__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv']
23 changes: 21 additions & 2 deletions python/dgl/nn/pytorch/conv/ginconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class GINConv(nn.Module):
\mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right)
If a weight tensor on each edge is provided, the weighted graph convolution is defined as:
.. math::
h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
\mathrm{aggregate}\left(\left\{e_{ji} h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right)
where :math:`e_{ji}` is the weight on the edge from node :math:`j` to node :math:`i`.
Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.
Parameters
----------
apply_func : callable activation function/layer or None
Expand Down Expand Up @@ -80,7 +90,7 @@ def __init__(self,
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))

def forward(self, graph, feat):
def forward(self, graph, feat, edge_weight=None):
r"""
Description
Expand All @@ -98,6 +108,9 @@ def forward(self, graph, feat):
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input dimensionality requirement of ``apply_func``.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
Expand All @@ -108,9 +121,15 @@ def forward(self, graph, feat):
as input dimensionality.
"""
with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
Expand Down
165 changes: 156 additions & 9 deletions python/dgl/nn/pytorch/conv/graphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,135 @@
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair
from ....transform import reverse
from ....convert import block_to_graph
from ....heterograph import DGLBlock

class EdgeWeightNorm(nn.Module):
r"""
Description
-----------
This module normalizes positive scalar edge weights on a graph
following the form in `GCN <https://arxiv.org/abs/1609.02907>`__.
Mathematically, setting ``norm='both'`` yields the following normalization term:
.. math:
c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}})
And, setting ``norm='right'`` yields the following normalization term:
.. math:
c_{ji} = (\sum_{k\in\mathcal{N}(i)}}e_{ki})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
The module returns the normalized weight :math:`e_{ji} / c_{ji}`.
Parameters
----------
norm : str, optional
The normalizer as specified above. Default is `'both'`.
eps : float, optional
A small offset value in the denominator. Default is 0.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import EdgeWeightNorm, GraphConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> edge_weight = th.tensor([0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1])
>>> norm = EdgeWeightNorm(norm='both')
>>> norm_edge_weight = norm(g, edge_weight)
>>> conv = GraphConv(10, 2, norm='none', weight=True, bias=True)
>>> res = conv(g, feat, edge_weight=norm_edge_weight)
>>> print(res)
tensor([[-1.1849, -0.7525],
[-1.3514, -0.8582],
[-1.2384, -0.7865],
[-1.9949, -1.2669],
[-1.3658, -0.8674],
[-0.8323, -0.5286]], grad_fn=<AddBackward0>)
"""
def __init__(self, norm='both', eps=0.):
super(EdgeWeightNorm, self).__init__()
self._norm = norm
self._eps = eps

def forward(self, graph, edge_weight):
r"""
Description
-----------
Compute normalized edge weight for the GCN model.
Parameters
----------
graph : DGLGraph
The graph.
edge_weight : torch.Tensor
Unnormalized scalar weights on the edges.
The shape is expected to be :math:`(|E|)`.
Returns
-------
torch.Tensor
The normalized edge weight.
Raises
------
DGLError
Case 1:
The edge weight is multi-dimensional. Currently this module
only supports a scalar weight on each edge.
Case 2:
The edge weight has non-positive values with ``norm='both'``.
This will trigger square root and division by a non-positive number.
"""
with graph.local_scope():
if isinstance(graph, DGLBlock):
graph = block_to_graph(graph)
if len(edge_weight.shape) > 1:
raise DGLError('Currently the normalization is only defined '
'on scalar edge weight. Please customize the '
'normalization for your high-dimensional weights.')
if self._norm == 'both' and th.any(edge_weight <= 0).item():
raise DGLError('Non-positive edge weight detected with `norm="both"`. '
'This leads to square root of zero or negative values.')

dev = graph.device
graph.srcdata['_src_out_w'] = th.ones((graph.number_of_src_nodes())).float().to(dev)
graph.dstdata['_dst_in_w'] = th.ones((graph.number_of_dst_nodes())).float().to(dev)
graph.edata['_edge_w'] = edge_weight

if self._norm == 'both':
reversed_g = reverse(graph)
reversed_g.edata['_edge_w'] = edge_weight
reversed_g.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'out_weight'))
degs = reversed_g.dstdata['out_weight'] + self._eps
norm = th.pow(degs, -0.5)
graph.srcdata['_src_out_w'] = norm

if self._norm != 'none':
graph.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'in_weight'))
degs = graph.dstdata['in_weight'] + self._eps
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
graph.dstdata['_dst_in_w'] = norm

graph.apply_edges(lambda e: {'_norm_edge_weights': e.src['_src_out_w'] * \
e.dst['_dst_in_w'] * \
e.data['_edge_w']})
return graph.edata['_norm_edge_weights']

# pylint: disable=W0235
class GraphConv(nn.Module):
Expand All @@ -18,13 +147,25 @@ class GraphConv(nn.Module):
and mathematically is defined as follows:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`\mathcal{N}(i)` is the set of neighbors of node :math:`i`,
:math:`c_{ij}` is the product of the square root of node degrees
(i.e., :math:`c_{ij} = \sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}`),
:math:`c_{ji}` is the product of the square root of node degrees
(i.e., :math:`c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|}`),
and :math:`\sigma` is an activation function.
If a weight tensor on each edge is provided, the weighted graph convolution is defined as:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
This is NOT equivalent to the weighted graph convolutional network formulation in the paper.
To customize the normalization term :math:`c_{ji}`, one can first set ``norm='none'`` for
the model, and send the pre-normalized :math:`e_{ji}` to the forward computation. We provide
:class:`~dgl.nn.pytorch.EdgeWeightNorm` to normalize scalar edge weight following the GCN paper.
Parameters
----------
in_feats : int
Expand All @@ -35,7 +176,7 @@ class GraphConv(nn.Module):
How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied.
where the :math:`c_{ji}` in the paper is applied.
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix.
Expand Down Expand Up @@ -185,7 +326,7 @@ def set_allow_zero_in_degree(self, set_value):
"""
self._allow_zero_in_degree = set_value

def forward(self, graph, feat, weight=None):
def forward(self, graph, feat, weight=None, edge_weight=None):
r"""
Description
Expand All @@ -205,6 +346,9 @@ def forward(self, graph, feat, weight=None):
:math:`(N_{out}, D_{in_{dst}})`.
weight : torch.Tensor, optional
Optional external weight tensor.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
Expand Down Expand Up @@ -243,6 +387,11 @@ def forward(self, graph, feat, weight=None):
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph)
Expand All @@ -266,14 +415,12 @@ def forward(self, graph, feat, weight=None):
if weight is not None:
feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
if weight is not None:
rst = th.matmul(rst, weight)
Expand Down
Loading

0 comments on commit 0855d25

Please sign in to comment.