diff --git a/README.md b/README.md index 22611be014fb..242a055fff34 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,8 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba 1. [**GNNLens: A Visual Analytics Approach for Prediction Error Diagnosis of Graph Neural Networks**](https://arxiv.org/abs/2011.11048v5), *Zhihua Jin, Yong Wang, Qianwen Wang, Yao Ming, Tengfei Ma, Huamin Qu* +1. [**How Attentive are Graph Attention Networks?**](https://arxiv.org/pdf/2105.14491.pdf), *Shaked Brody, Uri Alon, Eran Yahav*, [code](https://github.com/tech-srl/how_attentive_are_gats) + ## Contributing diff --git a/docs/source/api/python/nn.pytorch.rst b/docs/source/api/python/nn.pytorch.rst index a929f2a82f53..ff840448208d 100644 --- a/docs/source/api/python/nn.pytorch.rst +++ b/docs/source/api/python/nn.pytorch.rst @@ -45,7 +45,13 @@ GATConv :members: forward :show-inheritance: - +GATv2Conv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.pytorch.conv.GATv2Conv + :members: forward + :show-inheritance: + EGATConv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/README.md b/examples/README.md index 576151cf543a..1e0fe5607c9f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -23,6 +23,9 @@ To quickly locate the examples of your interest, search for the tagged keywords - Guo et al. PCT: Point cloud transformer. [Paper link](http://arxiv.org/abs/2012.09688). - Example code: [PyTorch](../examples/pytorch/pointcloud/pct) - Tags: point cloud classification, point cloud part-segmentation +- Brody et al. How Attentive are Graph Attention Networks? [Paper link](https://arxiv.org/abs/2105.14491). + - Example code: [PyTorch](../examples/pytorch/gatv2) + - Tags: graph attention, gat, gatv2, attention ## 2020 - Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html). diff --git a/examples/pytorch/gatv2/README.md b/examples/pytorch/gatv2/README.md new file mode 100644 index 000000000000..6b1687a4ead2 --- /dev/null +++ b/examples/pytorch/gatv2/README.md @@ -0,0 +1,40 @@ +Graph Attention Networks v2 (GATv2) +============ + +- Paper link: [How Attentive are Graph Attention Networks?](https://arxiv.org/pdf/2105.14491.pdf) +- Author's code repo: [https://github.com/tech-srl/how_attentive_are_gats](https://github.com/tech-srl/how_attentive_are_gats). +- Annotated implemetnation: [https://nn.labml.ai/graphs/gatv2/index.html] + +Dependencies +------------ +- torch +- requests +- sklearn + +How to run +---------- + +Run with following: + +```bash +python3 train.py --dataset=cora +``` + +```bash +python3 train.py --dataset=citeseer +``` + +```bash +python3 train.py --dataset=pubmed +``` + +Results +------- + +| Dataset | Test Accuracy | +| -------- | ------------- | +| Cora | 82.10 | +| Citeseer | 70.00 | +| Pubmed | 77.2 | + +* All the accuracy numbers are obtained after 200 epochs. \ No newline at end of file diff --git a/examples/pytorch/gatv2/gatv2.py b/examples/pytorch/gatv2/gatv2.py new file mode 100644 index 000000000000..ed486d0a8d26 --- /dev/null +++ b/examples/pytorch/gatv2/gatv2.py @@ -0,0 +1,51 @@ +""" +Graph Attention Networks in DGL using SPMV optimization. +References +---------- +Paper: https://arxiv.org/pdf/2105.14491.pdf +Author's code: https://github.com/tech-srl/how_attentive_are_gats +""" + +import torch +import torch.nn as nn +from dgl.nn import GATv2Conv + + +class GATv2(nn.Module): + def __init__(self, + num_layers, + in_dim, + num_hidden, + num_classes, + heads, + activation, + feat_drop, + attn_drop, + negative_slope, + residual): + super(GATv2, self).__init__() + self.num_layers = num_layers + self.gatv2_layers = nn.ModuleList() + self.activation = activation + # input projection (no residual) + self.gatv2_layers.append(GATv2Conv( + in_dim, num_hidden, heads[0], + feat_drop, attn_drop, negative_slope, False, self.activation, bias=False, share_weights=True)) + # hidden layers + for l in range(1, num_layers): + # due to multi-head, the in_dim = num_hidden * num_heads + self.gatv2_layers.append(GATv2Conv( + num_hidden * heads[l-1], num_hidden, heads[l], + feat_drop, attn_drop, negative_slope, residual, self.activation, bias=False, share_weights=True)) + # output projection + self.gatv2_layers.append(GATv2Conv( + num_hidden * heads[-2], num_classes, heads[-1], + feat_drop, attn_drop, negative_slope, residual, None, bias=False, share_weights=True)) + + def forward(self, g, inputs): + h = inputs + for l in range(self.num_layers): + h = self.gatv2_layers[l](h).flatten(1) + # output projection + logits = self.gatv2_layers[-1](h).mean(1) + return logits diff --git a/examples/pytorch/gatv2/train.py b/examples/pytorch/gatv2/train.py new file mode 100644 index 000000000000..b52406018393 --- /dev/null +++ b/examples/pytorch/gatv2/train.py @@ -0,0 +1,198 @@ +""" +Graph Attention Networks v2 (GATv2) in DGL using SPMV optimization. +Multiple heads are also batched together for faster training. +""" + +import argparse +import numpy as np +import time +import torch +import torch.nn.functional as F +import dgl +from dgl.data import register_data_args +from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset + +from gatv2 import GATv2 + + +class EarlyStopping: + def __init__(self, patience=10): + self.patience = patience + self.counter = 0 + self.best_score = None + self.early_stop = False + + def step(self, acc, model): + score = acc + if self.best_score is None: + self.best_score = score + self.save_checkpoint(model) + elif score < self.best_score: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(model) + self.counter = 0 + return self.early_stop + + def save_checkpoint(self, model): + '''Saves model when validation loss decrease.''' + torch.save(model.state_dict(), 'es_checkpoint.pt') + +def accuracy(logits, labels): + _, indices = torch.max(logits, dim=1) + correct = torch.sum(indices == labels) + return correct.item() * 1.0 / len(labels) + + +def evaluate(model, g, features, labels, mask): + model.eval() + with torch.no_grad(): + logits = model(g, features) + logits = logits[mask] + labels = labels[mask] + return accuracy(logits, labels) + + +def main(args): + # load and preprocess dataset + if args.dataset == 'cora': + data = CoraGraphDataset() + elif args.dataset == 'citeseer': + data = CiteseerGraphDataset() + elif args.dataset == 'pubmed': + data = PubmedGraphDataset() + else: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + + g = data[0] + if args.gpu < 0: + cuda = False + else: + cuda = True + g = g.int().to(args.gpu) + + features = g.ndata['feat'] + labels = g.ndata['label'] + train_mask = g.ndata['train_mask'] + val_mask = g.ndata['val_mask'] + test_mask = g.ndata['test_mask'] + num_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.int().sum().item(), + val_mask.int().sum().item(), + test_mask.int().sum().item())) + + # add self loop + g = dgl.remove_self_loop(g) + g = dgl.add_self_loop(g) + n_edges = g.number_of_edges() + # create model + heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] + model = GATv2(args.num_layers, + num_feats, + args.num_hidden, + n_classes, + heads, + F.elu, + args.in_drop, + args.attn_drop, + args.negative_slope, + args.residual) + print(model) + if args.early_stop: + stopper = EarlyStopping(patience=100) + if cuda: + model.cuda() + loss_fcn = torch.nn.CrossEntropyLoss() + + # use optimizer + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # initialize graph + dur = [] + for epoch in range(args.epochs): + model.train() + if epoch >= 3: + t0 = time.time() + # forward + logits = model(g, features) + loss = loss_fcn(logits[train_mask], labels[train_mask]) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if epoch >= 3: + dur.append(time.time() - t0) + + train_acc = accuracy(logits[train_mask], labels[train_mask]) + + if args.fastmode: + val_acc = accuracy(logits[val_mask], labels[val_mask]) + else: + val_acc = evaluate(g, model, features, labels, val_mask) + if args.early_stop: + if stopper.step(val_acc, model): + break + + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" + " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". + format(epoch, np.mean(dur), loss.item(), train_acc, + val_acc, n_edges / np.mean(dur) / 1000)) + + print() + if args.early_stop: + model.load_state_dict(torch.load('es_checkpoint.pt')) + acc = evaluate(model, features, labels, test_mask) + print("Test Accuracy {:.4f}".format(acc)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='GAT') + register_data_args(parser) + parser.add_argument("--gpu", type=int, default=-1, + help="which GPU to use. Set -1 to use CPU.") + parser.add_argument("--epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--num-heads", type=int, default=8, + help="number of hidden attention heads") + parser.add_argument("--num-out-heads", type=int, default=1, + help="number of output attention heads") + parser.add_argument("--num-layers", type=int, default=1, + help="number of hidden layers") + parser.add_argument("--num-hidden", type=int, default=8, + help="number of hidden units") + parser.add_argument("--residual", action="store_true", default=False, + help="use residual connection") + parser.add_argument("--in-drop", type=float, default=.7, + help="input feature dropout") + parser.add_argument("--attn-drop", type=float, default=.7, + help="attention dropout") + parser.add_argument("--lr", type=float, default=0.005, + help="learning rate") + parser.add_argument('--weight-decay', type=float, default=5e-4, + help="weight decay") + parser.add_argument('--negative-slope', type=float, default=0.2, + help="the negative slope of leaky relu") + parser.add_argument('--early-stop', action='store_true', default=False, + help="indicates whether to use early stop or not") + parser.add_argument('--fastmode', action="store_true", default=False, + help="skip re-evaluate the validation set") + args = parser.parse_args() + print(args) + + main(args) diff --git a/python/dgl/nn/pytorch/conv/__init__.py b/python/dgl/nn/pytorch/conv/__init__.py index 65fbe7913903..fe3599fdada8 100644 --- a/python/dgl/nn/pytorch/conv/__init__.py +++ b/python/dgl/nn/pytorch/conv/__init__.py @@ -6,6 +6,7 @@ from .chebconv import ChebConv from .edgeconv import EdgeConv from .gatconv import GATConv +from .gatv2conv import GATv2Conv from .egatconv import EGATConv from .ginconv import GINConv from .gmmconv import GMMConv @@ -25,8 +26,8 @@ from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention from .gcn2conv import GCN2Conv -__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'EGATConv', 'TAGConv', 'RelGraphConv', - 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', - 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', +__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv', + 'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', + 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv', 'TWIRLSUnfoldingAndAttention', 'GCN2Conv'] diff --git a/python/dgl/nn/pytorch/conv/gatv2conv.py b/python/dgl/nn/pytorch/conv/gatv2conv.py new file mode 100644 index 000000000000..7ae71f8b4fbb --- /dev/null +++ b/python/dgl/nn/pytorch/conv/gatv2conv.py @@ -0,0 +1,312 @@ +"""Torch modules for graph attention networks v2 (GATv2).""" +# pylint: disable= no-member, arguments-differ, invalid-name +import torch as th +from torch import nn + +from .... import function as fn +from ...functional import edge_softmax +from ....base import DGLError +from ..utils import Identity +from ....utils import expand_as_pair + +# pylint: enable=W0235 +class GATv2Conv(nn.Module): + r""" + + Description + ----------- + Apply GATv2 from + `How Attentive are Graph Attention Networks? `__ + over an input signal. + + .. math:: + h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)}_{right} h_j^{(l)} + + where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and + node :math:`j`: + + .. math:: + \alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l}) + + e_{ij}^{l} &= \vec{a}^T\mathrm{LeakyReLU}\left( + W^{(l)}_{left} h_{i} + W^{(l)}_{right} h_{j}]\right) + + Parameters + ---------- + in_feats : int, or pair of ints + Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`. + If the layer is to be applied to a unidirectional bipartite graph, `in_feats` + specifies the input feature size on both the source and destination nodes. + If a scalar is given, the source and destination node feature size + would take the same value. + out_feats : int + Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`. + num_heads : int + Number of heads in Multi-Head Attention. + feat_drop : float, optional + Dropout rate on feature. Defaults: ``0``. + attn_drop : float, optional + Dropout rate on attention weight. Defaults: ``0``. + negative_slope : float, optional + LeakyReLU angle of negative slope. Defaults: ``0.2``. + residual : bool, optional + If True, use residual connection. Defaults: ``False``. + activation : callable activation function/layer or None, optional. + If not None, applies an activation function to the updated node features. + Default: ``None``. + allow_zero_in_degree : bool, optional + If there are 0-in-degree nodes in the graph, output for those nodes will be invalid + since no message will be passed to those nodes. This is harmful for some applications + causing silent performance regression. This module will raise a DGLError if it detects + 0-in-degree nodes in input graph. By setting ``True``, it will suppress the check + and let the users handle it by themselves. Defaults: ``False``. + bias : bool, optional + If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + share_weights : bool, optional + If set to :obj:`True`, the same matrix for :math:`W_{left}` and :math:`W_{right}` in + the above equations, will be applied to the source and the target node of every edge. + (default: :obj:`False`) + + Note + ---- + Zero in-degree nodes will lead to invalid output value. This is because no message + will be passed to those nodes, the aggregation function will be applied on empty input. + A common practice to avoid this is to add a self-loop for each node in the graph if + it is homogeneous, which can be achieved by: + + >>> g = ... # a DGLGraph + >>> g = dgl.add_self_loop(g) + + Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph + since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree`` + to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually. + A common practise to handle this is to filter out the nodes with zero-in-degree when use + after conv. + + Examples + -------- + >>> import dgl + >>> import numpy as np + >>> import torch as th + >>> from dgl.nn import GATv2Conv + + >>> # Case 1: Homogeneous graph + >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) + >>> g = dgl.add_self_loop(g) + >>> feat = th.ones(6, 10) + >>> gatv2conv = GATv2Conv(10, 2, num_heads=3) + >>> res = gatv2conv(g, feat) + >>> res + tensor([[[ 1.9599, 1.0239], + [ 3.2015, -0.5512], + [ 2.3700, -2.2182]], + [[ 1.9599, 1.0239], + [ 3.2015, -0.5512], + [ 2.3700, -2.2182]], + [[ 1.9599, 1.0239], + [ 3.2015, -0.5512], + [ 2.3700, -2.2182]], + [[ 1.9599, 1.0239], + [ 3.2015, -0.5512], + [ 2.3700, -2.2182]], + [[ 1.9599, 1.0239], + [ 3.2015, -0.5512], + [ 2.3700, -2.2182]], + [[ 1.9599, 1.0239], + [ 3.2015, -0.5512], + [ 2.3700, -2.2182]]], grad_fn=) + + >>> # Case 2: Unidirectional bipartite graph + >>> u = [0, 1, 0, 0, 1] + >>> v = [0, 1, 2, 3, 2] + >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)}) + >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32)) + >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) + >>> gatv2conv = GATv2Conv((5,10), 2, 3) + >>> res = gatv2conv(g, (u_feat, v_feat)) + >>> res + tensor([[[-0.0935, -0.4273], + [-1.1850, 0.1123], + [-0.2002, 0.1155]], + [[ 0.1908, -1.2095], + [-0.0129, 0.6408], + [-0.8135, 0.1157]], + [[ 0.0596, -0.8487], + [-0.5421, 0.4022], + [-0.4805, 0.1156]], + [[-0.0935, -0.4273], + [-1.1850, 0.1123], + [-0.2002, 0.1155]]], grad_fn=) + """ + def __init__(self, + in_feats, + out_feats, + num_heads, + feat_drop=0., + attn_drop=0., + negative_slope=0.2, + residual=False, + activation=None, + allow_zero_in_degree=False, + bias=True, + share_weights=False): + super(GATv2Conv, self).__init__() + self._num_heads = num_heads + self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) + self._out_feats = out_feats + self._allow_zero_in_degree = allow_zero_in_degree + if isinstance(in_feats, tuple): + self.fc_src = nn.Linear( + self._in_src_feats, out_feats * num_heads, bias=bias) + self.fc_dst = nn.Linear( + self._in_dst_feats, out_feats * num_heads, bias=bias) + else: + self.fc_src = nn.Linear( + self._in_src_feats, out_feats * num_heads, bias=bias) + if share_weights: + self.fc_dst = self.fc_src + else: + self.fc_dst = nn.Linear( + self._in_src_feats, out_feats * num_heads, bias=bias) + self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) + self.feat_drop = nn.Dropout(feat_drop) + self.attn_drop = nn.Dropout(attn_drop) + self.leaky_relu = nn.LeakyReLU(negative_slope) + if residual: + if self._in_dst_feats != out_feats: + self.res_fc = nn.Linear( + self._in_dst_feats, num_heads * out_feats, bias=bias) + else: + self.res_fc = Identity() + else: + self.register_buffer('res_fc', None) + self.activation = activation + self.share_weights = share_weights + self.bias = bias + self.reset_parameters() + + def reset_parameters(self): + """ + Description + ----------- + Reinitialize learnable parameters. + + Note + ---- + The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. + The attention weights are using xavier initialization method. + """ + gain = nn.init.calculate_gain('relu') + nn.init.xavier_normal_(self.fc_src.weight, gain=gain) + if self.bias: + nn.init.constant_(self.fc_src.bias, 0) + if not self.share_weights: + nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) + if self.bias: + nn.init.constant_(self.fc_dst.bias, 0) + nn.init.xavier_normal_(self.attn, gain=gain) + if isinstance(self.res_fc, nn.Linear): + nn.init.xavier_normal_(self.res_fc.weight, gain=gain) + if self.bias: + nn.init.constant_(self.res_fc.bias, 0) + + def set_allow_zero_in_degree(self, set_value): + r""" + Description + ----------- + Set allow_zero_in_degree flag. + + Parameters + ---------- + set_value : bool + The value to be set to the flag. + """ + self._allow_zero_in_degree = set_value + + def forward(self, graph, feat, get_attention=False): + r""" + Description + ----------- + Compute graph attention network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : torch.Tensor or pair of torch.Tensor + If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where + :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. + If a pair of torch.Tensor is given, the pair must contain two tensors of shape + :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. + get_attention : bool, optional + Whether to return the attention values. Default to False. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, H, D_{out})` where :math:`H` + is the number of heads, and :math:`D_{out}` is size of output feature. + torch.Tensor, optional + The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of + edges. This is returned only when :attr:`get_attention` is ``True``. + + Raises + ------ + DGLError + If there are 0-in-degree nodes in the input graph, it will raise DGLError + since no message will be passed to those nodes. This will cause invalid output. + The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``. + """ + with graph.local_scope(): + if not self._allow_zero_in_degree: + if (graph.in_degrees() == 0).any(): + raise DGLError('There are 0-in-degree nodes in the graph, ' + 'output for those nodes will be invalid. ' + 'This is harmful for some applications, ' + 'causing silent performance regression. ' + 'Adding self-loop on the input graph by ' + 'calling `g = dgl.add_self_loop(g)` will resolve ' + 'the issue. Setting ``allow_zero_in_degree`` ' + 'to be `True` when constructing this module will ' + 'suppress the check and let the code run.') + + if isinstance(feat, tuple): + h_src = self.feat_drop(feat[0]) + h_dst = self.feat_drop(feat[1]) + feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) + feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) + else: + h_src = h_dst = self.feat_drop(feat) + feat_src = self.fc_src(h_src).view( + -1, self._num_heads, self._out_feats) + if self.share_weights: + feat_dst = feat_src + else: + feat_dst = self.fc_dst(h_src).view( + -1, self._num_heads, self._out_feats) + if graph.is_block: + feat_dst = feat_src[:graph.number_of_dst_nodes()] + graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim) + graph.dstdata.update({'er': feat_dst}) + graph.apply_edges(fn.u_add_v('el', 'er', 'e')) + e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim) + e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1) + # compute softmax + graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads) + # message passing + graph.update_all(fn.u_mul_e('el', 'a', 'm'), + fn.sum('m', 'ft')) + rst = graph.dstdata['ft'] + # residual + if self.res_fc is not None: + resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) + rst = rst + resval + # activation + if self.activation: + rst = self.activation(rst) + + if get_attention: + return rst, graph.edata['a'] + else: + return rst diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 9be94b1aa874..a2309f24f688 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -564,6 +564,45 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads): _, a = gat(g, feat, get_attention=True) assert a.shape == (g.number_of_edges(), num_heads, 1) +@parametrize_dtype +@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) +@pytest.mark.parametrize('out_dim', [1, 5]) +@pytest.mark.parametrize('num_heads', [1, 4]) +def test_gatv2_conv(g, idtype, out_dim, num_heads): + g = g.astype(idtype).to(F.ctx()) + ctx = F.ctx() + gat = nn.GATv2Conv(5, out_dim, num_heads) + feat = F.randn((g.number_of_src_nodes(), 5)) + gat = gat.to(ctx) + h = gat(g, feat) + + # test pickle + th.save(gat, tmp_buffer) + + assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) + _, a = gat(g, feat, get_attention=True) + assert a.shape == (g.number_of_edges(), num_heads, 1) + + # test residual connection + gat = nn.GATConv(5, out_dim, num_heads, residual=True) + gat = gat.to(ctx) + h = gat(g, feat) + +@parametrize_dtype +@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) +@pytest.mark.parametrize('out_dim', [1, 2]) +@pytest.mark.parametrize('num_heads', [1, 4]) +def test_gatv2_conv_bi(g, idtype, out_dim, num_heads): + g = g.astype(idtype).to(F.ctx()) + ctx = F.ctx() + gat = nn.GATv2Conv(5, out_dim, num_heads) + feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5))) + gat = gat.to(ctx) + h = gat(g, feat) + assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) + _, a = gat(g, feat, get_attention=True) + assert a.shape == (g.number_of_edges(), num_heads, 1) + @parametrize_dtype @pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize('out_node_feats', [1, 5]) @@ -1159,6 +1198,7 @@ def forward(self, g, h, arg1=None, *, arg2=None): test_rgcn_sorted() test_tagconv() test_gat_conv() + test_gatv2_conv() test_egat_conv() test_sage_conv() test_sgc_conv()