From 51c65097049026e482642f005decce1e04318537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Kami=C5=84ski?= <63782865+Argusmocny@users.noreply.github.com> Date: Wed, 27 Oct 2021 18:38:08 +0200 Subject: [PATCH] [NN] Add EGATConv nn.module (#3425) * added nn pytorch egatconv * aligned with test build * aligned with test build * fixed wihite spaces * fixed wihite spaces * fixed wihite spaces * added missing egatconv in imports * added indentation in forward * GATConv based implementation * removed **kw_args * added dgl relative imports * PR corrections * added DGL Error to EGATConv imports * Update test_nn.py Co-authored-by: Argusmocny Co-authored-by: Mufei Li --- README.md | 2 +- docs/source/api/python/nn.pytorch.rst | 8 ++ python/dgl/nn/pytorch/conv/__init__.py | 5 +- python/dgl/nn/pytorch/conv/egatconv.py | 182 +++++++++++++++++++++++++ tests/pytorch/test_nn.py | 23 ++++ 5 files changed, 217 insertions(+), 3 deletions(-) create mode 100644 python/dgl/nn/pytorch/conv/egatconv.py diff --git a/README.md b/README.md index 072960ff4b7d..22611be014fb 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba 1. [**Covid-19 Detection from Chest X-ray and Patient Metadata using Graph Convolutional Neural Networks**](https://arxiv.org/abs/2105.09720), *Thosini Bamunu Mudiyanselage, Nipuna Senanayake, Chunyan Ji, Yi Pan, Yanqing Zhang* -1. [**Graph neural networks and sequence embeddings enable the prediction and design of the cofactor specificity of Rossmann fold proteins**](https://www.biorxiv.org/content/10.1101/2021.05.05.440912v2), bioRxiv'21, *Kamil Kaminski, Jan Ludwiczak, Maciej Jasinski, Adriana Bukala, Rafal Madaj, Krzysztof Szczepaniak, Stanislaw Dunin-Horkawicz* +1. [**Rossmann-toolbox: a deep learning-based protocol for the prediction and design of cofactor specificity in Rossmann fold proteins**](https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbab371/6375059), Briefings in Bioinformatics, *Kamil Kaminski, Jan Ludwiczak, Maciej Jasinski, Adriana Bukala, Rafal Madaj, Krzysztof Szczepaniak, Stanislaw Dunin-Horkawicz* 1. [**LGESQL: Line Graph Enhanced Text-to-SQL Model with Mixed Local and Non-Local Relations**](https://arxiv.org/pdf/2106.01093.pdf), ACL'21, *Ruisheng Cao, Lu Chen, Zhi Chen, Yanbin Zhao, Su Zhu, Kai Yu* diff --git a/docs/source/api/python/nn.pytorch.rst b/docs/source/api/python/nn.pytorch.rst index 97832f5a717c..a929f2a82f53 100644 --- a/docs/source/api/python/nn.pytorch.rst +++ b/docs/source/api/python/nn.pytorch.rst @@ -44,6 +44,14 @@ GATConv .. autoclass:: dgl.nn.pytorch.conv.GATConv :members: forward :show-inheritance: + + +EGATConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.pytorch.conv.EGATConv + :members: forward + :show-inheritance: EdgeConv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/dgl/nn/pytorch/conv/__init__.py b/python/dgl/nn/pytorch/conv/__init__.py index 66e573fc7f68..65fbe7913903 100644 --- a/python/dgl/nn/pytorch/conv/__init__.py +++ b/python/dgl/nn/pytorch/conv/__init__.py @@ -6,6 +6,7 @@ from .chebconv import ChebConv from .edgeconv import EdgeConv from .gatconv import GATConv +from .egatconv import EGATConv from .ginconv import GINConv from .gmmconv import GMMConv from .graphconv import GraphConv, EdgeWeightNorm @@ -24,8 +25,8 @@ from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention from .gcn2conv import GCN2Conv -__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv', - 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', +__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'EGATConv', 'TAGConv', 'RelGraphConv', + 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv', 'TWIRLSUnfoldingAndAttention', 'GCN2Conv'] diff --git a/python/dgl/nn/pytorch/conv/egatconv.py b/python/dgl/nn/pytorch/conv/egatconv.py new file mode 100644 index 000000000000..4c0bc3b94e2c --- /dev/null +++ b/python/dgl/nn/pytorch/conv/egatconv.py @@ -0,0 +1,182 @@ +"""Torch modules for graph attention networks with fully valuable edges (EGAT).""" +# pylint: disable= no-member, arguments-differ, invalid-name +import torch as th +from torch import nn +from torch.nn import init + +from .... import function as fn +from ...functional import edge_softmax +from ....base import DGLError + +# pylint: enable=W0235 +class EGATConv(nn.Module): + r""" + Description + ----------- + Apply Graph Attention Layer over input graph. EGAT is an extension + of regular `Graph Attention Network `__ + handling edge features, detailed description is available in `Rossmann-Toolbox + `__ (see supplementary data). + The difference appears in the method how unnormalized attention scores :math:`e_{ij}` + are obtained: + + .. math:: + e_{ij} &= \vec{F} (f_{ij}^{\prime}) + + f_{ij}^{\prime} &= \mathrm{LeakyReLU}\left(A [ h_{i} \| f_{ij} \| h_{j}]\right) + + where :math:`f_{ij}^{\prime}` are edge features, :math:`\mathrm{A}` is weight matrix and + + :math: `\vec{F}` is weight vector. After that resulting node features + :math:`h_{i}^{\prime}` are updated in the same way as in regular GAT. + + Parameters + ---------- + in_node_feats : int + Input node feature size :math:`h_{i}`. + in_edge_feats : int + Input edge feature size :math:`f_{ij}`. + out_node_feats : int + Output node feature size. + out_edge_feats : int + Output edge feature size :math:`f_{ij}^{\prime}`. + num_heads : int + Number of attention heads. + bias : bool, optional + If True, add bias term to :math: `f_{ij}^{\prime}`. Defaults: ``True``. + + Examples + ---------- + >>> import dgl + >>> import torch as th + >>> from dgl.nn import EGATConv + + >>> num_nodes, num_edges = 8, 30 + >>> # generate a graph + >>> graph = dgl.rand_graph((num_nodes,num_edges)) + + >>> node_feats = th.rand((num_nodes, 20)) + >>> edge_feats = th.rand((num_edges, 12)) + >>> egat = EGATConv(in_node_feats=20, + in_edge_feats=12, + out_node_feats=15, + out_edge_feats=10, + num_heads=3) + >>> #forward pass + >>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats) + >>> new_node_feats.shape, new_edge_feats.shape + ((8, 3, 12), (30, 3, 10)) + """ + + def __init__(self, + in_node_feats, + in_edge_feats, + out_node_feats, + out_edge_feats, + num_heads, + bias=True): + + super().__init__() + self._num_heads = num_heads + self._out_node_feats = out_node_feats + self._out_edge_feats = out_edge_feats + self.fc_node = nn.Linear(in_node_feats, out_node_feats*num_heads, bias=True) + self.fc_ni = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False) + self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False) + self.fc_nj = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False) + self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats))) + if bias: + self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,))) + else: + self.register_buffer('bias', None) + self.reset_parameters() + + def reset_parameters(self): + """ + Reinitialize learnable parameters. + """ + gain = init.calculate_gain('relu') + init.xavier_normal_(self.fc_node.weight, gain=gain) + init.xavier_normal_(self.fc_ni.weight, gain=gain) + init.xavier_normal_(self.fc_fij.weight, gain=gain) + init.xavier_normal_(self.fc_nj.weight, gain=gain) + init.xavier_normal_(self.attn, gain=gain) + init.constant_(self.bias, 0) + + def forward(self, graph, nfeats, efeats, get_attention=False): + r""" + Compute new node and edge features. + + Parameters + ---------- + graph : DGLGraph + The graph. + nfeats : torch.Tensor + The input node feature of shape :math:`(N, D_{in})` + where: + :math:`D_{in}` is size of input node feature, + :math:`N` is the number of nodes. + efeats: torch.Tensor + The input edge feature of shape :math:`(E, F_{in})` + where: + :math:`F_{in}` is size of input node feature, + :math:`E` is the number of edges. + get_attention : bool, optional + Whether to return the attention values. Default to False. + + Returns + ------- + pair of torch.Tensor + node output features followed by edge output features + The node output feature of shape :math:`(N, H, D_{out})` + The edge output feature of shape :math:`(F, H, F_{out})` + where: + :math:`H` is the number of heads, + :math:`D_{out}` is size of output node feature, + :math:`F_{out}` is size of output edge feature. + torch.Tensor, optional + The attention values of shape :math:`(E, H, 1)`. + This is returned only when :attr: `get_attention` is ``True``. + """ + + with graph.local_scope(): + if (graph.in_degrees() == 0).any(): + raise DGLError('There are 0-in-degree nodes in the graph, ' + 'output for those nodes will be invalid. ' + 'This is harmful for some applications, ' + 'causing silent performance regression. ' + 'Adding self-loop on the input graph by ' + 'calling `g = dgl.add_self_loop(g)` will resolve ' + 'the issue.') + + # TODO allow node src and dst feats + # calc edge attention + # same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats + # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py + f_ni = self.fc_ni(nfeats) + f_nj = self.fc_nj(nfeats) + f_fij = self.fc_fij(efeats) + graph.srcdata.update({'f_ni': f_ni}) + graph.dstdata.update({'f_nj': f_nj}) + # add ni, nj factors + graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp')) + # add fij to node factor + f_out = graph.edata.pop('f_tmp') + f_fij + if self.bias is not None: + f_out = f_out + self.bias + f_out = nn.functional.leaky_relu(f_out) + f_out = f_out.view(-1, self._num_heads, self._out_edge_feats) + # compute attention factor + e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1) + graph.edata['a'] = edge_softmax(graph, e) + graph.ndata['h_out'] = self.fc_node(nfeats).view(-1, self._num_heads, + self._out_node_feats) + # calc weighted sum + graph.update_all(fn.u_mul_e('h_out', 'a', 'm'), + fn.sum('m', 'h_out')) + + h_out = graph.ndata['h_out'].view(-1, self._num_heads, self._out_node_feats) + if get_attention: + return h_out, f_out, graph.edata.pop('a') + else: + return h_out, f_out diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 4ae4407a12d4..9be94b1aa874 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -564,6 +564,28 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads): _, a = gat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1) +@parametrize_dtype +@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) +@pytest.mark.parametrize('out_node_feats', [1, 5]) +@pytest.mark.parametrize('out_edge_feats', [1, 5]) +@pytest.mark.parametrize('num_heads', [1, 4]) +def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads): + g = g.astype(idtype).to(F.ctx()) + ctx = F.ctx() + egat = nn.EGATConv(in_node_feats=10, + in_edge_feats=5, + out_node_feats=out_node_feats, + out_edge_feats=out_edge_feats, + num_heads=num_heads) + nfeat = F.randn((g.number_of_nodes(), 10)) + efeat = F.randn((g.number_of_edges(), 5)) + + egat = egat.to(ctx) + h, f = egat(g, nfeat, efeat) + h, f, attn = egat(g, nfeat, efeat, True) + + th.save(egat, tmp_buffer) + @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm']) @@ -1137,6 +1159,7 @@ def forward(self, g, h, arg1=None, *, arg2=None): test_rgcn_sorted() test_tagconv() test_gat_conv() + test_egat_conv() test_sage_conv() test_sgc_conv() test_appnp_conv()