Skip to content

Commit

Permalink
[Model] Add Pytorch example for Cluster GCN (dmlc#877)
Browse files Browse the repository at this point in the history
* initial commit of cluster GCN

* update readme

* fix small bugs running ppi

* nearly sota ppi training script&update readme

* rm unused imports&add shebang line to scripts

* minor comments&readme appended

* add rnd seed control&update readme
  • Loading branch information
Zardinality authored and mufeili committed Sep 28, 2019
1 parent c03046a commit 51a7350
Show file tree
Hide file tree
Showing 8 changed files with 693 additions and 0 deletions.
55 changes: 55 additions & 0 deletions examples/pytorch/cluster_gcn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks
============
- Paper link: [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](https://arxiv.org/abs/1905.07953)
- Author's code repo: [https://github.com/google-research/google-research/blob/master/cluster_gcn/](https://github.com/google-research/google-research/blob/master/cluster_gcn/).

This repo reproduce the reported speed and performance maximally on Reddit and PPI. However, the diag enhancement is not covered, as the GraphSage aggregator already achieves satisfying F1 score.

Dependencies
------------
- Python 3.7+(for string formatting features)
- PyTorch 1.1.0+
- metis
- sklearn


* install clustering toolkit: metis and its Python interface.

download and install metis: http://glaros.dtc.umn.edu/gkhome/metis/metis/download

METIS - Serial Graph Partitioning and Fill-reducing Matrix Ordering ([official website](http://glaros.dtc.umn.edu/gkhome/metis/metis/overview))

```
1) Download metis-5.1.0.tar.gz from http://glaros.dtc.umn.edu/gkhome/metis/metis/download and unpack it
2) cd metis-5.1.0
3) make config shared=1 prefix=~/.local/
4) make install
5) export METIS_DLL=~/.local/lib/libmetis.so
6) `pip install metis`
```

quick test to see whether you install metis correctly:

```
>>> import networkx as nx
>>> import metis
>>> G = metis.example_networkx()
>>> (edgecuts, parts) = metis.part_graph(G, 3)
```


## Run Experiments.
* For reddit data, you may run the following scripts

```
./run_reddit.sh
```
You should be able to see the final test F1 is around `Test F1-mic0.9612, Test F1-mac0.9399`.
Note that the first run of provided script is considerably slow than reported in the paper, which is presumably due to dataloader used. After caching the partition allocation, the overall speed would be in a normal scale. On a 1080Ti and Intel(R) Xeon(R) Bronze 3104 CPU @ 1.70GHz machine I am able to train it within 45s. After the first epoch the F1-mic on Validation dataset should be around `0.93`.

* For PPI data, you may run the following scripts

```
./run_ppi.sh
```
You should be able to see the final test F1 is around `Test F1-mic0.9924, Test F1-mac0.9917`. The training finished in 10 mins.
248 changes: 248 additions & 0 deletions examples/pytorch/cluster_gcn/cluster_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import argparse
import os
import time
import random

import numpy as np
import sklearn.preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args
from torch.utils.tensorboard import SummaryWriter

from modules import GCNCluster, GraphSAGE
from sampler import ClusterIter
from utils import Logger, evaluate, save_log_dir, load_data


def main(args):
torch.manual_seed(args.rnd_seed)
np.random.seed(args.rnd_seed)
random.seed(args.rnd_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

multitask_data = set(['ppi', 'amazon', 'amazon-0.1',
'amazon-0.3', 'amazon2M', 'amazon2M-47'])

multitask = args.dataset in multitask_data

# load and preprocess dataset
data = load_data(args)

train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)
test_nid = np.nonzero(data.test_mask)[0].astype(np.int64)

# Normalize features
if args.normalize:
train_feats = data.features[train_nid]
scaler = sklearn.preprocessing.StandardScaler()
scaler.fit(train_feats)
features = scaler.transform(data.features)
else:
features = data.features

features = torch.FloatTensor(features)
if not multitask:
labels = torch.LongTensor(data.labels)
else:
labels = torch.FloatTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask).type(torch.bool)
val_mask = torch.ByteTensor(data.val_mask).type(torch.bool)
test_mask = torch.ByteTensor(data.test_mask).type(torch.bool)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()

n_train_samples = train_mask.sum().item()
n_val_samples = val_mask.sum().item()
n_test_samples = test_mask.sum().item()

print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# create GCN model
g = data.graph
if args.self_loop and not args.dataset.startswith('reddit'):
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
print("adding self-loop edges")
g = DGLGraph(g, readonly=True)

# set device for dataset tensors
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()

print(torch.cuda.get_device_name(0))

g.ndata['features'] = features
g.ndata['labels'] = labels
g.ndata['train_mask'] = train_mask
print('labels shape:', labels.shape)

cluster_iterator = ClusterIter(
args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp)

print("features shape, ", features.shape)

model_sel = {'GCN': GCNCluster, 'graphsage': GraphSAGE}
model_class = model_sel[args.model_type]
print('using model:', model_class)

model = model_class(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout, args.use_pp)

if cuda:
model.cuda()

# logger and so on
log_dir = save_log_dir(args)
writer = SummaryWriter(log_dir)
logger = Logger(os.path.join(log_dir, 'loggings'))
logger.write(args)

# Loss function
if multitask:
print('Using multi-label loss')
loss_f = nn.BCEWithLogitsLoss()
else:
print('Using multi-class loss')
loss_f = nn.CrossEntropyLoss()

# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

# initialize graph
dur = []

# set train_nids to cuda tensor
if cuda:
train_nid = torch.from_numpy(train_nid).cuda()
print("current memory after model before training",
torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
start_time = time.time()
best_f1 = -1

for epoch in range(args.n_epochs):
for j, cluster in enumerate(cluster_iterator):
# sync with upper level training graph
cluster.copy_from_parent()
model.train()
# forward
pred = model(cluster)
batch_labels = cluster.ndata['labels']
batch_train_mask = cluster.ndata['train_mask']
loss = loss_f(pred[batch_train_mask],
batch_labels[batch_train_mask])

optimizer.zero_grad()
loss.backward()
optimizer.step()
# in PPI case, `log_every` is chosen to log one time per epoch.
# Choose your log freq dynamically when you want more info within one epoch
if j % args.log_every == 0:
print(f"epoch:{epoch}/{args.n_epochs}, Iteration {j}/{len(cluster_iterator)}:training loss", loss.item())
writer.add_scalar('train/loss', loss.item(),
global_step=j + epoch * len(cluster_iterator))
print("current memory:",
torch.cuda.memory_allocated(device=pred.device) / 1024 / 1024)

# evaluate
if epoch % args.val_every == 0:
val_f1_mic, val_f1_mac = evaluate(
model, g, labels, val_mask, multitask)
print(
"Val F1-mic{:.4f}, Val F1-mac{:.4f}". format(val_f1_mic, val_f1_mac))
if val_f1_mic > best_f1:
best_f1 = val_f1_mic
print('new best val f1:', best_f1)
torch.save(model.state_dict(), os.path.join(
log_dir, 'best_model.pkl'))
writer.add_scalar('val/f1-mic', val_f1_mic, global_step=epoch)
writer.add_scalar('val/f1-mac', val_f1_mac, global_step=epoch)

end_time = time.time()
print(f'training using time {start_time-end_time}')

# test
if args.use_val:
model.load_state_dict(torch.load(os.path.join(
log_dir, 'best_model.pkl')))
test_f1_mic, test_f1_mac = evaluate(
model, g, labels, test_mask, multitask)
print(
"Test F1-mic{:.4f}, Test F1-mac{:.4f}". format(test_f1_mic, test_f1_mac))
writer.add_scalar('test/f1-mic', test_f1_mic)
writer.add_scalar('test/f1-mac', test_f1_mac)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--log-every", type=int, default=100,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=20,
help="batch size")
parser.add_argument("--psize", type=int, default=1500,
help="partition number")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--val-every", type=int, default=1,
help="number of epoch of doing inference on validation")
parser.add_argument("--rnd-seed", type=int, default=3,
help="number of epoch of doing inference on validation")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--use-pp", action='store_true',
help="whether to use percomputation")
parser.add_argument("--normalize", action='store_true',
help="whether to use normalized feature")
parser.add_argument("--use-val", action='store_true',
help="whether to use validated best model to test")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--model-type", type=str, default='GCN',
help="model to be used")
parser.add_argument("--note", type=str, default='none',
help="note for log dir")

args = parser.parse_args()

print(args)

main(args)
Loading

0 comments on commit 51a7350

Please sign in to comment.