-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GraphIndex] Immutable graph index using MXNet CSRArray (dmlc#86)
* add immutable graph index. * update ImmutableGraphIndex. * add benchmarks of subgraph generation. * fix node_subgraphs. * use scipy synthetic data in immutable graph. * create immutable_graph. * use spmv in gcn. * fix immutable graph index. * update graph index. * create ImmutableSubgraphIndex. * return subgraph node and edge mapping. * fix benchmark. * add mapping to subgraph nid. * test in_edges and out_edges. * update benchmark. * Fix immutable subgraph. * complete in_edges and out_edges. * fix benchmarks. * add test on node_subgraphs * add sse. * construct subgraphs in parallel. * add in/out degree(s). * make immutable graph index backend-specific. * implement has_edge(s) and edge_id(s). * Revert "use spmv in gcn." This reverts commit 9cfed5f. * implement node_subgraphs in GraphIndex. * fix sse. * address comments. * address comments. * Update mxnet docker.
- Loading branch information
Showing
12 changed files
with
1,424 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
""" | ||
Learning Steady-States of Iterative Algorithms over Graphs | ||
Paper: http://proceedings.mlr.press/v80/dai18a.html | ||
""" | ||
import argparse | ||
import numpy as np | ||
import time | ||
import mxnet as mx | ||
from mxnet import gluon | ||
import dgl | ||
import dgl.function as fn | ||
from dgl import DGLGraph, utils | ||
from dgl.data import register_data_args, load_data | ||
|
||
def gcn_msg(src, edge): | ||
# TODO should we use concat? | ||
return {'m': mx.nd.concat(src['in'], src['h'], dim=1)} | ||
|
||
def gcn_reduce(node, msgs): | ||
return {'accum': mx.nd.sum(msgs['m'], 1)} | ||
|
||
class NodeUpdate(gluon.Block): | ||
def __init__(self, out_feats, activation=None, alpha=0.9): | ||
super(NodeUpdate, self).__init__() | ||
self.linear1 = gluon.nn.Dense(out_feats, activation=activation) | ||
# TODO what is the dimension here? | ||
self.linear2 = gluon.nn.Dense(out_feats) | ||
self.alpha = alpha | ||
|
||
def forward(self, node): | ||
tmp = mx.nd.concat(node['in'], node['accum'], dim=1) | ||
hidden = self.linear2(self.linear1(tmp)) | ||
return {'h': node['h'] * (1 - self.alpha) + self.alpha * hidden} | ||
|
||
class SSEUpdateHidden(gluon.Block): | ||
def __init__(self, | ||
n_hidden, | ||
activation, | ||
dropout, | ||
use_spmv): | ||
super(SSEUpdateHidden, self).__init__() | ||
self.layer = NodeUpdate(n_hidden, activation) | ||
self.dropout = dropout | ||
self.use_spmv = use_spmv | ||
|
||
def forward(self, g, vertices): | ||
if self.use_spmv: | ||
feat = g.get_n_repr()['in'] | ||
h = g.get_n_repr()['h'] | ||
g.set_n_repr({'cat': mx.nd.concat(feat, h, dim=1)}) | ||
|
||
msg_func = fn.copy_src(src='cat', out='tmp') | ||
reduce_func = fn.sum(msg='tmp', out='accum') | ||
else: | ||
msg_func = gcn_msg | ||
reduce_func = gcn_reduce | ||
if vertices is None: | ||
g.update_all(msg_func, reduce_func, self.layer) | ||
ret = g.get_n_repr()['h'] | ||
else: | ||
# We don't need dropout for inference. | ||
if self.dropout: | ||
# TODO here we apply dropout on all vertex representation. | ||
val = mx.nd.Dropout(g.get_n_repr()['h'], p=self.dropout) | ||
g.set_n_repr({'h': val}) | ||
g.pull(vertices, msg_func, reduce_func, self.layer) | ||
ctx = g.get_n_repr()['h'].context | ||
ret = mx.nd.take(g.get_n_repr()['h'], vertices.tousertensor().as_in_context(ctx)) | ||
return ret | ||
|
||
class SSEPredict(gluon.Block): | ||
def __init__(self, update_hidden, out_feats, dropout): | ||
super(SSEPredict, self).__init__() | ||
self.linear1 = gluon.nn.Dense(out_feats, activation='relu') | ||
self.linear2 = gluon.nn.Dense(out_feats) | ||
self.update_hidden = update_hidden | ||
self.dropout = dropout | ||
|
||
def forward(self, g, vertices): | ||
hidden = self.update_hidden(g, vertices) | ||
if self.dropout: | ||
hidden = mx.nd.Dropout(hidden, p=self.dropout) | ||
return self.linear2(self.linear1(hidden)) | ||
|
||
def subgraph_gen(g, seed_vertices): | ||
vertices = [] | ||
for seed in seed_vertices: | ||
src, _, _ = g.in_edges(seed) | ||
vs = np.concatenate((src.asnumpy(), seed.asnumpy()), axis=0) | ||
vs = mx.nd.array(np.unique(vs), dtype=np.int64) | ||
vertices.append(vs) | ||
subgs = g.subgraphs(vertices) | ||
nids = [] | ||
for i, subg in enumerate(subgs): | ||
subg.copy_from_parent() | ||
nids.append(subg.map_to_subgraph_nid(utils.toindex(seed_vertices[i]))) | ||
return subgs, nids | ||
|
||
def main(args, data): | ||
features = mx.nd.array(data.features) | ||
labels = mx.nd.array(data.labels) | ||
train_size = len(labels) * args.train_percent | ||
train_vs = np.arange(train_size, dtype='int64') | ||
eval_vs = np.arange(train_size, len(labels), dtype='int64') | ||
print("train size: " + str(len(train_vs))) | ||
print("eval size: " + str(len(eval_vs))) | ||
train_labels = mx.nd.array(data.labels[train_vs]) | ||
eval_labels = mx.nd.array(data.labels[eval_vs]) | ||
in_feats = features.shape[1] | ||
n_classes = data.num_labels | ||
n_edges = data.graph.number_of_edges() | ||
|
||
if args.gpu <= 0: | ||
cuda = False | ||
ctx = mx.cpu(0) | ||
else: | ||
cuda = True | ||
features = features.as_in_context(mx.gpu(0)) | ||
train_labels = train_labels.as_in_context(mx.gpu(0)) | ||
eval_labels = eval_labels.as_in_context(mx.gpu(0)) | ||
ctx = mx.gpu(0) | ||
|
||
# create the SSE model | ||
try: | ||
graph = data.graph.get_graph() | ||
except AttributeError: | ||
graph = data.graph | ||
g = DGLGraph(graph, readonly=True) | ||
g.set_n_repr({'in': features, 'h': mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden), | ||
ctx=ctx)}) | ||
|
||
update_hidden = SSEUpdateHidden(args.n_hidden, 'relu', args.update_dropout, args.use_spmv) | ||
model = SSEPredict(update_hidden, args.n_hidden, args.predict_dropout) | ||
model.initialize(ctx=ctx) | ||
|
||
# use optimizer | ||
num_batches = int(g.number_of_nodes() / args.batch_size) | ||
scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches, | ||
args.lr * 10, 0, 0, args.lr/5) | ||
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr, | ||
'lr_scheduler': scheduler}) | ||
|
||
# initialize graph | ||
dur = [] | ||
for epoch in range(args.n_epochs): | ||
# compute vertex embedding. | ||
update_hidden(g, None) | ||
|
||
t0 = time.time() | ||
permute = np.random.permutation(len(train_vs)) | ||
randv = train_vs[permute] | ||
rand_labels = train_labels[permute] | ||
data_iter = mx.io.NDArrayIter(data=mx.nd.array(randv, dtype='int64'), label=rand_labels, | ||
batch_size=args.batch_size) | ||
train_loss = 0 | ||
data = [] | ||
labels = [] | ||
for batch in data_iter: | ||
data.append(batch.data[0]) | ||
labels.append(batch.label[0]) | ||
if len(data) < args.num_parallel_subgraphs: | ||
continue | ||
|
||
subgs, seed_ids = subgraph_gen(g, data) | ||
for subg, seed_id, label, d in zip(subgs, seed_ids, labels, data): | ||
with mx.autograd.record(): | ||
logits = model(subg, seed_id) | ||
loss = mx.nd.softmax_cross_entropy(logits, label) | ||
loss.backward() | ||
trainer.step(d.shape[0]) | ||
train_loss += loss.asnumpy()[0] | ||
data = [] | ||
labels = [] | ||
|
||
#logits = model(eval_vs) | ||
#eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels) | ||
#eval_loss = eval_loss.asnumpy()[0] | ||
eval_loss = 0 | ||
|
||
dur.append(time.time() - t0) | ||
print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( | ||
epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000)) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='GCN') | ||
register_data_args(parser) | ||
parser.add_argument("--gpu", type=int, default=-1, | ||
help="gpu") | ||
parser.add_argument("--lr", type=float, default=1e-3, | ||
help="learning rate") | ||
parser.add_argument("--batch-size", type=int, default=128, | ||
help="number of vertices in a batch") | ||
parser.add_argument("--n-epochs", type=int, default=20, | ||
help="number of training epochs") | ||
parser.add_argument("--n-hidden", type=int, default=16, | ||
help="number of hidden gcn units") | ||
parser.add_argument("--warmup", type=int, default=10, | ||
help="number of iterations to warm up with large learning rate") | ||
parser.add_argument("--update-dropout", type=float, default=0.5, | ||
help="the dropout rate for updating vertex embedding") | ||
parser.add_argument("--predict-dropout", type=float, default=0.5, | ||
help="the dropout rate for prediction") | ||
parser.add_argument("--train_percent", type=float, default=0.5, | ||
help="the percentage of data used for training") | ||
parser.add_argument("--use-spmv", type=bool, default=False, | ||
help="use SpMV for faster speed.") | ||
parser.add_argument("--num-parallel-subgraphs", type=int, default=1, | ||
help="the number of subgraphs to construct in parallel.") | ||
args = parser.parse_args() | ||
|
||
# load and preprocess dataset | ||
data = load_data(args) | ||
main(args, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.