forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Add SIGN model to example (dmlc#1571)
* sign model * mini-batch GPU eval & set default hyper-parameters * add readme for sign * minor
- Loading branch information
Showing
3 changed files
with
270 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
SIGN: Scalable Inception Graph Neural Networks | ||
=============== | ||
|
||
- paper link: [https://arxiv.org/pdf/2004.11198.pdf](https://arxiv.org/pdf/2004.11198.pdf) | ||
|
||
Requirements | ||
---------------- | ||
|
||
```bash | ||
pip install requests ogb | ||
``` | ||
|
||
Results | ||
--------------- | ||
### [Ogbn-products](https://ogb.stanford.edu/docs/nodeprop/#ogbn-products) (Amazon co-purchase dataset) | ||
|
||
```bash | ||
python sign.py --dataset amazon | ||
``` | ||
|
||
Test accuracy: mean 0.78672, std 0.00059 | ||
|
||
```bash | ||
python sign.py --dataset reddit | ||
``` | ||
|
||
Test accuracy: mean 0.96326, std 0.00010 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
import numpy as np | ||
import dgl | ||
|
||
|
||
def load_dataset(name): | ||
dataset = name.lower() | ||
if dataset == "amazon": | ||
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset | ||
dataset = DglNodePropPredDataset(name="ogbn-products") | ||
splitted_idx = dataset.get_idx_split() | ||
train_nid = splitted_idx["train"] | ||
val_nid = splitted_idx["valid"] | ||
test_nid = splitted_idx["test"] | ||
g, labels = dataset[0] | ||
labels = labels.squeeze() | ||
n_classes = int(labels.max() - labels.min() + 1) | ||
features = g.ndata.pop("feat").float() | ||
elif dataset in ["reddit", "cora"]: | ||
if dataset == "reddit": | ||
from dgl.data import RedditDataset | ||
data = RedditDataset(self_loop=True) | ||
g = data.graph | ||
else: | ||
from dgl.data import CoraDataset | ||
data = CoraDataset() | ||
g = dgl.DGLGraph(data.graph) | ||
train_mask = data.train_mask | ||
val_mask = data.val_mask | ||
test_mask = data.test_mask | ||
features = torch.Tensor(data.features) | ||
labels = torch.LongTensor(data.labels) | ||
n_classes = data.num_labels | ||
train_nid = torch.LongTensor(np.nonzero(train_mask)[0]) | ||
val_nid = torch.LongTensor(np.nonzero(val_mask)[0]) | ||
test_nid = torch.LongTensor(np.nonzero(test_mask)[0]) | ||
else: | ||
print("Dataset {} is not supported".format(name)) | ||
assert(0) | ||
|
||
return g, features, labels, n_classes, train_nid, val_nid, test_nid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import argparse | ||
import os | ||
import time | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import dgl | ||
import dgl.function as fn | ||
from dataset import load_dataset | ||
|
||
|
||
class FeedForwardNet(nn.Module): | ||
def __init__(self, in_feats, hidden, out_feats, n_layers, dropout): | ||
super(FeedForwardNet, self).__init__() | ||
self.layers = nn.ModuleList() | ||
self.n_layers = n_layers | ||
if n_layers == 1: | ||
self.layers.append(nn.Linear(in_feats, out_feats)) | ||
else: | ||
self.layers.append(nn.Linear(in_feats, hidden)) | ||
for i in range(n_layers - 2): | ||
self.layers.append(nn.Linear(hidden, hidden)) | ||
self.layers.append(nn.Linear(hidden, out_feats)) | ||
if self.n_layers > 1: | ||
self.prelu = nn.PReLU() | ||
self.dropout = nn.Dropout(dropout) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
gain = nn.init.calculate_gain("relu") | ||
for layer in self.layers: | ||
nn.init.xavier_uniform_(layer.weight, gain=gain) | ||
nn.init.zeros_(layer.bias) | ||
|
||
def forward(self, x): | ||
for layer_id, layer in enumerate(self.layers): | ||
x = layer(x) | ||
if layer_id < self.n_layers - 1: | ||
x = self.dropout(self.prelu(x)) | ||
return x | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, in_feats, hidden, out_feats, R, n_layers, dropout): | ||
super(Model, self).__init__() | ||
self.dropout = nn.Dropout(dropout) | ||
self.prelu = nn.PReLU() | ||
self.inception_ffs = nn.ModuleList() | ||
for hop in range(R + 1): | ||
self.inception_ffs.append( | ||
FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)) | ||
# self.linear = nn.Linear(hidden * (R + 1), out_feats) | ||
self.project = FeedForwardNet((R + 1) * hidden, hidden, out_feats, | ||
n_layers, dropout) | ||
|
||
def forward(self, feats): | ||
hidden = [] | ||
for feat, ff in zip(feats, self.inception_ffs): | ||
hidden.append(ff(feat)) | ||
out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1)))) | ||
return out | ||
|
||
|
||
def calc_weight(g, device): | ||
""" | ||
Compute row_normalized(D^(-1/2)AD^(-1/2)) | ||
""" | ||
with g.local_scope(): | ||
# compute D^(-0.5)*D(-1/2), assuming A is Identity | ||
g.ndata["in_deg"] = g.in_degrees().float().to(device).pow(-0.5) | ||
g.ndata["out_deg"] = g.out_degrees().float().to(device).pow(-0.5) | ||
g.apply_edges(fn.u_mul_v("out_deg", "in_deg", "weight")) | ||
# row-normalize weight | ||
g.update_all(fn.copy_e("weight", "msg"), fn.sum("msg", "norm")) | ||
g.apply_edges(fn.e_div_v("weight", "norm", "weight")) | ||
return g.edata["weight"] | ||
|
||
|
||
def preprocess(g, features, args, device): | ||
""" | ||
Pre-compute the average of n-th hop neighbors | ||
""" | ||
with torch.no_grad(): | ||
g.edata["weight"] = calc_weight(g, device) | ||
g.ndata["feat_0"] = features.to(device) | ||
for hop in range(1, args.R + 1): | ||
g.update_all(fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"), | ||
fn.sum("msg", f"feat_{hop}")) | ||
res = [] | ||
for hop in range(args.R + 1): | ||
res.append(g.ndata.pop(f"feat_{hop}")) | ||
return res | ||
|
||
|
||
def prepare_data(device, args): | ||
data = load_dataset(args.dataset) | ||
g, features, labels, n_classes, train_nid, val_nid, test_nid = data | ||
in_feats = features.shape[1] | ||
feats = preprocess(g, features, args, device) | ||
# move to device | ||
labels = labels.to(device) | ||
train_nid = train_nid.to(device) | ||
val_nid = val_nid.to(device) | ||
test_nid = test_nid.to(device) | ||
train_feats = [x[train_nid] for x in feats] | ||
train_labels = labels[train_nid] | ||
return feats, labels, train_feats, train_labels, in_feats, \ | ||
n_classes, train_nid, val_nid, test_nid | ||
|
||
|
||
def evaluate(epoch, args, model, feats, labels, train, val, test): | ||
with torch.no_grad(): | ||
batch_size = args.eval_batch_size | ||
if batch_size <= 0: | ||
pred = model(feats) | ||
else: | ||
pred = [] | ||
num_nodes = labels.shape[0] | ||
n_batch = (num_nodes + batch_size - 1) // batch_size | ||
for i in range(n_batch): | ||
batch_start = i * batch_size | ||
batch_end = min((i + 1) * batch_size, num_nodes) | ||
batch_feats = [feat[batch_start: batch_end] for feat in feats] | ||
pred.append(model(batch_feats)) | ||
pred = torch.cat(pred) | ||
|
||
pred = torch.argmax(pred, dim=1) | ||
correct = (pred == labels).float() | ||
train_acc = correct[train].sum() / len(train) | ||
val_acc = correct[val].sum() / len(val) | ||
test_acc = correct[test].sum() / len(test) | ||
return train_acc, val_acc, test_acc | ||
|
||
|
||
def main(args): | ||
if args.gpu < 0: | ||
device = "cpu" | ||
else: | ||
device = "cuda:{}".format(args.gpu) | ||
|
||
data = prepare_data(device, args) | ||
feats, labels, train_feats, train_labels, in_size, num_classes, \ | ||
train_nid, val_nid, test_nid = data | ||
|
||
model = Model(in_size, args.num_hidden, num_classes, args.R, args.ff_layer, | ||
args.dropout) | ||
model = model.to(device) | ||
loss_fcn = nn.CrossEntropyLoss() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, | ||
weight_decay=args.weight_decay) | ||
|
||
best_epoch = 0 | ||
best_val = 0 | ||
best_test = 0 | ||
|
||
for epoch in range(1, args.num_epochs + 1): | ||
start = time.time() | ||
model.train() | ||
loss = loss_fcn(model(train_feats), train_labels) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if epoch % args.eval_every == 0: | ||
model.eval() | ||
acc = evaluate(epoch, args, model, feats, labels, | ||
train_nid, val_nid, test_nid) | ||
end = time.time() | ||
log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start) | ||
log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}" \ | ||
.format(*acc) | ||
print(log) | ||
if acc[1] > best_val: | ||
best_val = acc[1] | ||
best_epoch = epoch | ||
best_test = acc[2] | ||
|
||
print("Best Epoch {}, Val {:.4f}, Test {:.4f}".format( | ||
best_epoch, best_val, best_test)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="SIGN") | ||
parser.add_argument("--num-epochs", type=int, default=1000) | ||
parser.add_argument("--num-hidden", type=int, default=256) | ||
parser.add_argument("--R", type=int, default=3, | ||
help="number of hops") | ||
parser.add_argument("--lr", type=float, default=0.003) | ||
parser.add_argument("--dataset", type=str, default="amazon") | ||
parser.add_argument("--dropout", type=float, default=0.5) | ||
parser.add_argument("--gpu", type=int, default=0) | ||
parser.add_argument("--weight-decay", type=float, default=0) | ||
parser.add_argument("--eval-every", type=int, default=50) | ||
parser.add_argument("--eval-batch-size", type=int, default=250000, | ||
help="evaluation batch size, -1 for full batch") | ||
parser.add_argument("--ff-layer", type=int, default=2, | ||
help="number of feed-forward layers") | ||
args = parser.parse_args() | ||
|
||
print(args) | ||
main(args) |