Skip to content

Commit

Permalink
[WIP] [NN] Refactor NN package (dmlc#406)
Browse files Browse the repository at this point in the history
* refactor graph conv

* docs & tests

* fix lint

* fix lint

* fix lint

* fix lint script

* fix lint

* Update

* Style fix

* Fix style

* Fix style

* Fix gpu case

* Fix for gpu case

* Hotfix edgesoftmax docs

* Handle repeated features

* Add docstring

* Set default arguments

* Remove dropout from nn.conv

* Fix

* add util fn for renaming

* revert gcn_spmv.py

* mx folder

* fix wierd bug

* fix mx

* fix lint
  • Loading branch information
jermainewang authored Feb 25, 2019
1 parent 8c75017 commit 565f0c8
Show file tree
Hide file tree
Showing 30 changed files with 1,258 additions and 780 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ API Reference
sampler
data
transform
nn
13 changes: 13 additions & 0 deletions docs/source/api/python/nn.mxnet.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _apinn-mxnet:

dgl.nn.mxnet
============

dgl.nn.mxnet.conv
-----------------

.. automodule:: dgl.nn.mxnet.conv

.. autoclass:: dgl.nn.mxnet.conv.GraphConv
:members: weight, bias, forward
:show-inheritance:
22 changes: 22 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. _apinn-pytorch:

dgl.nn.pytorch
==============

dgl.nn.pytorch.conv
-------------------

.. automodule:: dgl.nn.pytorch.conv

.. autoclass:: dgl.nn.pytorch.conv.GraphConv
:members: weight, bias, forward, reset_parameters
:show-inheritance:

dgl.nn.pytorch.softmax
----------------------

.. automodule:: dgl.nn.pytorch.softmax

.. autoclass:: dgl.nn.pytorch.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/api/python/nn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _apinn:

dgl.nn
======

.. toctree::

nn.pytorch
nn.mxnet
18 changes: 10 additions & 8 deletions examples/mxnet/gcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ pip install requests

Codes
-----
The folder contains two implementations of GCN. `gcn.py` uses user-defined
message and reduce functions. `gcn_spmv.py` uses DGL's builtin functions so
SPMV optimization could be applied.
The folder contains three implementations of GCN:
- `gcn.py` uses DGL's predefined graph convolution module.
- `gcn_mp.py` uses user-defined message and reduce functions.
- `gcn_spmv.py` improves from `gcn_mp.py` by using DGL's builtin functions
so SPMV optimization could be applied.

The provided implementation in `gcn_concat.py` is a bit different from the
original paper for better performance, credit to @yifeim and @ZiyueHuang.
Expand All @@ -27,15 +29,15 @@ Results
-------
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
DGLBACKEND=mxnet python3 gcn_spmv.py --dataset cora --gpu 0
DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0
```

* cora: ~0.810 (paper: 0.815)
* citeseer: ~0.702 (paper: 0.703)
* pubmed: ~0.780 (paper: 0.790)

Results (`gcn_concat.py vs. gcn_spmv.py`)
-------------------------
Results (`gcn_concat.py vs. gcn.py`)
------------------------------------
`gcn_concat.py` uses concatenation of hidden units to account for multi-hop
skip-connections, while `gcn_spmv.py` uses simple additions (the original paper
omitted this detail). We feel concatenation is superior
Expand Down Expand Up @@ -90,10 +92,10 @@ DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "cora" --n-e
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "pubmed" --n-epochs 200 --n-layers 0
# Final accuracy 77.40% with 2-layer GCN
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_spmv.py --dataset "cora" --n-epochs 200 --n-layers 1
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_spmv.py --dataset "pubmed" --n-epochs 200 --n-layers 1
# Final accuracy 36.20% with 10-layer GCN
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_spmv.py --dataset "cora" --n-epochs 200 --n-layers 9
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_spmv.py --dataset "pubmed" --n-epochs 200 --n-layers 9
# Final accuracy 78.30% with 2-layer GCN with skip connection
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "pubmed" --n-epochs 200 --n-layers 2 --normalization 'sym' --self-loop
Expand Down
213 changes: 17 additions & 196 deletions examples/mxnet/gcn/gcn.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,14 @@
"""GCN using DGL nn package
References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with SPMV optimization
"""
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data


def gcn_msg(edge):
msg = edge.src['h'] * edge.src['norm']
return {'m': msg}


def gcn_reduce(node):
accum = mx.nd.sum(node.mailbox['m'], 1) * node.data['norm']
return {'h': accum}


class NodeUpdate(gluon.Block):
def __init__(self, out_feats, activation=None, bias=True):
super(NodeUpdate, self).__init__()
with self.name_scope():
if bias:
self.bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero())
else:
self.bias = None
self.activation = activation

def forward(self, node):
h = node.data['h']
if self.bias is not None:
h = h + self.bias.data(h.context)
if self.activation:
h = self.activation(h)
return {'h': h}

class GCNLayer(gluon.Block):
def __init__(self,
g,
in_feats,
out_feats,
activation,
dropout,
bias=True):
super(GCNLayer, self).__init__()
self.g = g
self.dropout = dropout
with self.name_scope():
self.weight = self.params.get('weight', shape=(in_feats, out_feats),
init=mx.init.Xavier())
self.node_update = NodeUpdate(out_feats, activation, bias)

def forward(self, h):
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
h = mx.nd.dot(h, self.weight.data(h.context))
self.g.ndata['h'] = h
self.g.update_all(gcn_msg, gcn_reduce, self.node_update)
h = self.g.ndata.pop('h')
return h

from dgl.nn.mxnet import GraphConv

class GCN(gluon.Block):
def __init__(self,
Expand All @@ -76,144 +18,23 @@ def __init__(self,
n_classes,
n_layers,
activation,
dropout,
normalization):
dropout):
super(GCN, self).__init__()
self.g = g
self.layers = gluon.nn.Sequential()
# input layer
self.layers.add(GCNLayer(g, in_feats, n_hidden, activation, 0))
self.layers.add(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.add(GCNLayer(g, n_hidden, n_hidden, activation, dropout))
self.layers.add(GraphConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.add(GCNLayer(g, n_hidden, n_classes, None, dropout))

self.layers.add(GraphConv(n_hidden, n_classes))
self.dropout = gluon.nn.Dropout(rate=dropout)

def forward(self, features):
h = features
for layer in self.layers:
h = layer(h)
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
return h

def evaluate(model, features, labels, mask):
pred = model(features).argmax(axis=1)
accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
return accuracy.asscalar()

def main(args):
# load and preprocess dataset
data = load_data(args)

if args.self_loop:
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])

features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
train_mask = mx.nd.array(data.train_mask)
val_mask = mx.nd.array(data.val_mask)
test_mask = mx.nd.array(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().asscalar(),
val_mask.sum().asscalar(),
test_mask.sum().asscalar()))

if args.gpu < 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
ctx = mx.gpu(args.gpu)

features = features.as_in_context(ctx)
labels = labels.as_in_context(ctx)
train_mask = train_mask.as_in_context(ctx)
val_mask = val_mask.as_in_context(ctx)
test_mask = test_mask.as_in_context(ctx)

# create GCN model
g = DGLGraph(data.graph)
# normalization
degs = g.in_degrees().astype('float32')
norm = mx.nd.power(degs, -0.5)
if cuda:
norm = norm.as_in_context(ctx)
g.ndata['norm'] = mx.nd.expand_dims(norm, 1)

model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
args.dropout,
args.normalization)
model.initialize(ctx=ctx)
n_train_samples = train_mask.sum().asscalar()
loss_fcn = gluon.loss.SoftmaxCELoss()

# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay})

# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
with mx.autograd.record():
pred = model(features)
loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))
loss = loss.sum() / n_train_samples

loss.backward()
trainer.step(batch_size=1)

if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(
epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000))

# test set accuracy
acc = evaluate(model, features, labels, test_mask)
print("Test accuracy {:.2%}".format(acc))

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--normalization",
choices=['sym','left'], default=None,
help="graph normalization types (default=None)")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()

print(args)

main(args)
Loading

0 comments on commit 565f0c8

Please sign in to comment.