Skip to content

Commit

Permalink
[NN] Add EGATConv nn.module (dmlc#3425)
Browse files Browse the repository at this point in the history
* added nn pytorch egatconv

* aligned with test build

* aligned with test build

* fixed wihite spaces

* fixed wihite spaces

* fixed wihite spaces

* added missing egatconv in imports

* added indentation in forward

* GATConv based implementation

* removed **kw_args

* added dgl relative imports

* PR corrections

* added DGL Error to EGATConv imports

* Update test_nn.py

Co-authored-by: Argusmocny <[email protected]>
Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
3 people authored Oct 27, 2021
1 parent a9c83bc commit 51c6509
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba

1. [**Covid-19 Detection from Chest X-ray and Patient Metadata using Graph Convolutional Neural Networks**](https://arxiv.org/abs/2105.09720), *Thosini Bamunu Mudiyanselage, Nipuna Senanayake, Chunyan Ji, Yi Pan, Yanqing Zhang*

1. [**Graph neural networks and sequence embeddings enable the prediction and design of the cofactor specificity of Rossmann fold proteins**](https://www.biorxiv.org/content/10.1101/2021.05.05.440912v2), bioRxiv'21, *Kamil Kaminski, Jan Ludwiczak, Maciej Jasinski, Adriana Bukala, Rafal Madaj, Krzysztof Szczepaniak, Stanislaw Dunin-Horkawicz*
1. [**Rossmann-toolbox: a deep learning-based protocol for the prediction and design of cofactor specificity in Rossmann fold proteins**](https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbab371/6375059), Briefings in Bioinformatics, *Kamil Kaminski, Jan Ludwiczak, Maciej Jasinski, Adriana Bukala, Rafal Madaj, Krzysztof Szczepaniak, Stanislaw Dunin-Horkawicz*

1. [**LGESQL: Line Graph Enhanced Text-to-SQL Model with Mixed Local and Non-Local Relations**](https://arxiv.org/pdf/2106.01093.pdf), ACL'21, *Ruisheng Cao, Lu Chen, Zhi Chen, Yanbin Zhao, Su Zhu, Kai Yu*

Expand Down
8 changes: 8 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ GATConv
.. autoclass:: dgl.nn.pytorch.conv.GATConv
:members: forward
:show-inheritance:


EGATConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.conv.EGATConv
:members: forward
:show-inheritance:

EdgeConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
5 changes: 3 additions & 2 deletions python/dgl/nn/pytorch/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .chebconv import ChebConv
from .edgeconv import EdgeConv
from .gatconv import GATConv
from .egatconv import EGATConv
from .ginconv import GINConv
from .gmmconv import GMMConv
from .graphconv import GraphConv, EdgeWeightNorm
Expand All @@ -24,8 +25,8 @@
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv

__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'EGATConv', 'TAGConv', 'RelGraphConv',
'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv']
182 changes: 182 additions & 0 deletions python/dgl/nn/pytorch/conv/egatconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Torch modules for graph attention networks with fully valuable edges (EGAT)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from torch.nn import init

from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError

# pylint: enable=W0235
class EGATConv(nn.Module):
r"""
Description
-----------
Apply Graph Attention Layer over input graph. EGAT is an extension
of regular `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
handling edge features, detailed description is available in `Rossmann-Toolbox
<https://pubmed.ncbi.nlm.nih.gov/34571541/>`__ (see supplementary data).
The difference appears in the method how unnormalized attention scores :math:`e_{ij}`
are obtained:
.. math::
e_{ij} &= \vec{F} (f_{ij}^{\prime})
f_{ij}^{\prime} &= \mathrm{LeakyReLU}\left(A [ h_{i} \| f_{ij} \| h_{j}]\right)
where :math:`f_{ij}^{\prime}` are edge features, :math:`\mathrm{A}` is weight matrix and
:math: `\vec{F}` is weight vector. After that resulting node features
:math:`h_{i}^{\prime}` are updated in the same way as in regular GAT.
Parameters
----------
in_node_feats : int
Input node feature size :math:`h_{i}`.
in_edge_feats : int
Input edge feature size :math:`f_{ij}`.
out_node_feats : int
Output node feature size.
out_edge_feats : int
Output edge feature size :math:`f_{ij}^{\prime}`.
num_heads : int
Number of attention heads.
bias : bool, optional
If True, add bias term to :math: `f_{ij}^{\prime}`. Defaults: ``True``.
Examples
----------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import EGATConv
>>> num_nodes, num_edges = 8, 30
>>> # generate a graph
>>> graph = dgl.rand_graph((num_nodes,num_edges))
>>> node_feats = th.rand((num_nodes, 20))
>>> edge_feats = th.rand((num_edges, 12))
>>> egat = EGATConv(in_node_feats=20,
in_edge_feats=12,
out_node_feats=15,
out_edge_feats=10,
num_heads=3)
>>> #forward pass
>>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
>>> new_node_feats.shape, new_edge_feats.shape
((8, 3, 12), (30, 3, 10))
"""

def __init__(self,
in_node_feats,
in_edge_feats,
out_node_feats,
out_edge_feats,
num_heads,
bias=True):

super().__init__()
self._num_heads = num_heads
self._out_node_feats = out_node_feats
self._out_edge_feats = out_edge_feats
self.fc_node = nn.Linear(in_node_feats, out_node_feats*num_heads, bias=True)
self.fc_ni = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats)))
if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,)))
else:
self.register_buffer('bias', None)
self.reset_parameters()

