-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [Model][Core] GATv2 * lint * gatv2conv.py * lint * lint * style and docs * lint * gatv2conv fix Co-authored-by: Shaked Brody [email protected] <[email protected]> Co-authored-by: Mufei Li <[email protected]>
- Loading branch information
1 parent
f510214
commit e2f33fd
Showing
9 changed files
with
657 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
Graph Attention Networks v2 (GATv2) | ||
============ | ||
|
||
- Paper link: [How Attentive are Graph Attention Networks?](https://arxiv.org/pdf/2105.14491.pdf) | ||
- Author's code repo: [https://github.com/tech-srl/how_attentive_are_gats](https://github.com/tech-srl/how_attentive_are_gats). | ||
- Annotated implemetnation: [https://nn.labml.ai/graphs/gatv2/index.html] | ||
|
||
Dependencies | ||
------------ | ||
- torch | ||
- requests | ||
- sklearn | ||
|
||
How to run | ||
---------- | ||
|
||
Run with following: | ||
|
||
```bash | ||
python3 train.py --dataset=cora | ||
``` | ||
|
||
```bash | ||
python3 train.py --dataset=citeseer | ||
``` | ||
|
||
```bash | ||
python3 train.py --dataset=pubmed | ||
``` | ||
|
||
Results | ||
------- | ||
|
||
| Dataset | Test Accuracy | | ||
| -------- | ------------- | | ||
| Cora | 82.10 | | ||
| Citeseer | 70.00 | | ||
| Pubmed | 77.2 | | ||
|
||
* All the accuracy numbers are obtained after 200 epochs. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Graph Attention Networks in DGL using SPMV optimization. | ||
References | ||
---------- | ||
Paper: https://arxiv.org/pdf/2105.14491.pdf | ||
Author's code: https://github.com/tech-srl/how_attentive_are_gats | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from dgl.nn import GATv2Conv | ||
|
||
|
||
class GATv2(nn.Module): | ||
def __init__(self, | ||
num_layers, | ||
in_dim, | ||
num_hidden, | ||
num_classes, | ||
heads, | ||
activation, | ||
feat_drop, | ||
attn_drop, | ||
negative_slope, | ||
residual): | ||
super(GATv2, self).__init__() | ||
self.num_layers = num_layers | ||
self.gatv2_layers = nn.ModuleList() | ||
self.activation = activation | ||
# input projection (no residual) | ||
self.gatv2_layers.append(GATv2Conv( | ||
in_dim, num_hidden, heads[0], | ||
feat_drop, attn_drop, negative_slope, False, self.activation, bias=False, share_weights=True)) | ||
# hidden layers | ||
for l in range(1, num_layers): | ||
# due to multi-head, the in_dim = num_hidden * num_heads | ||
self.gatv2_layers.append(GATv2Conv( | ||
num_hidden * heads[l-1], num_hidden, heads[l], | ||
feat_drop, attn_drop, negative_slope, residual, self.activation, bias=False, share_weights=True)) | ||
# output projection | ||
self.gatv2_layers.append(GATv2Conv( | ||
num_hidden * heads[-2], num_classes, heads[-1], | ||
feat_drop, attn_drop, negative_slope, residual, None, bias=False, share_weights=True)) | ||
|
||
def forward(self, g, inputs): | ||
h = inputs | ||
for l in range(self.num_layers): | ||
h = self.gatv2_layers[l](h).flatten(1) | ||
# output projection | ||
logits = self.gatv2_layers[-1](h).mean(1) | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
""" | ||
Graph Attention Networks v2 (GATv2) in DGL using SPMV optimization. | ||
Multiple heads are also batched together for faster training. | ||
""" | ||
|
||
import argparse | ||
import numpy as np | ||
import time | ||
import torch | ||
import torch.nn.functional as F | ||
import dgl | ||
from dgl.data import register_data_args | ||
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset | ||
|
||
from gatv2 import GATv2 | ||
|
||
|
||
class EarlyStopping: | ||
def __init__(self, patience=10): | ||
self.patience = patience | ||
self.counter = 0 | ||
self.best_score = None | ||
self.early_stop = False | ||
|
||
def step(self, acc, model): | ||
score = acc | ||
if self.best_score is None: | ||
self.best_score = score | ||
self.save_checkpoint(model) | ||
elif score < self.best_score: | ||
self.counter += 1 | ||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') | ||
if self.counter >= self.patience: | ||
self.early_stop = True | ||
else: | ||
self.best_score = score | ||
self.save_checkpoint(model) | ||
self.counter = 0 | ||
return self.early_stop | ||
|
||
def save_checkpoint(self, model): | ||
'''Saves model when validation loss decrease.''' | ||
torch.save(model.state_dict(), 'es_checkpoint.pt') | ||
|
||
def accuracy(logits, labels): | ||
_, indices = torch.max(logits, dim=1) | ||
correct = torch.sum(indices == labels) | ||
return correct.item() * 1.0 / len(labels) | ||
|
||
|
||
def evaluate(model, g, features, labels, mask): | ||
model.eval() | ||
with torch.no_grad(): | ||
logits = model(g, features) | ||
logits = logits[mask] | ||
labels = labels[mask] | ||
return accuracy(logits, labels) | ||
|
||
|
||
def main(args): | ||
# load and preprocess dataset | ||
if args.dataset == 'cora': | ||
data = CoraGraphDataset() | ||
elif args.dataset == 'citeseer': | ||
data = CiteseerGraphDataset() | ||
elif args.dataset == 'pubmed': | ||
data = PubmedGraphDataset() | ||
else: | ||
raise ValueError('Unknown dataset: {}'.format(args.dataset)) | ||
|
||
g = data[0] | ||
if args.gpu < 0: | ||
cuda = False | ||
else: | ||
cuda = True | ||
g = g.int().to(args.gpu) | ||
|
||
features = g.ndata['feat'] | ||
labels = g.ndata['label'] | ||
train_mask = g.ndata['train_mask'] | ||
val_mask = g.ndata['val_mask'] | ||
test_mask = g.ndata['test_mask'] | ||
num_feats = features.shape[1] | ||
n_classes = data.num_labels | ||
n_edges = data.graph.number_of_edges() | ||
print("""----Data statistics------' | ||
#Edges %d | ||
#Classes %d | ||
#Train samples %d | ||
#Val samples %d | ||
#Test samples %d""" % | ||
(n_edges, n_classes, | ||
train_mask.int().sum().item(), | ||
val_mask.int().sum().item(), | ||
test_mask.int().sum().item())) | ||
|
||
# add self loop | ||
g = dgl.remove_self_loop(g) | ||
g = dgl.add_self_loop(g) | ||
n_edges = g.number_of_edges() | ||
# create model | ||
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] | ||
model = GATv2(args.num_layers, | ||
num_feats, | ||
args.num_hidden, | ||
n_classes, | ||
heads, | ||
F.elu, | ||
args.in_drop, | ||
args.attn_drop, | ||
args.negative_slope, | ||
args.residual) | ||
print(model) | ||
if args.early_stop: | ||
stopper = EarlyStopping(patience=100) | ||
if cuda: | ||
model.cuda() | ||
loss_fcn = torch.nn.CrossEntropyLoss() | ||
|
||
# use optimizer | ||
optimizer = torch.optim.Adam( | ||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | ||
|
||
# initialize graph | ||
dur = [] | ||
for epoch in range(args.epochs): | ||
model.train() | ||
if epoch >= 3: | ||
t0 = time.time() | ||
# forward | ||
logits = model(g, features) | ||
loss = loss_fcn(logits[train_mask], labels[train_mask]) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if epoch >= 3: | ||
dur.append(time.time() - t0) | ||
|
||
train_acc = accuracy(logits[train_mask], labels[train_mask]) | ||
|
||
if args.fastmode: | ||
val_acc = accuracy(logits[val_mask], labels[val_mask]) | ||
else: | ||
val_acc = evaluate(g, model, features, labels, val_mask) | ||
if args.early_stop: | ||
if stopper.step(val_acc, model): | ||
break | ||
|
||
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" | ||
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". | ||
format(epoch, np.mean(dur), loss.item(), train_acc, | ||
val_acc, n_edges / np.mean(dur) / 1000)) | ||
|
||
print() | ||
if args.early_stop: | ||
model.load_state_dict(torch.load('es_checkpoint.pt')) | ||
acc = evaluate(model, features, labels, test_mask) | ||
print("Test Accuracy {:.4f}".format(acc)) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser(description='GAT') | ||
register_data_args(parser) | ||
parser.add_argument("--gpu", type=int, default=-1, | ||
help="which GPU to use. Set -1 to use CPU.") | ||
parser.add_argument("--epochs", type=int, default=200, | ||
help="number of training epochs") | ||
parser.add_argument("--num-heads", type=int, default=8, | ||
help="number of hidden attention heads") | ||
parser.add_argument("--num-out-heads", type=int, default=1, | ||
help="number of output attention heads") | ||
parser.add_argument("--num-layers", type=int, default=1, | ||
help="number of hidden layers") | ||
parser.add_argument("--num-hidden", type=int, default=8, | ||
help="number of hidden units") | ||
parser.add_argument("--residual", action="store_true", default=False, | ||
help="use residual connection") | ||
parser.add_argument("--in-drop", type=float, default=.7, | ||
help="input feature dropout") | ||
parser.add_argument("--attn-drop", type=float, default=.7, | ||
help="attention dropout") | ||
parser.add_argument("--lr", type=float, default=0.005, | ||
help="learning rate") | ||
parser.add_argument('--weight-decay', type=float, default=5e-4, | ||
help="weight decay") | ||
parser.add_argument('--negative-slope', type=float, default=0.2, | ||
help="the negative slope of leaky relu") | ||
parser.add_argument('--early-stop', action='store_true', default=False, | ||
help="indicates whether to use early stop or not") | ||
parser.add_argument('--fastmode', action="store_true", default=False, | ||
help="skip re-evaluate the validation set") | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.