diff --git a/README.md b/README.md
index 22611be014fb..242a055fff34 100644
--- a/README.md
+++ b/README.md
@@ -288,6 +288,8 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba
1. [**GNNLens: A Visual Analytics Approach for Prediction Error Diagnosis of Graph Neural Networks**](https://arxiv.org/abs/2011.11048v5), *Zhihua Jin, Yong Wang, Qianwen Wang, Yao Ming, Tengfei Ma, Huamin Qu*
+1. [**How Attentive are Graph Attention Networks?**](https://arxiv.org/pdf/2105.14491.pdf), *Shaked Brody, Uri Alon, Eran Yahav*, [code](https://github.com/tech-srl/how_attentive_are_gats)
+
## Contributing
diff --git a/docs/source/api/python/nn.pytorch.rst b/docs/source/api/python/nn.pytorch.rst
index a929f2a82f53..ff840448208d 100644
--- a/docs/source/api/python/nn.pytorch.rst
+++ b/docs/source/api/python/nn.pytorch.rst
@@ -45,7 +45,13 @@ GATConv
:members: forward
:show-inheritance:
-
+GATv2Conv
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: dgl.nn.pytorch.conv.GATv2Conv
+ :members: forward
+ :show-inheritance:
+
EGATConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/examples/README.md b/examples/README.md
index 576151cf543a..1e0fe5607c9f 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -23,6 +23,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- Guo et al. PCT: Point cloud transformer. [Paper link](http://arxiv.org/abs/2012.09688).
- Example code: [PyTorch](../examples/pytorch/pointcloud/pct)
- Tags: point cloud classification, point cloud part-segmentation
+- Brody et al. How Attentive are Graph Attention Networks? [Paper link](https://arxiv.org/abs/2105.14491).
+ - Example code: [PyTorch](../examples/pytorch/gatv2)
+ - Tags: graph attention, gat, gatv2, attention
## 2020
- Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html).
diff --git a/examples/pytorch/gatv2/README.md b/examples/pytorch/gatv2/README.md
new file mode 100644
index 000000000000..6b1687a4ead2
--- /dev/null
+++ b/examples/pytorch/gatv2/README.md
@@ -0,0 +1,40 @@
+Graph Attention Networks v2 (GATv2)
+============
+
+- Paper link: [How Attentive are Graph Attention Networks?](https://arxiv.org/pdf/2105.14491.pdf)
+- Author's code repo: [https://github.com/tech-srl/how_attentive_are_gats](https://github.com/tech-srl/how_attentive_are_gats).
+- Annotated implemetnation: [https://nn.labml.ai/graphs/gatv2/index.html]
+
+Dependencies
+------------
+- torch
+- requests
+- sklearn
+
+How to run
+----------
+
+Run with following:
+
+```bash
+python3 train.py --dataset=cora
+```
+
+```bash
+python3 train.py --dataset=citeseer
+```
+
+```bash
+python3 train.py --dataset=pubmed
+```
+
+Results
+-------
+
+| Dataset | Test Accuracy |
+| -------- | ------------- |
+| Cora | 82.10 |
+| Citeseer | 70.00 |
+| Pubmed | 77.2 |
+
+* All the accuracy numbers are obtained after 200 epochs.
\ No newline at end of file
diff --git a/examples/pytorch/gatv2/gatv2.py b/examples/pytorch/gatv2/gatv2.py
new file mode 100644
index 000000000000..ed486d0a8d26
--- /dev/null
+++ b/examples/pytorch/gatv2/gatv2.py
@@ -0,0 +1,51 @@
+"""
+Graph Attention Networks in DGL using SPMV optimization.
+References
+----------
+Paper: https://arxiv.org/pdf/2105.14491.pdf
+Author's code: https://github.com/tech-srl/how_attentive_are_gats
+"""
+
+import torch
+import torch.nn as nn
+from dgl.nn import GATv2Conv
+
+
+class GATv2(nn.Module):
+ def __init__(self,
+ num_layers,
+ in_dim,
+ num_hidden,
+ num_classes,
+ heads,
+ activation,
+ feat_drop,
+ attn_drop,
+ negative_slope,
+ residual):
+ super(GATv2, self).__init__()
+ self.num_layers = num_layers
+ self.gatv2_layers = nn.ModuleList()
+ self.activation = activation
+ # input projection (no residual)
+ self.gatv2_layers.append(GATv2Conv(
+ in_dim, num_hidden, heads[0],
+ feat_drop, attn_drop, negative_slope, False, self.activation, bias=False, share_weights=True))
+ # hidden layers
+ for l in range(1, num_layers):
+ # due to multi-head, the in_dim = num_hidden * num_heads
+ self.gatv2_layers.append(GATv2Conv(
+ num_hidden * heads[l-1], num_hidden, heads[l],
+ feat_drop, attn_drop, negative_slope, residual, self.activation, bias=False, share_weights=True))
+ # output projection
+ self.gatv2_layers.append(GATv2Conv(
+ num_hidden * heads[-2], num_classes, heads[-1],
+ feat_drop, attn_drop, negative_slope, residual, None, bias=False, share_weights=True))
+
+ def forward(self, g, inputs):
+ h = inputs
+ for l in range(self.num_layers):
+ h = self.gatv2_layers[l](h).flatten(1)
+ # output projection
+ logits = self.gatv2_layers[-1](h).mean(1)
+ return logits
diff --git a/examples/pytorch/gatv2/train.py b/examples/pytorch/gatv2/train.py
new file mode 100644
index 000000000000..b52406018393
--- /dev/null
+++ b/examples/pytorch/gatv2/train.py
@@ -0,0 +1,198 @@
+"""
+Graph Attention Networks v2 (GATv2) in DGL using SPMV optimization.
+Multiple heads are also batched together for faster training.
+"""
+
+import argparse
+import numpy as np
+import time
+import torch
+import torch.nn.functional as F
+import dgl
+from dgl.data import register_data_args
+from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
+
+from gatv2 import GATv2
+
+
+class EarlyStopping:
+ def __init__(self, patience=10):
+ self.patience = patience
+ self.counter = 0
+ self.best_score = None
+ self.early_stop = False
+
+ def step(self, acc, model):
+ score = acc
+ if self.best_score is None:
+ self.best_score = score
+ self.save_checkpoint(model)
+ elif score < self.best_score:
+ self.counter += 1
+ print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
+ if self.counter >= self.patience:
+ self.early_stop = True
+ else:
+ self.best_score = score
+ self.save_checkpoint(model)
+ self.counter = 0
+ return self.early_stop
+
+ def save_checkpoint(self, model):
+ '''Saves model when validation loss decrease.'''
+ torch.save(model.state_dict(), 'es_checkpoint.pt')
+
+def accuracy(logits, labels):
+ _, indices = torch.max(logits, dim=1)
+ correct = torch.sum(indices == labels)
+ return correct.item() * 1.0 / len(labels)
+
+
+def evaluate(model, g, features, labels, mask):
+ model.eval()
+ with torch.no_grad():
+ logits = model(g, features)
+ logits = logits[mask]
+ labels = labels[mask]
+ return accuracy(logits, labels)
+
+
+def main(args):
+ # load and preprocess dataset
+ if args.dataset == 'cora':
+ data = CoraGraphDataset()
+ elif args.dataset == 'citeseer':
+ data = CiteseerGraphDataset()
+ elif args.dataset == 'pubmed':
+ data = PubmedGraphDataset()
+ else:
+ raise ValueError('Unknown dataset: {}'.format(args.dataset))
+
+ g = data[0]
+ if args.gpu < 0:
+ cuda = False
+ else:
+ cuda = True
+ g = g.int().to(args.gpu)
+
+ features = g.ndata['feat']
+ labels = g.ndata['label']
+ train_mask = g.ndata['train_mask']
+ val_mask = g.ndata['val_mask']
+ test_mask = g.ndata['test_mask']
+ num_feats = features.shape[1]
+ n_classes = data.num_labels
+ n_edges = data.graph.number_of_edges()
+ print("""----Data statistics------'
+ #Edges %d
+ #Classes %d
+ #Train samples %d
+ #Val samples %d
+ #Test samples %d""" %
+ (n_edges, n_classes,
+ train_mask.int().sum().item(),
+ val_mask.int().sum().item(),
+ test_mask.int().sum().item()))
+
+ # add self loop
+ g = dgl.remove_self_loop(g)
+ g = dgl.add_self_loop(g)
+ n_edges = g.number_of_edges()
+ # create model
+ heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
+ model = GATv2(args.num_layers,
+ num_feats,
+ args.num_hidden,
+ n_classes,
+ heads,
+ F.elu,
+ args.in_drop,
+ args.attn_drop,
+ args.negative_slope,
+ args.residual)
+ print(model)
+ if args.early_stop:
+ stopper = EarlyStopping(patience=100)
+ if cuda:
+ model.cuda()
+ loss_fcn = torch.nn.CrossEntropyLoss()
+
+ # use optimizer
+ optimizer = torch.optim.Adam(
+ model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+
+ # initialize graph
+ dur = []
+ for epoch in range(args.epochs):
+ model.train()
+ if epoch >= 3:
+ t0 = time.time()
+ # forward
+ logits = model(g, features)
+ loss = loss_fcn(logits[train_mask], labels[train_mask])
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ if epoch >= 3:
+ dur.append(time.time() - t0)
+
+ train_acc = accuracy(logits[train_mask], labels[train_mask])
+
+ if args.fastmode:
+ val_acc = accuracy(logits[val_mask], labels[val_mask])
+ else:
+ val_acc = evaluate(g, model, features, labels, val_mask)
+ if args.early_stop:
+ if stopper.step(val_acc, model):
+ break
+
+ print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
+ " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
+ format(epoch, np.mean(dur), loss.item(), train_acc,
+ val_acc, n_edges / np.mean(dur) / 1000))
+
+ print()
+ if args.early_stop:
+ model.load_state_dict(torch.load('es_checkpoint.pt'))
+ acc = evaluate(model, features, labels, test_mask)
+ print("Test Accuracy {:.4f}".format(acc))
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description='GAT')
+ register_data_args(parser)
+ parser.add_argument("--gpu", type=int, default=-1,
+ help="which GPU to use. Set -1 to use CPU.")
+ parser.add_argument("--epochs", type=int, default=200,
+ help="number of training epochs")
+ parser.add_argument("--num-heads", type=int, default=8,
+ help="number of hidden attention heads")
+ parser.add_argument("--num-out-heads", type=int, default=1,
+ help="number of output attention heads")
+ parser.add_argument("--num-layers", type=int, default=1,
+ help="number of hidden layers")
+ parser.add_argument("--num-hidden", type=int, default=8,
+ help="number of hidden units")
+ parser.add_argument("--residual", action="store_true", default=False,
+ help="use residual connection")
+ parser.add_argument("--in-drop", type=float, default=.7,
+ help="input feature dropout")
+ parser.add_argument("--attn-drop", type=float, default=.7,
+ help="attention dropout")
+ parser.add_argument("--lr", type=float, default=0.005,
+ help="learning rate")
+ parser.add_argument('--weight-decay', type=float, default=5e-4,
+ help="weight decay")
+ parser.add_argument('--negative-slope', type=float, default=0.2,
+ help="the negative slope of leaky relu")
+ parser.add_argument('--early-stop', action='store_true', default=False,
+ help="indicates whether to use early stop or not")
+ parser.add_argument('--fastmode', action="store_true", default=False,
+ help="skip re-evaluate the validation set")
+ args = parser.parse_args()
+ print(args)
+
+ main(args)
diff --git a/python/dgl/nn/pytorch/conv/__init__.py b/python/dgl/nn/pytorch/conv/__init__.py
index 65fbe7913903..fe3599fdada8 100644
--- a/python/dgl/nn/pytorch/conv/__init__.py
+++ b/python/dgl/nn/pytorch/conv/__init__.py
@@ -6,6 +6,7 @@
from .chebconv import ChebConv
from .edgeconv import EdgeConv
from .gatconv import GATConv
+from .gatv2conv import GATv2Conv
from .egatconv import EGATConv
from .ginconv import GINConv
from .gmmconv import GMMConv
@@ -25,8 +26,8 @@
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv
-__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'EGATConv', 'TAGConv', 'RelGraphConv',
- 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
- 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
+__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
+ 'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
+ 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv']
diff --git a/python/dgl/nn/pytorch/conv/gatv2conv.py b/python/dgl/nn/pytorch/conv/gatv2conv.py
new file mode 100644
index 000000000000..7ae71f8b4fbb
--- /dev/null
+++ b/python/dgl/nn/pytorch/conv/gatv2conv.py
@@ -0,0 +1,312 @@
+"""Torch modules for graph attention networks v2 (GATv2)."""
+# pylint: disable= no-member, arguments-differ, invalid-name
+import torch as th
+from torch import nn
+
+from .... import function as fn
+from ...functional import edge_softmax
+from ....base import DGLError
+from ..utils import Identity
+from ....utils import expand_as_pair
+
+# pylint: enable=W0235
+class GATv2Conv(nn.Module):
+ r"""
+
+ Description
+ -----------
+ Apply GATv2 from
+ `How Attentive are Graph Attention Networks? `__
+ over an input signal.
+
+ .. math::
+ h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)}_{right} h_j^{(l)}
+
+ where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and
+ node :math:`j`:
+
+ .. math::
+ \alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})
+
+ e_{ij}^{l} &= \vec{a}^T\mathrm{LeakyReLU}\left(
+ W^{(l)}_{left} h_{i} + W^{(l)}_{right} h_{j}]\right)
+
+ Parameters
+ ----------
+ in_feats : int, or pair of ints
+ Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
+ If the layer is to be applied to a unidirectional bipartite graph, `in_feats`
+ specifies the input feature size on both the source and destination nodes.
+ If a scalar is given, the source and destination node feature size
+ would take the same value.
+ out_feats : int
+ Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
+ num_heads : int
+ Number of heads in Multi-Head Attention.
+ feat_drop : float, optional
+ Dropout rate on feature. Defaults: ``0``.
+ attn_drop : float, optional
+ Dropout rate on attention weight. Defaults: ``0``.
+ negative_slope : float, optional
+ LeakyReLU angle of negative slope. Defaults: ``0.2``.
+ residual : bool, optional
+ If True, use residual connection. Defaults: ``False``.
+ activation : callable activation function/layer or None, optional.
+ If not None, applies an activation function to the updated node features.
+ Default: ``None``.
+ allow_zero_in_degree : bool, optional
+ If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
+ since no message will be passed to those nodes. This is harmful for some applications
+ causing silent performance regression. This module will raise a DGLError if it detects
+ 0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
+ and let the users handle it by themselves. Defaults: ``False``.
+ bias : bool, optional
+ If set to :obj:`False`, the layer will not learn
+ an additive bias. (default: :obj:`True`)
+ share_weights : bool, optional
+ If set to :obj:`True`, the same matrix for :math:`W_{left}` and :math:`W_{right}` in
+ the above equations, will be applied to the source and the target node of every edge.
+ (default: :obj:`False`)
+
+ Note
+ ----
+ Zero in-degree nodes will lead to invalid output value. This is because no message
+ will be passed to those nodes, the aggregation function will be applied on empty input.
+ A common practice to avoid this is to add a self-loop for each node in the graph if
+ it is homogeneous, which can be achieved by:
+
+ >>> g = ... # a DGLGraph
+ >>> g = dgl.add_self_loop(g)
+
+ Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
+ since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
+ to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.
+ A common practise to handle this is to filter out the nodes with zero-in-degree when use
+ after conv.
+
+ Examples
+ --------
+ >>> import dgl
+ >>> import numpy as np
+ >>> import torch as th
+ >>> from dgl.nn import GATv2Conv
+
+ >>> # Case 1: Homogeneous graph
+ >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
+ >>> g = dgl.add_self_loop(g)
+ >>> feat = th.ones(6, 10)
+ >>> gatv2conv = GATv2Conv(10, 2, num_heads=3)
+ >>> res = gatv2conv(g, feat)
+ >>> res
+ tensor([[[ 1.9599, 1.0239],
+ [ 3.2015, -0.5512],
+ [ 2.3700, -2.2182]],
+ [[ 1.9599, 1.0239],
+ [ 3.2015, -0.5512],
+ [ 2.3700, -2.2182]],
+ [[ 1.9599, 1.0239],
+ [ 3.2015, -0.5512],
+ [ 2.3700, -2.2182]],
+ [[ 1.9599, 1.0239],
+ [ 3.2015, -0.5512],
+ [ 2.3700, -2.2182]],
+ [[ 1.9599, 1.0239],
+ [ 3.2015, -0.5512],
+ [ 2.3700, -2.2182]],
+ [[ 1.9599, 1.0239],
+ [ 3.2015, -0.5512],
+ [ 2.3700, -2.2182]]], grad_fn=)
+
+ >>> # Case 2: Unidirectional bipartite graph
+ >>> u = [0, 1, 0, 0, 1]
+ >>> v = [0, 1, 2, 3, 2]
+ >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
+ >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
+ >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
+ >>> gatv2conv = GATv2Conv((5,10), 2, 3)
+ >>> res = gatv2conv(g, (u_feat, v_feat))
+ >>> res
+ tensor([[[-0.0935, -0.4273],
+ [-1.1850, 0.1123],
+ [-0.2002, 0.1155]],
+ [[ 0.1908, -1.2095],
+ [-0.0129, 0.6408],
+ [-0.8135, 0.1157]],
+ [[ 0.0596, -0.8487],
+ [-0.5421, 0.4022],
+ [-0.4805, 0.1156]],
+ [[-0.0935, -0.4273],
+ [-1.1850, 0.1123],
+ [-0.2002, 0.1155]]], grad_fn=)
+ """
+ def __init__(self,
+ in_feats,
+ out_feats,
+ num_heads,
+ feat_drop=0.,
+ attn_drop=0.,
+ negative_slope=0.2,
+ residual=False,
+ activation=None,
+ allow_zero_in_degree=False,
+ bias=True,
+ share_weights=False):
+ super(GATv2Conv, self).__init__()
+ self._num_heads = num_heads
+ self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
+ self._out_feats = out_feats
+ self._allow_zero_in_degree = allow_zero_in_degree
+ if isinstance(in_feats, tuple):
+ self.fc_src = nn.Linear(
+ self._in_src_feats, out_feats * num_heads, bias=bias)
+ self.fc_dst = nn.Linear(
+ self._in_dst_feats, out_feats * num_heads, bias=bias)
+ else:
+ self.fc_src = nn.Linear(
+ self._in_src_feats, out_feats * num_heads, bias=bias)
+ if share_weights:
+ self.fc_dst = self.fc_src
+ else:
+ self.fc_dst = nn.Linear(
+ self._in_src_feats, out_feats * num_heads, bias=bias)
+ self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
+ self.feat_drop = nn.Dropout(feat_drop)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.leaky_relu = nn.LeakyReLU(negative_slope)
+ if residual:
+ if self._in_dst_feats != out_feats:
+ self.res_fc = nn.Linear(
+ self._in_dst_feats, num_heads * out_feats, bias=bias)
+ else:
+ self.res_fc = Identity()
+ else:
+ self.register_buffer('res_fc', None)
+ self.activation = activation
+ self.share_weights = share_weights
+ self.bias = bias
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ """
+ Description
+ -----------
+ Reinitialize learnable parameters.
+
+ Note
+ ----
+ The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
+ The attention weights are using xavier initialization method.
+ """
+ gain = nn.init.calculate_gain('relu')
+ nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
+ if self.bias:
+ nn.init.constant_(self.fc_src.bias, 0)
+ if not self.share_weights:
+ nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
+ if self.bias:
+ nn.init.constant_(self.fc_dst.bias, 0)
+ nn.init.xavier_normal_(self.attn, gain=gain)
+ if isinstance(self.res_fc, nn.Linear):
+ nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
+ if self.bias:
+ nn.init.constant_(self.res_fc.bias, 0)
+
+ def set_allow_zero_in_degree(self, set_value):
+ r"""
+ Description
+ -----------
+ Set allow_zero_in_degree flag.
+
+ Parameters
+ ----------
+ set_value : bool
+ The value to be set to the flag.
+ """
+ self._allow_zero_in_degree = set_value
+
+ def forward(self, graph, feat, get_attention=False):
+ r"""
+ Description
+ -----------
+ Compute graph attention network layer.
+
+ Parameters
+ ----------
+ graph : DGLGraph
+ The graph.
+ feat : torch.Tensor or pair of torch.Tensor
+ If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
+ :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
+ If a pair of torch.Tensor is given, the pair must contain two tensors of shape
+ :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
+ get_attention : bool, optional
+ Whether to return the attention values. Default to False.
+
+ Returns
+ -------
+ torch.Tensor
+ The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
+ is the number of heads, and :math:`D_{out}` is size of output feature.
+ torch.Tensor, optional
+ The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
+ edges. This is returned only when :attr:`get_attention` is ``True``.
+
+ Raises
+ ------
+ DGLError
+ If there are 0-in-degree nodes in the input graph, it will raise DGLError
+ since no message will be passed to those nodes. This will cause invalid output.
+ The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
+ """
+ with graph.local_scope():
+ if not self._allow_zero_in_degree:
+ if (graph.in_degrees() == 0).any():
+ raise DGLError('There are 0-in-degree nodes in the graph, '
+ 'output for those nodes will be invalid. '
+ 'This is harmful for some applications, '
+ 'causing silent performance regression. '
+ 'Adding self-loop on the input graph by '
+ 'calling `g = dgl.add_self_loop(g)` will resolve '
+ 'the issue. Setting ``allow_zero_in_degree`` '
+ 'to be `True` when constructing this module will '
+ 'suppress the check and let the code run.')
+
+ if isinstance(feat, tuple):
+ h_src = self.feat_drop(feat[0])
+ h_dst = self.feat_drop(feat[1])
+ feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
+ feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
+ else:
+ h_src = h_dst = self.feat_drop(feat)
+ feat_src = self.fc_src(h_src).view(
+ -1, self._num_heads, self._out_feats)
+ if self.share_weights:
+ feat_dst = feat_src
+ else:
+ feat_dst = self.fc_dst(h_src).view(
+ -1, self._num_heads, self._out_feats)
+ if graph.is_block:
+ feat_dst = feat_src[:graph.number_of_dst_nodes()]
+ graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim)
+ graph.dstdata.update({'er': feat_dst})
+ graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
+ e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim)
+ e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1)
+ # compute softmax
+ graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads)
+ # message passing
+ graph.update_all(fn.u_mul_e('el', 'a', 'm'),
+ fn.sum('m', 'ft'))
+ rst = graph.dstdata['ft']
+ # residual
+ if self.res_fc is not None:
+ resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
+ rst = rst + resval
+ # activation
+ if self.activation:
+ rst = self.activation(rst)
+
+ if get_attention:
+ return rst, graph.edata['a']
+ else:
+ return rst
diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py
index 9be94b1aa874..a2309f24f688 100644
--- a/tests/pytorch/test_nn.py
+++ b/tests/pytorch/test_nn.py
@@ -564,6 +564,45 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
_, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1)
+@parametrize_dtype
+@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
+@pytest.mark.parametrize('out_dim', [1, 5])
+@pytest.mark.parametrize('num_heads', [1, 4])
+def test_gatv2_conv(g, idtype, out_dim, num_heads):
+ g = g.astype(idtype).to(F.ctx())
+ ctx = F.ctx()
+ gat = nn.GATv2Conv(5, out_dim, num_heads)
+ feat = F.randn((g.number_of_src_nodes(), 5))
+ gat = gat.to(ctx)
+ h = gat(g, feat)
+
+ # test pickle
+ th.save(gat, tmp_buffer)
+
+ assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
+ _, a = gat(g, feat, get_attention=True)
+ assert a.shape == (g.number_of_edges(), num_heads, 1)
+
+ # test residual connection
+ gat = nn.GATConv(5, out_dim, num_heads, residual=True)
+ gat = gat.to(ctx)
+ h = gat(g, feat)
+
+@parametrize_dtype
+@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
+@pytest.mark.parametrize('out_dim', [1, 2])
+@pytest.mark.parametrize('num_heads', [1, 4])
+def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
+ g = g.astype(idtype).to(F.ctx())
+ ctx = F.ctx()
+ gat = nn.GATv2Conv(5, out_dim, num_heads)
+ feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
+ gat = gat.to(ctx)
+ h = gat(g, feat)
+ assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
+ _, a = gat(g, feat, get_attention=True)
+ assert a.shape == (g.number_of_edges(), num_heads, 1)
+
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@@ -1159,6 +1198,7 @@ def forward(self, g, h, arg1=None, *, arg2=None):
test_rgcn_sorted()
test_tagconv()
test_gat_conv()
+ test_gatv2_conv()
test_egat_conv()
test_sage_conv()
test_sgc_conv()