Skip to content

Commit

Permalink
[NN] RGCN modules (dmlc#744)
Browse files Browse the repository at this point in the history
* rgcn module

* support id input

* WIP: model codes

* use faster index select

* dropout

* self loop

* WIP: link prediction

* fix lint

* WIP: docs

* docstring

* docstring

* merge two child classes

* mxnet rgcn module

* fix lint

* fix lint

* fix rename bug

* add uniform edge sampler

* fix fn name

* docstring

* fix mxnet rgcn module

* fix mx rgcn

* enable test on cuda
  • Loading branch information
jermainewang authored Aug 23, 2019
1 parent 52d4535 commit 708765f
Show file tree
Hide file tree
Showing 18 changed files with 774 additions and 337 deletions.
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Requirements
------------
* sphinx
* sphinx-gallery
* sphinx_rtd_theme
* Both pytorch and mxnet installed.

Build documents
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/python/nn.mxnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ dgl.nn.mxnet.conv
:members: weight, bias, forward
:show-inheritance:

.. autoclass:: dgl.nn.mxnet.conv.RelGraphConv
:members: forward
:show-inheritance:

dgl.nn.mxnet.glob
-----------------

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ dgl.nn.pytorch.conv
:members: weight, bias, forward, reset_parameters
:show-inheritance:

.. autoclass:: dgl.nn.pytorch.conv.RelGraphConv
:members: forward
:show-inheritance:

dgl.nn.pytorch.glob
-------------------
.. automodule:: dgl.nn.pytorch.glob
Expand Down
4 changes: 2 additions & 2 deletions examples/mxnet/rgcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ AIFB: accuracy 97.22% (DGL), 95.83% (paper)
DGLBACKEND=mxnet python3 entity_classify.py -d aifb --testing --gpu 0
```

MUTAG: accuracy 76.47% (DGL), 73.23% (paper)
MUTAG: accuracy 73.53% (DGL), 73.23% (paper)
```
DGLBACKEND=mxnet python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```

BGS: accuracy 79.31% (DGL, n-basese=20, OOM when >20), 83.10% (paper)
BGS: accuracy 75.86% (DGL, n-basese=20, OOM when >20), 83.10% (paper)
```
DGLBACKEND=mxnet python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 20 --testing --gpu 0 --relabel
```
40 changes: 20 additions & 20 deletions examples/mxnet/rgcn/entity_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,27 @@
from mxnet import gluon
import mxnet.ndarray as F
from dgl import DGLGraph
from dgl.nn.mxnet import RelGraphConv
from dgl.contrib.data import load_data
from functools import partial

from model import BaseRGCN
from layers import RGCNBasisLayer as RGCNLayer


class EntityClassify(BaseRGCN):
def create_features(self):
features = mx.nd.arange(self.num_nodes)
if self.gpu_id >= 0:
features = features.as_in_context(mx.gpu(self.gpu_id))
return features

def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu, is_input_layer=True)
return RelGraphConv(self.num_nodes, self.h_dim, self.num_rels, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)

def build_hidden_layer(self, idx):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu)
return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)

def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels,self.num_bases,
activation=partial(F.softmax, axis=1))

return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
self.num_bases, activation=partial(F.softmax, axis=1),
self_loop=self.use_self_loop)

def main(args):
# load graph data
Expand All @@ -60,15 +55,18 @@ def main(args):
val_idx = train_idx

train_idx = mx.nd.array(train_idx)
# since the nodes are featureless, the input feature is then the node id.
feats = mx.nd.arange(num_nodes, dtype='int32')
# edge type and normalization factor
edge_type = mx.nd.array(data.edge_type)
edge_type = mx.nd.array(data.edge_type, dtype='int32')
edge_norm = mx.nd.array(data.edge_norm).expand_dims(1)
labels = mx.nd.array(labels).reshape((-1))

# check cuda
use_cuda = args.gpu >= 0
if use_cuda:
ctx = mx.gpu(args.gpu)
feats = feats.as_in_context(ctx)
edge_type = edge_type.as_in_context(ctx)
edge_norm = edge_norm.as_in_context(ctx)
labels = labels.as_in_context(ctx)
Expand All @@ -80,7 +78,6 @@ def main(args):
g = DGLGraph()
g.add_nodes(num_nodes)
g.add_edges(data.edge_src, data.edge_dst)
g.edata.update({'type': edge_type, 'norm': edge_norm})

# create model
model = EntityClassify(len(g),
Expand All @@ -90,6 +87,7 @@ def main(args):
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop,
gpu_id=args.gpu)
model.initialize(ctx=ctx)

Expand All @@ -104,7 +102,7 @@ def main(args):
for epoch in range(args.n_epochs):
t0 = time.time()
with mx.autograd.record():
pred = model(g)
pred = model(g, feats, edge_type, edge_norm)
loss = loss_fcn(pred[train_idx], labels[train_idx])
t1 = time.time()
loss.backward()
Expand All @@ -120,7 +118,7 @@ def main(args):
print("Train Accuracy: {:.4f} | Validation Accuracy: {:.4f}".format(train_acc, val_acc))
print()

logits = model(g)
logits = model.forward(g, feats, edge_type, edge_norm)
test_acc = F.sum(logits[test_idx].argmax(axis=1) == labels[test_idx]).asscalar() / len(test_idx)
print("Test Accuracy: {:.4f}".format(test_acc))
print()
Expand Down Expand Up @@ -151,6 +149,8 @@ def main(args):
help="l2 norm coef")
parser.add_argument("--relabel", default=False, action='store_true',
help="remove untouched nodes and relabel")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
Expand All @@ -159,4 +159,4 @@ def main(args):
args = parser.parse_args()
print(args)
args.bfs_level = args.n_layers + 1 # pruning used nodes for memory
main(args)
main(args)
96 changes: 0 additions & 96 deletions examples/mxnet/rgcn/layers.py

This file was deleted.

20 changes: 6 additions & 14 deletions examples/mxnet/rgcn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

class BaseRGCN(gluon.Block):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1,
num_hidden_layers=1, dropout=0, gpu_id=-1):
num_hidden_layers=1, dropout=0,
use_self_loop=False, gpu_id=-1):
super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
Expand All @@ -12,14 +13,12 @@ def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1,
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.gpu_id = gpu_id

# create rgcn layers
self.build_model()

# create initial features
self.features = self.create_features()

def build_model(self):
self.layers = gluon.nn.Sequential()
# i2h
Expand All @@ -35,10 +34,6 @@ def build_model(self):
if h2o is not None:
self.layers.add(h2o)

# initialize feature for each node
def create_features(self):
return None

def build_input_layer(self):
return None

Expand All @@ -48,10 +43,7 @@ def build_hidden_layer(self):
def build_output_layer(self):
return None

def forward(self, g):
if self.features is not None:
g.ndata['id'] = self.features
def forward(self, g, h, r, norm):
for layer in self.layers:
layer(g)
return g.ndata.pop('h')

h = layer(g, h, r, norm)
return h
Loading

0 comments on commit 708765f

Please sign in to comment.