-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] GAT on ogbn-arxiv (dmlc#2210)
* [Example] GCN on ogbn-arxiv dataset * Add README.md * Update GCN implementation on ogbn-arxiv * Update GCN on ogbn-arxiv * Fix typo * Use evaluator to get results * Fix duplicated * Fix duplicated * Update GCN on ogbn-arxiv * Add GAT for ogbn-arxiv * Update README.md * Update GAT on ogbn-arxiv * Update README.md Co-authored-by: Mufei Li <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
- Loading branch information
1 parent
2fbffd3
commit 47bb059
Showing
4 changed files
with
554 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,305 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import argparse | ||
import math | ||
import time | ||
|
||
import numpy as np | ||
import torch as th | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from matplotlib import pyplot as plt | ||
from matplotlib.ticker import AutoMinorLocator, MultipleLocator | ||
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator | ||
|
||
from models import GAT | ||
|
||
device = None | ||
in_feats, n_classes = None, None | ||
epsilon = 1 - math.log(2) | ||
|
||
|
||
def gen_model(args): | ||
norm = "both" if args.use_norm else "none" | ||
|
||
if args.use_labels: | ||
model = GAT( | ||
in_feats + n_classes, | ||
n_classes, | ||
n_hidden=args.n_hidden, | ||
n_layers=args.n_layers, | ||
n_heads=args.n_heads, | ||
activation=F.relu, | ||
dropout=args.dropout, | ||
attn_drop=args.attn_drop, | ||
norm=norm, | ||
) | ||
else: | ||
model = GAT( | ||
in_feats, | ||
n_classes, | ||
n_hidden=args.n_hidden, | ||
n_layers=args.n_layers, | ||
n_heads=args.n_heads, | ||
activation=F.relu, | ||
dropout=args.dropout, | ||
attn_drop=args.attn_drop, | ||
norm=norm, | ||
) | ||
|
||
return model | ||
|
||
|
||
def cross_entropy(x, labels): | ||
y = F.cross_entropy(x, labels[:, 0], reduction="none") | ||
y = th.log(epsilon + y) - math.log(epsilon) | ||
return th.mean(y) | ||
|
||
|
||
def compute_acc(pred, labels, evaluator): | ||
return evaluator.eval({"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels})["acc"] | ||
|
||
|
||
def add_labels(feat, labels, idx): | ||
onehot = th.zeros([feat.shape[0], n_classes]).to(device) | ||
onehot[idx, labels[idx, 0]] = 1 | ||
return th.cat([feat, onehot], dim=-1) | ||
|
||
|
||
def adjust_learning_rate(optimizer, lr, epoch): | ||
if epoch <= 50: | ||
for param_group in optimizer.param_groups: | ||
param_group["lr"] = lr * epoch / 50 | ||
|
||
|
||
def train(model, graph, labels, train_idx, optimizer, use_labels): | ||
model.train() | ||
|
||
feat = graph.ndata["feat"] | ||
|
||
if use_labels: | ||
mask_rate = 0.5 | ||
mask = th.rand(train_idx.shape) < mask_rate | ||
|
||
train_labels_idx = train_idx[mask] | ||
train_pred_idx = train_idx[~mask] | ||
|
||
feat = add_labels(feat, labels, train_labels_idx) | ||
else: | ||
mask_rate = 0.5 | ||
mask = th.rand(train_idx.shape) < mask_rate | ||
|
||
train_pred_idx = train_idx[mask] | ||
|
||
optimizer.zero_grad() | ||
pred = model(graph, feat) | ||
loss = cross_entropy(pred[train_pred_idx], labels[train_pred_idx]) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
return loss, pred | ||
|
||
|
||
@th.no_grad() | ||
def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator): | ||
model.eval() | ||
|
||
feat = graph.ndata["feat"] | ||
|
||
if use_labels: | ||
feat = add_labels(feat, labels, train_idx) | ||
|
||
pred = model(graph, feat) | ||
train_loss = cross_entropy(pred[train_idx], labels[train_idx]) | ||
val_loss = cross_entropy(pred[val_idx], labels[val_idx]) | ||
test_loss = cross_entropy(pred[test_idx], labels[test_idx]) | ||
|
||
return ( | ||
compute_acc(pred[train_idx], labels[train_idx], evaluator), | ||
compute_acc(pred[val_idx], labels[val_idx], evaluator), | ||
compute_acc(pred[test_idx], labels[test_idx], evaluator), | ||
train_loss, | ||
val_loss, | ||
test_loss, | ||
) | ||
|
||
|
||
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): | ||
# define model and optimizer | ||
model = gen_model(args) | ||
model = model.to(device) | ||
|
||
optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd) | ||
|
||
# training loop | ||
total_time = 0 | ||
best_val_acc, best_test_acc, best_val_loss = 0, 0, float("inf") | ||
|
||
accs, train_accs, val_accs, test_accs = [], [], [], [] | ||
losses, train_losses, val_losses, test_losses = [], [], [], [] | ||
|
||
for epoch in range(1, args.n_epochs + 1): | ||
tic = time.time() | ||
|
||
adjust_learning_rate(optimizer, args.lr, epoch) | ||
|
||
loss, pred = train(model, graph, labels, train_idx, optimizer, args.use_labels) | ||
acc = compute_acc(pred[train_idx], labels[train_idx], evaluator) | ||
|
||
train_acc, val_acc, test_acc, train_loss, val_loss, test_loss = evaluate( | ||
model, graph, labels, train_idx, val_idx, test_idx, args.use_labels, evaluator | ||
) | ||
|
||
toc = time.time() | ||
total_time += toc - tic | ||
|
||
if val_loss < best_val_loss: | ||
best_val_loss = val_loss | ||
best_val_acc = val_acc | ||
best_test_acc = test_acc | ||
|
||
if epoch % args.log_every == 0: | ||
print(f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}") | ||
print( | ||
f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n" | ||
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n" | ||
f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}" | ||
) | ||
|
||
for l, e in zip( | ||
[accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses], | ||
[acc, train_acc, val_acc, test_acc, loss.item(), train_loss, val_loss, test_loss], | ||
): | ||
l.append(e) | ||
|
||
print("*" * 50) | ||
print(f"Average epoch time: {total_time / args.n_epochs}, Test acc: {best_test_acc}") | ||
|
||
if args.plot_curves: | ||
fig = plt.figure(figsize=(24, 24)) | ||
ax = fig.gca() | ||
ax.set_xticks(np.arange(0, args.n_epochs, 100)) | ||
ax.set_yticks(np.linspace(0, 1.0, 101)) | ||
ax.tick_params(labeltop=True, labelright=True) | ||
for y, label in zip([accs, train_accs, val_accs, test_accs], ["acc", "train acc", "val acc", "test acc"]): | ||
plt.plot(range(args.n_epochs), y, label=label) | ||
ax.xaxis.set_major_locator(MultipleLocator(100)) | ||
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) | ||
ax.yaxis.set_major_locator(MultipleLocator(0.01)) | ||
ax.yaxis.set_minor_locator(AutoMinorLocator(2)) | ||
plt.grid(which="major", color="red", linestyle="dotted") | ||
plt.grid(which="minor", color="orange", linestyle="dotted") | ||
plt.legend() | ||
plt.tight_layout() | ||
plt.savefig(f"gat_acc_{n_running}.png") | ||
|
||
fig = plt.figure(figsize=(24, 24)) | ||
ax = fig.gca() | ||
ax.set_xticks(np.arange(0, args.n_epochs, 100)) | ||
ax.tick_params(labeltop=True, labelright=True) | ||
for y, label in zip( | ||
[losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"] | ||
): | ||
plt.plot(range(args.n_epochs), y, label=label) | ||
ax.xaxis.set_major_locator(MultipleLocator(100)) | ||
ax.xaxis.set_minor_locator(AutoMinorLocator(1)) | ||
ax.yaxis.set_major_locator(MultipleLocator(0.1)) | ||
ax.yaxis.set_minor_locator(AutoMinorLocator(5)) | ||
plt.grid(which="major", color="red", linestyle="dotted") | ||
plt.grid(which="minor", color="orange", linestyle="dotted") | ||
plt.legend() | ||
plt.tight_layout() | ||
plt.savefig(f"gat_loss_{n_running}.png") | ||
|
||
return best_val_acc, best_test_acc | ||
|
||
|
||
def count_parameters(args): | ||
model = gen_model(args) | ||
print([np.prod(p.size()) for p in model.parameters() if p.requires_grad]) | ||
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad]) | ||
|
||
|
||
def main(): | ||
global device, in_feats, n_classes, epsilon | ||
|
||
argparser = argparse.ArgumentParser("GAT on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.") | ||
argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.") | ||
argparser.add_argument("--n-runs", type=int, default=10) | ||
argparser.add_argument("--n-epochs", type=int, default=2000) | ||
argparser.add_argument( | ||
"--use-labels", action="store_true", help="Use labels in the training set as input features." | ||
) | ||
argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.") | ||
argparser.add_argument("--lr", type=float, default=0.002) | ||
argparser.add_argument("--n-layers", type=int, default=3) | ||
argparser.add_argument("--n-heads", type=int, default=3) | ||
argparser.add_argument("--n-hidden", type=int, default=256) | ||
argparser.add_argument("--dropout", type=float, default=0.75) | ||
argparser.add_argument("--attn_drop", type=float, default=0.05) | ||
argparser.add_argument("--wd", type=float, default=0) | ||
argparser.add_argument("--log-every", type=int, default=20) | ||
argparser.add_argument("--plot-curves", action="store_true") | ||
args = argparser.parse_args() | ||
|
||
if args.cpu: | ||
device = th.device("cpu") | ||
else: | ||
device = th.device("cuda:%d" % args.gpu) | ||
|
||
# load data | ||
data = DglNodePropPredDataset(name="ogbn-arxiv") | ||
evaluator = Evaluator(name="ogbn-arxiv") | ||
|
||
splitted_idx = data.get_idx_split() | ||
train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"] | ||
graph, labels = data[0] | ||
|
||
# add reverse edges | ||
srcs, dsts = graph.all_edges() | ||
graph.add_edges(dsts, srcs) | ||
|
||
# add self-loop | ||
print(f"Total edges before adding self-loop {graph.number_of_edges()}") | ||
graph = graph.remove_self_loop().add_self_loop() | ||
print(f"Total edges after adding self-loop {graph.number_of_edges()}") | ||
|
||
in_feats = graph.ndata["feat"].shape[1] | ||
n_classes = (labels.max() + 1).item() | ||
# graph.create_format_() | ||
|
||
train_idx = train_idx.to(device) | ||
val_idx = val_idx.to(device) | ||
test_idx = test_idx.to(device) | ||
labels = labels.to(device) | ||
graph = graph.to(device) | ||
|
||
# run | ||
val_accs = [] | ||
test_accs = [] | ||
|
||
for i in range(1, args.n_runs + 1): | ||
val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i) | ||
val_accs.append(val_acc) | ||
test_accs.append(test_acc) | ||
|
||
print(f"Runned {args.n_runs} times") | ||
print("Val Accs:", val_accs) | ||
print("Test Accs:", test_accs) | ||
print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}") | ||
print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}") | ||
print(f"Number of params: {count_parameters(args)}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
|
||
# Runned 10 times | ||
# Val Accs: [0.7505956575724018, 0.7489177489177489, 0.7502600758414711, 0.7498573777643545, 0.75079700661096, 0.7504278667069365, 0.7505285412262156, 0.7512332628611699, 0.7503271921876573, 0.750729890264774] | ||
# Test Accs: [0.7374853404110857, 0.7357982017570932, 0.7359216509268975, 0.736826944838796, 0.7385140834927885, 0.7370944180400387, 0.7358187766187272, 0.7365183219142851, 0.7343168117194412, 0.7371767174865749] | ||
# Average val accuracy: 0.750367461995369 ± 0.0005934770264509258 | ||
# Average test accuracy: 0.7365471267205728 ± 0.0010945826389317434 | ||
# Number of params: 1628440 |
Oops, something went wrong.