def reset_parameters(self):
"""
Reinitialize learnable parameters.
"""
gain = init.calculate_gain('relu')
init.xavier_normal_(self.fc_node.weight, gain=gain)
init.xavier_normal_(self.fc_ni.weight, gain=gain)
init.xavier_normal_(self.fc_fij.weight, gain=gain)
init.xavier_normal_(self.fc_nj.weight, gain=gain)
init.xavier_normal_(self.attn, gain=gain)
init.constant_(self.bias, 0)

def forward(self, graph, nfeats, efeats, get_attention=False):
r"""
Compute new node and edge features.
Parameters
----------
graph : DGLGraph
The graph.
nfeats : torch.Tensor
The input node feature of shape :math:`(N, D_{in})`
where:
:math:`D_{in}` is size of input node feature,
:math:`N` is the number of nodes.
efeats: torch.Tensor
The input edge feature of shape :math:`(E, F_{in})`
where:
:math:`F_{in}` is size of input node feature,
:math:`E` is the number of edges.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
pair of torch.Tensor
node output features followed by edge output features
The node output feature of shape :math:`(N, H, D_{out})`
The edge output feature of shape :math:`(F, H, F_{out})`
where:
:math:`H` is the number of heads,
:math:`D_{out}` is size of output node feature,
:math:`F_{out}` is size of output edge feature.
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`.
This is returned only when :attr: `get_attention` is ``True``.
"""

with graph.local_scope():
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue.')

# TODO allow node src and dst feats
# calc edge attention
# same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
# https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py
f_ni = self.fc_ni(nfeats)
f_nj = self.fc_nj(nfeats)
f_fij = self.fc_fij(efeats)
graph.srcdata.update({'f_ni': f_ni})
graph.dstdata.update({'f_nj': f_nj})
# add ni, nj factors
graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp'))
# add fij to node factor
f_out = graph.edata.pop('f_tmp') + f_fij
if self.bias is not None:
f_out = f_out + self.bias
f_out = nn.functional.leaky_relu(f_out)
f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)
# compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata['a'] = edge_softmax(graph, e)
graph.ndata['h_out'] = self.fc_node(nfeats).view(-1, self._num_heads,
self._out_node_feats)
# calc weighted sum
graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
fn.sum('m', 'h_out'))

h_out = graph.ndata['h_out'].view(-1, self._num_heads, self._out_node_feats)
if get_attention:
return h_out, f_out, graph.edata.pop('a')
else:
return h_out, f_out
23 changes: 23 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,28 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@pytest.mark.parametrize('out_edge_feats', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
egat = nn.EGATConv(in_node_feats=10,
in_edge_feats=5,
out_node_feats=out_node_feats,
out_edge_feats=out_edge_feats,
num_heads=num_heads)
nfeat = F.randn((g.number_of_nodes(), 10))
efeat = F.randn((g.number_of_edges(), 5))

egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)
h, f, attn = egat(g, nfeat, efeat, True)

th.save(egat, tmp_buffer)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
Expand Down Expand Up @@ -1137,6 +1159,7 @@ def forward(self, g, h, arg1=None, *, arg2=None):
test_rgcn_sorted()
test_tagconv()
test_gat_conv()
test_egat_conv()
test_sage_conv()
test_sgc_conv()
test_appnp_conv()
Expand Down

0 comments on commit 51c6509

Please sign in to comment.