From 708765f0a11d29b3d3f27857aeebfd10498f1937 Mon Sep 17 00:00:00 2001 From: Minjie Wang Date: Fri, 23 Aug 2019 16:38:48 -0400 Subject: [PATCH] [NN] RGCN modules (#744) * rgcn module * support id input * WIP: model codes * use faster index select * dropout * self loop * WIP: link prediction * fix lint * WIP: docs * docstring * docstring * merge two child classes * mxnet rgcn module * fix lint * fix lint * fix rename bug * add uniform edge sampler * fix fn name * docstring * fix mxnet rgcn module * fix mx rgcn * enable test on cuda --- docs/README.md | 1 + docs/source/api/python/nn.mxnet.rst | 4 + docs/source/api/python/nn.pytorch.rst | 4 + examples/mxnet/rgcn/README.md | 4 +- examples/mxnet/rgcn/entity_classify.py | 40 ++--- examples/mxnet/rgcn/layers.py | 96 ----------- examples/mxnet/rgcn/model.py | 20 +-- examples/pytorch/rgcn/entity_classify.py | 31 ++-- examples/pytorch/rgcn/layers.py | 133 --------------- examples/pytorch/rgcn/link_predict.py | 69 ++++---- examples/pytorch/rgcn/model.py | 24 +-- examples/pytorch/rgcn/utils.py | 26 ++- python/dgl/nn/mxnet/conv.py | 196 ++++++++++++++++++++++- python/dgl/nn/mxnet/utils.py | 86 ++++++++++ python/dgl/nn/pytorch/conv.py | 188 +++++++++++++++++++++- python/dgl/nn/pytorch/utils.py | 88 ++++++++++ tests/mxnet/test_nn.py | 55 ++++++- tests/pytorch/test_nn.py | 46 ++++++ 18 files changed, 774 insertions(+), 337 deletions(-) delete mode 100644 examples/mxnet/rgcn/layers.py delete mode 100644 examples/pytorch/rgcn/layers.py create mode 100644 python/dgl/nn/mxnet/utils.py create mode 100644 python/dgl/nn/pytorch/utils.py diff --git a/docs/README.md b/docs/README.md index d928fe266925..b26777e2c257 100644 --- a/docs/README.md +++ b/docs/README.md @@ -5,6 +5,7 @@ Requirements ------------ * sphinx * sphinx-gallery +* sphinx_rtd_theme * Both pytorch and mxnet installed. Build documents diff --git a/docs/source/api/python/nn.mxnet.rst b/docs/source/api/python/nn.mxnet.rst index 162be4f61ba8..fba29c80c025 100644 --- a/docs/source/api/python/nn.mxnet.rst +++ b/docs/source/api/python/nn.mxnet.rst @@ -12,6 +12,10 @@ dgl.nn.mxnet.conv :members: weight, bias, forward :show-inheritance: +.. autoclass:: dgl.nn.mxnet.conv.RelGraphConv + :members: forward + :show-inheritance: + dgl.nn.mxnet.glob ----------------- diff --git a/docs/source/api/python/nn.pytorch.rst b/docs/source/api/python/nn.pytorch.rst index 24ba02d41530..cf7535df8ca8 100644 --- a/docs/source/api/python/nn.pytorch.rst +++ b/docs/source/api/python/nn.pytorch.rst @@ -12,6 +12,10 @@ dgl.nn.pytorch.conv :members: weight, bias, forward, reset_parameters :show-inheritance: +.. autoclass:: dgl.nn.pytorch.conv.RelGraphConv + :members: forward + :show-inheritance: + dgl.nn.pytorch.glob ------------------- .. automodule:: dgl.nn.pytorch.glob diff --git a/examples/mxnet/rgcn/README.md b/examples/mxnet/rgcn/README.md index 3c8c772d1c3e..9eee18ab73d3 100644 --- a/examples/mxnet/rgcn/README.md +++ b/examples/mxnet/rgcn/README.md @@ -25,12 +25,12 @@ AIFB: accuracy 97.22% (DGL), 95.83% (paper) DGLBACKEND=mxnet python3 entity_classify.py -d aifb --testing --gpu 0 ``` -MUTAG: accuracy 76.47% (DGL), 73.23% (paper) +MUTAG: accuracy 73.53% (DGL), 73.23% (paper) ``` DGLBACKEND=mxnet python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 40 --testing --gpu 0 ``` -BGS: accuracy 79.31% (DGL, n-basese=20, OOM when >20), 83.10% (paper) +BGS: accuracy 75.86% (DGL, n-basese=20, OOM when >20), 83.10% (paper) ``` DGLBACKEND=mxnet python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 20 --testing --gpu 0 --relabel ``` diff --git a/examples/mxnet/rgcn/entity_classify.py b/examples/mxnet/rgcn/entity_classify.py index a5374c5ebec9..79ba10c725f6 100644 --- a/examples/mxnet/rgcn/entity_classify.py +++ b/examples/mxnet/rgcn/entity_classify.py @@ -15,32 +15,27 @@ from mxnet import gluon import mxnet.ndarray as F from dgl import DGLGraph +from dgl.nn.mxnet import RelGraphConv from dgl.contrib.data import load_data from functools import partial from model import BaseRGCN -from layers import RGCNBasisLayer as RGCNLayer - class EntityClassify(BaseRGCN): - def create_features(self): - features = mx.nd.arange(self.num_nodes) - if self.gpu_id >= 0: - features = features.as_in_context(mx.gpu(self.gpu_id)) - return features - def build_input_layer(self): - return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases, - activation=F.relu, is_input_layer=True) + return RelGraphConv(self.num_nodes, self.h_dim, self.num_rels, "basis", + self.num_bases, activation=F.relu, self_loop=self.use_self_loop, + dropout=self.dropout) def build_hidden_layer(self, idx): - return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, - activation=F.relu) + return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis", + self.num_bases, activation=F.relu, self_loop=self.use_self_loop, + dropout=self.dropout) def build_output_layer(self): - return RGCNLayer(self.h_dim, self.out_dim, self.num_rels,self.num_bases, - activation=partial(F.softmax, axis=1)) - + return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis", + self.num_bases, activation=partial(F.softmax, axis=1), + self_loop=self.use_self_loop) def main(args): # load graph data @@ -60,8 +55,10 @@ def main(args): val_idx = train_idx train_idx = mx.nd.array(train_idx) + # since the nodes are featureless, the input feature is then the node id. + feats = mx.nd.arange(num_nodes, dtype='int32') # edge type and normalization factor - edge_type = mx.nd.array(data.edge_type) + edge_type = mx.nd.array(data.edge_type, dtype='int32') edge_norm = mx.nd.array(data.edge_norm).expand_dims(1) labels = mx.nd.array(labels).reshape((-1)) @@ -69,6 +66,7 @@ def main(args): use_cuda = args.gpu >= 0 if use_cuda: ctx = mx.gpu(args.gpu) + feats = feats.as_in_context(ctx) edge_type = edge_type.as_in_context(ctx) edge_norm = edge_norm.as_in_context(ctx) labels = labels.as_in_context(ctx) @@ -80,7 +78,6 @@ def main(args): g = DGLGraph() g.add_nodes(num_nodes) g.add_edges(data.edge_src, data.edge_dst) - g.edata.update({'type': edge_type, 'norm': edge_norm}) # create model model = EntityClassify(len(g), @@ -90,6 +87,7 @@ def main(args): num_bases=args.n_bases, num_hidden_layers=args.n_layers - 2, dropout=args.dropout, + use_self_loop=args.use_self_loop, gpu_id=args.gpu) model.initialize(ctx=ctx) @@ -104,7 +102,7 @@ def main(args): for epoch in range(args.n_epochs): t0 = time.time() with mx.autograd.record(): - pred = model(g) + pred = model(g, feats, edge_type, edge_norm) loss = loss_fcn(pred[train_idx], labels[train_idx]) t1 = time.time() loss.backward() @@ -120,7 +118,7 @@ def main(args): print("Train Accuracy: {:.4f} | Validation Accuracy: {:.4f}".format(train_acc, val_acc)) print() - logits = model(g) + logits = model.forward(g, feats, edge_type, edge_norm) test_acc = F.sum(logits[test_idx].argmax(axis=1) == labels[test_idx]).asscalar() / len(test_idx) print("Test Accuracy: {:.4f}".format(test_acc)) print() @@ -151,6 +149,8 @@ def main(args): help="l2 norm coef") parser.add_argument("--relabel", default=False, action='store_true', help="remove untouched nodes and relabel") + parser.add_argument("--use-self-loop", default=False, action='store_true', + help="include self feature as a special relation") fp = parser.add_mutually_exclusive_group(required=False) fp.add_argument('--validation', dest='validation', action='store_true') fp.add_argument('--testing', dest='validation', action='store_false') @@ -159,4 +159,4 @@ def main(args): args = parser.parse_args() print(args) args.bfs_level = args.n_layers + 1 # pruning used nodes for memory - main(args) \ No newline at end of file + main(args) diff --git a/examples/mxnet/rgcn/layers.py b/examples/mxnet/rgcn/layers.py deleted file mode 100644 index 27ae7a9e8f20..000000000000 --- a/examples/mxnet/rgcn/layers.py +++ /dev/null @@ -1,96 +0,0 @@ -import math - -import mxnet as mx -from mxnet import gluon -import mxnet.ndarray as F -import dgl.function as fn - -class RGCNLayer(gluon.Block): - def __init__(self, in_feat, out_feat, bias=None, activation=None, - self_loop=False, dropout=0.0): - super(RGCNLayer, self).__init__() - self.bias = bias - self.activation = activation - self.self_loop = self_loop - - if self.bias == True: - self.bias = self.params.get('bias', shape=(out_feat,), - init=mx.init.Xavier(magnitude=math.sqrt(2.0))) - - # weight for self loop - if self.self_loop: - self.loop_weight = self.params.get('loop_weight', shape=(in_feat, out_feat), - init=mx.init.Xavier(magnitude=math.sqrt(2.0))) - if dropout: - self.dropout = gluon.nn.Dropout(dropout) - else: - self.dropout = None - - # define how propagation is done in subclass - def propagate(self, g): - raise NotImplementedError - - def forward(self, g): - if self.self_loop: - loop_message = F.dot(g.ndata['h'], self.loop_weight) - if self.dropout is not None: - loop_message = self.dropout(loop_message) - - self.propagate(g) - - # apply bias and activation - node_repr = g.ndata['h'] - if self.bias: - node_repr = node_repr + self.bias - if self.self_loop: - node_repr = node_repr + loop_message - if self.activation: - node_repr = self.activation(node_repr) - - g.ndata['h'] = node_repr - - -class RGCNBasisLayer(RGCNLayer): - def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None, - activation=None, is_input_layer=False): - super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation) - self.in_feat = in_feat - self.out_feat = out_feat - self.num_rels = num_rels - self.num_bases = num_bases - self.is_input_layer = is_input_layer - if self.num_bases <= 0 or self.num_bases > self.num_rels: - self.num_bases = self.num_rels - - # add basis weights - if self.num_bases < self.num_rels: - # linear combination coefficients - self.weight = self.params.get('weight', shape=(self.num_bases, self.in_feat * self.out_feat)) - self.w_comp = self.params.get('w_comp', shape=(self.num_rels, self.num_bases), - init=mx.init.Xavier(magnitude=math.sqrt(2.0))) - else: - self.weight = self.params.get('weight', shape=(self.num_bases, self.in_feat, self.out_feat), - init=mx.init.Xavier(magnitude=math.sqrt(2.0))) - - def propagate(self, g): - if self.num_bases < self.num_rels: - # generate all weights from bases - weight = F.dot(self.w_comp.data(), self.weight.data()).reshape((self.num_rels, self.in_feat, self.out_feat)) - else: - weight = self.weight.data() - - if self.is_input_layer: - def msg_func(edges): - # for input layer, matrix multiply can be converted to be - # an embedding lookup using source node id - embed = F.reshape(weight, (-1, self.out_feat)) - index = edges.data['type'] * self.in_feat + edges.src['id'] - return {'msg': embed[index] * edges.data['norm']} - else: - def msg_func(edges): - w = weight[edges.data['type']] - msg = F.batch_dot(edges.src['h'].expand_dims(1), w).reshape(-1, self.out_feat) - msg = msg * edges.data['norm'] - return {'msg': msg} - - g.update_all(msg_func, fn.sum(msg='msg', out='h'), None) \ No newline at end of file diff --git a/examples/mxnet/rgcn/model.py b/examples/mxnet/rgcn/model.py index 14dd1a4cb853..77c211bbced3 100644 --- a/examples/mxnet/rgcn/model.py +++ b/examples/mxnet/rgcn/model.py @@ -3,7 +3,8 @@ class BaseRGCN(gluon.Block): def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, - num_hidden_layers=1, dropout=0, gpu_id=-1): + num_hidden_layers=1, dropout=0, + use_self_loop=False, gpu_id=-1): super(BaseRGCN, self).__init__() self.num_nodes = num_nodes self.h_dim = h_dim @@ -12,14 +13,12 @@ def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, self.num_bases = num_bases self.num_hidden_layers = num_hidden_layers self.dropout = dropout + self.use_self_loop = use_self_loop self.gpu_id = gpu_id # create rgcn layers self.build_model() - # create initial features - self.features = self.create_features() - def build_model(self): self.layers = gluon.nn.Sequential() # i2h @@ -35,10 +34,6 @@ def build_model(self): if h2o is not None: self.layers.add(h2o) - # initialize feature for each node - def create_features(self): - return None - def build_input_layer(self): return None @@ -48,10 +43,7 @@ def build_hidden_layer(self): def build_output_layer(self): return None - def forward(self, g): - if self.features is not None: - g.ndata['id'] = self.features + def forward(self, g, h, r, norm): for layer in self.layers: - layer(g) - return g.ndata.pop('h') - + h = layer(g, h, r, norm) + return h diff --git a/examples/pytorch/rgcn/entity_classify.py b/examples/pytorch/rgcn/entity_classify.py index a22388af62aa..795dabe1b85d 100644 --- a/examples/pytorch/rgcn/entity_classify.py +++ b/examples/pytorch/rgcn/entity_classify.py @@ -14,11 +14,10 @@ import torch import torch.nn.functional as F from dgl import DGLGraph +from dgl.nn.pytorch import RelGraphConv from dgl.contrib.data import load_data -import dgl.function as fn from functools import partial -from layers import RGCNBasisLayer as RGCNLayer from model import BaseRGCN class EntityClassify(BaseRGCN): @@ -29,16 +28,19 @@ def create_features(self): return features def build_input_layer(self): - return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases, - activation=F.relu, is_input_layer=True) + return RelGraphConv(self.num_nodes, self.h_dim, self.num_rels, "basis", + self.num_bases, activation=F.relu, self_loop=self.use_self_loop, + dropout=self.dropout) def build_hidden_layer(self, idx): - return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, - activation=F.relu) + return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis", + self.num_bases, activation=F.relu, self_loop=self.use_self_loop, + dropout=self.dropout) def build_output_layer(self): - return RGCNLayer(self.h_dim, self.out_dim, self.num_rels,self.num_bases, - activation=partial(F.softmax, dim=1)) + return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis", + self.num_bases, activation=partial(F.softmax, dim=1), + self_loop=self.use_self_loop) def main(args): # load graph data @@ -57,6 +59,9 @@ def main(args): else: val_idx = train_idx + # since the nodes are featureless, the input feature is then the node id. + feats = torch.arange(num_nodes) + # edge type and normalization factor edge_type = torch.from_numpy(data.edge_type) edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1) @@ -66,6 +71,7 @@ def main(args): use_cuda = args.gpu >= 0 and torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.gpu) + feats = feats.cuda() edge_type = edge_type.cuda() edge_norm = edge_norm.cuda() labels = labels.cuda() @@ -74,7 +80,6 @@ def main(args): g = DGLGraph() g.add_nodes(num_nodes) g.add_edges(data.edge_src, data.edge_dst) - g.edata.update({'type': edge_type, 'norm': edge_norm}) # create model model = EntityClassify(len(g), @@ -84,6 +89,7 @@ def main(args): num_bases=args.n_bases, num_hidden_layers=args.n_layers - 2, dropout=args.dropout, + use_self_loop=args.use_self_loop, use_cuda=use_cuda) if use_cuda: @@ -100,7 +106,7 @@ def main(args): for epoch in range(args.n_epochs): optimizer.zero_grad() t0 = time.time() - logits = model.forward(g) + logits = model(g, feats, edge_type, edge_norm) loss = F.cross_entropy(logits[train_idx], labels[train_idx]) t1 = time.time() loss.backward() @@ -119,7 +125,7 @@ def main(args): print() model.eval() - logits = model.forward(g) + logits = model.forward(g, feats, edge_type, edge_norm) test_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item())) @@ -151,6 +157,8 @@ def main(args): help="l2 norm coef") parser.add_argument("--relabel", default=False, action='store_true', help="remove untouched nodes and relabel") + parser.add_argument("--use-self-loop", default=False, action='store_true', + help="include self feature as a special relation") fp = parser.add_mutually_exclusive_group(required=False) fp.add_argument('--validation', dest='validation', action='store_true') fp.add_argument('--testing', dest='validation', action='store_false') @@ -160,4 +168,3 @@ def main(args): print(args) args.bfs_level = args.n_layers + 1 # pruning used nodes for memory main(args) - diff --git a/examples/pytorch/rgcn/layers.py b/examples/pytorch/rgcn/layers.py deleted file mode 100644 index f96a16d9a745..000000000000 --- a/examples/pytorch/rgcn/layers.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch -import torch.nn as nn -import dgl.function as fn - -class RGCNLayer(nn.Module): - def __init__(self, in_feat, out_feat, bias=None, activation=None, - self_loop=False, dropout=0.0): - super(RGCNLayer, self).__init__() - self.bias = bias - self.activation = activation - self.self_loop = self_loop - - if self.bias == True: - self.bias = nn.Parameter(torch.Tensor(out_feat)) - nn.init.xavier_uniform_(self.bias, - gain=nn.init.calculate_gain('relu')) - - # weight for self loop - if self.self_loop: - self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat)) - nn.init.xavier_uniform_(self.loop_weight, - gain=nn.init.calculate_gain('relu')) - - if dropout: - self.dropout = nn.Dropout(dropout) - else: - self.dropout = None - - # define how propagation is done in subclass - def propagate(self, g): - raise NotImplementedError - - def forward(self, g): - if self.self_loop: - loop_message = torch.mm(g.ndata['h'], self.loop_weight) - if self.dropout is not None: - loop_message = self.dropout(loop_message) - - self.propagate(g) - - # apply bias and activation - node_repr = g.ndata['h'] - if self.bias: - node_repr = node_repr + self.bias - if self.self_loop: - node_repr = node_repr + loop_message - if self.activation: - node_repr = self.activation(node_repr) - - g.ndata['h'] = node_repr - -class RGCNBasisLayer(RGCNLayer): - def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None, - activation=None, is_input_layer=False): - super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation) - self.in_feat = in_feat - self.out_feat = out_feat - self.num_rels = num_rels - self.num_bases = num_bases - self.is_input_layer = is_input_layer - if self.num_bases <= 0 or self.num_bases > self.num_rels: - self.num_bases = self.num_rels - - # add basis weights - self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat, - self.out_feat)) - if self.num_bases < self.num_rels: - # linear combination coefficients - self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, - self.num_bases)) - nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) - if self.num_bases < self.num_rels: - nn.init.xavier_uniform_(self.w_comp, - gain=nn.init.calculate_gain('relu')) - - def propagate(self, g): - if self.num_bases < self.num_rels: - # generate all weights from bases - weight = self.weight.view(self.num_bases, - self.in_feat * self.out_feat) - weight = torch.matmul(self.w_comp, weight).view( - self.num_rels, self.in_feat, self.out_feat) - else: - weight = self.weight - - if self.is_input_layer: - def msg_func(edges): - # for input layer, matrix multiply can be converted to be - # an embedding lookup using source node id - embed = weight.view(-1, self.out_feat) - index = edges.data['type'] * self.in_feat + edges.src['id'] - return {'msg': embed.index_select(0, index) * edges.data['norm']} - else: - def msg_func(edges): - w = weight.index_select(0, edges.data['type']) - msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze() - msg = msg * edges.data['norm'] - return {'msg': msg} - - g.update_all(msg_func, fn.sum(msg='msg', out='h'), None) - -class RGCNBlockLayer(RGCNLayer): - def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=None, - activation=None, self_loop=False, dropout=0.0): - super(RGCNBlockLayer, self).__init__(in_feat, out_feat, bias, - activation, self_loop=self_loop, - dropout=dropout) - self.num_rels = num_rels - self.num_bases = num_bases - assert self.num_bases > 0 - - self.out_feat = out_feat - self.submat_in = in_feat // self.num_bases - self.submat_out = out_feat // self.num_bases - - # assuming in_feat and out_feat are both divisible by num_bases - self.weight = nn.Parameter(torch.Tensor( - self.num_rels, self.num_bases * self.submat_in * self.submat_out)) - nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) - - def msg_func(self, edges): - weight = self.weight.index_select(0, edges.data['type']).view( - -1, self.submat_in, self.submat_out) - node = edges.src['h'].view(-1, 1, self.submat_in) - msg = torch.bmm(node, weight).view(-1, self.out_feat) - return {'msg': msg} - - def propagate(self, g): - g.update_all(self.msg_func, fn.sum(msg='msg', out='h'), self.apply_func) - - def apply_func(self, nodes): - return {'h': nodes.data['h'] * nodes.data['norm']} - diff --git a/examples/pytorch/rgcn/link_predict.py b/examples/pytorch/rgcn/link_predict.py index c2fa19ab7838..7925233fdae6 100644 --- a/examples/pytorch/rgcn/link_predict.py +++ b/examples/pytorch/rgcn/link_predict.py @@ -4,7 +4,12 @@ Code: https://github.com/MichSchli/RelationPrediction Difference compared to MichSchli/RelationPrediction -* report raw metrics instead of filtered metrics +* Report raw metrics instead of filtered metrics. +* By default, we use uniform edge sampling instead of neighbor-based edge + sampling used in author's code. In practice, we find it achieves similar MRR + probably because the model only uses one GNN layer so messages are propagated + among immediate neighbors. User could specify "--edge-sampler=neighbor" to switch + to neighbor-based edge sampling. """ import argparse @@ -15,8 +20,8 @@ import torch.nn.functional as F import random from dgl.contrib.data import load_data +from dgl.nn.pytorch import RelGraphConv -from layers import RGCNBlockLayer as RGCNLayer from model import BaseRGCN import utils @@ -26,9 +31,8 @@ def __init__(self, num_nodes, h_dim): super(EmbeddingLayer, self).__init__() self.embedding = torch.nn.Embedding(num_nodes, h_dim) - def forward(self, g): - node_id = g.ndata['id'].squeeze() - g.ndata['h'] = self.embedding(node_id) + def forward(self, g, h, r, norm): + return self.embedding(h.squeeze()) class RGCN(BaseRGCN): def build_input_layer(self): @@ -36,8 +40,9 @@ def build_input_layer(self): def build_hidden_layer(self, idx): act = F.relu if idx < self.num_hidden_layers - 1 else None - return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, - activation=act, self_loop=True, dropout=self.dropout) + return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "bdd", + self.num_bases, activation=act, self_loop=True, + dropout=self.dropout) class LinkPredict(nn.Module): def __init__(self, in_dim, h_dim, num_rels, num_bases=-1, @@ -58,26 +63,26 @@ def calc_score(self, embedding, triplets): score = torch.sum(s * r * o, dim=1) return score - def forward(self, g): - return self.rgcn.forward(g) - - def evaluate(self, g): - # get embedding and relation weight without grad - embedding = self.forward(g) - return embedding, self.w_relation + def forward(self, g, h, r, norm): + return self.rgcn.forward(g, h, r, norm) def regularization_loss(self, embedding): return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2)) - def get_loss(self, g, triplets, labels): + def get_loss(self, g, embed, triplets, labels): # triplets is a list of data samples (positive and negative) # each row in the triplets is a 3-tuple of (source, relation, destination) - embedding = self.forward(g) - score = self.calc_score(embedding, triplets) + score = self.calc_score(embed, triplets) predict_loss = F.binary_cross_entropy_with_logits(score, labels) - reg_loss = self.regularization_loss(embedding) + reg_loss = self.regularization_loss(embed) return predict_loss + self.reg_param * reg_loss +def node_norm_to_edge_norm(g, node_norm): + g = g.local_var() + # convert to edge norm + g.ndata['norm'] = node_norm + g.apply_edges(lambda edges : {'norm' : edges.dst['norm']}) + return g.edata['norm'] def main(args): # load graph data @@ -114,9 +119,7 @@ def main(args): range(test_graph.number_of_nodes())).float().view(-1,1) test_node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1) test_rel = torch.from_numpy(test_rel) - test_norm = torch.from_numpy(test_norm).view(-1, 1) - test_graph.ndata.update({'id': test_node_id, 'norm': test_norm}) - test_graph.edata['type'] = test_rel + test_norm = node_norm_to_edge_norm(test_graph, torch.from_numpy(test_norm).view(-1, 1)) if use_cuda: model.cuda() @@ -144,24 +147,24 @@ def main(args): g, node_id, edge_type, node_norm, data, labels = \ utils.generate_sampled_graph_and_labels( train_data, args.graph_batch_size, args.graph_split_size, - num_rels, adj_list, degrees, args.negative_sample) + num_rels, adj_list, degrees, args.negative_sample, + args.edge_sampler) print("Done edge sampling") # set node/edge feature node_id = torch.from_numpy(node_id).view(-1, 1).long() edge_type = torch.from_numpy(edge_type) - node_norm = torch.from_numpy(node_norm).view(-1, 1) + edge_norm = node_norm_to_edge_norm(g, torch.from_numpy(node_norm).view(-1, 1)) data, labels = torch.from_numpy(data), torch.from_numpy(labels) deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1) if use_cuda: node_id, deg = node_id.cuda(), deg.cuda() - edge_type, node_norm = edge_type.cuda(), node_norm.cuda() + edge_type, edge_norm = edge_type.cuda(), edge_norm.cuda() data, labels = data.cuda(), labels.cuda() - g.ndata.update({'id': node_id, 'norm': node_norm}) - g.edata['type'] = edge_type t0 = time.time() - loss = model.get_loss(g, data, labels) + embed = model(g, node_id, edge_type, edge_norm) + loss = model.get_loss(g, embed, data, labels) t1 = time.time() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients @@ -182,7 +185,8 @@ def main(args): model.cpu() model.eval() print("start eval") - mrr = utils.evaluate(test_graph, model, valid_data, + embed = model(test_graph, test_node_id, test_rel, test_norm) + mrr = utils.calc_mrr(embed, model.w_relation, valid_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size) # save best model if mrr < best_mrr: @@ -207,9 +211,9 @@ def main(args): model.eval() model.load_state_dict(checkpoint['state_dict']) print("Using best epoch: {}".format(checkpoint['epoch'])) - utils.evaluate(test_graph, model, test_data, hits=[1, 3, 10], - eval_bz=args.eval_batch_size) - + embed = model(test_graph, test_node_id, test_rel, test_norm) + utils.calc_mrr(embed, model.w_relation, test_data, + hits=[1, 3, 10], eval_bz=args.eval_batch_size) if __name__ == '__main__': parser = argparse.ArgumentParser(description='RGCN') @@ -243,8 +247,9 @@ def main(args): help="number of negative samples per positive sample") parser.add_argument("--evaluate-every", type=int, default=500, help="perform evaluation every n epochs") + parser.add_argument("--edge-sampler", type=str, default="uniform", + help="type of edge sampler: 'uniform' or 'neighbor'") args = parser.parse_args() print(args) main(args) - diff --git a/examples/pytorch/rgcn/model.py b/examples/pytorch/rgcn/model.py index 2b6f8b741c6a..84631ed65784 100644 --- a/examples/pytorch/rgcn/model.py +++ b/examples/pytorch/rgcn/model.py @@ -1,24 +1,23 @@ import torch.nn as nn class BaseRGCN(nn.Module): - def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, - num_hidden_layers=1, dropout=0, use_cuda=False): + def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, + num_hidden_layers=1, dropout=0, + use_self_loop=False, use_cuda=False): super(BaseRGCN, self).__init__() self.num_nodes = num_nodes self.h_dim = h_dim self.out_dim = out_dim self.num_rels = num_rels - self.num_bases = num_bases + self.num_bases = None if num_bases < 0 else num_bases self.num_hidden_layers = num_hidden_layers self.dropout = dropout + self.use_self_loop = use_self_loop self.use_cuda = use_cuda # create rgcn layers self.build_model() - # create initial features - self.features = self.create_features() - def build_model(self): self.layers = nn.ModuleList() # i2h @@ -34,10 +33,6 @@ def build_model(self): if h2o is not None: self.layers.append(h2o) - # initialize feature for each node - def create_features(self): - return None - def build_input_layer(self): return None @@ -47,10 +42,7 @@ def build_hidden_layer(self, idx): def build_output_layer(self): return None - def forward(self, g): - if self.features is not None: - g.ndata['id'] = self.features + def forward(self, g, h, r, norm): for layer in self.layers: - layer(g) - return g.ndata.pop('h') - + h = layer(g, h, r, norm) + return h diff --git a/examples/pytorch/rgcn/utils.py b/examples/pytorch/rgcn/utils.py index 15f21aaf635c..847cfa5e4684 100644 --- a/examples/pytorch/rgcn/utils.py +++ b/examples/pytorch/rgcn/utils.py @@ -28,9 +28,11 @@ def get_adj_and_degrees(num_nodes, triplets): return adj_list, degrees def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size): - """ Edge neighborhood sampling to reduce training graph size - """ + """Sample edges by neighborhool expansion. + This guarantees that the sampled edges form a connected graph, which + may help deeper GNNs that require information from more than one hop. + """ edges = np.zeros((sample_size), dtype=np.int32) #initialize @@ -69,16 +71,25 @@ def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size): return edges +def sample_edge_uniform(adj_list, degrees, n_triplets, sample_size): + """Sample edges uniformly from all the edges.""" + all_edges = np.arange(n_triplets) + return np.random.choice(all_edges, sample_size, replace=False) + def generate_sampled_graph_and_labels(triplets, sample_size, split_size, num_rels, adj_list, degrees, - negative_rate): + negative_rate, sampler="uniform"): """Get training graph and signals First perform edge neighborhood sampling on graph, then perform negative sampling to generate negative samples """ # perform edge neighbor sampling - edges = sample_edge_neighborhood(adj_list, degrees, len(triplets), - sample_size) + if sampler == "uniform": + edges = sample_edge_uniform(adj_list, degrees, len(triplets), sample_size) + elif sampler == "neighbor": + edges = sample_edge_neighborhood(adj_list, degrees, len(triplets), sample_size) + else: + raise ValueError("Sampler type must be either 'uniform' or 'neighbor'.") # relabel nodes to have consecutive node ids edges = triplets[edges] @@ -108,6 +119,7 @@ def generate_sampled_graph_and_labels(triplets, sample_size, split_size, return g, uniq_v, rel, norm, samples, labels def comp_deg_norm(g): + g = g.local_var() in_deg = g.in_degrees(range(g.number_of_nodes())).float().numpy() norm = 1.0 / in_deg norm[np.isinf(norm)] = 0 @@ -187,9 +199,8 @@ def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100): # TODO (lingfan): implement filtered metrics # return MRR (raw), and Hits @ (1, 3, 10) -def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100): +def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): with torch.no_grad(): - embedding, w = model.evaluate(test_graph) s = test_triplets[:, 0] r = test_triplets[:, 1] o = test_triplets[:, 2] @@ -210,4 +221,3 @@ def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100): avg_count = torch.mean((ranks <= hit).float()) print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item())) return mrr.item() - diff --git a/python/dgl/nn/mxnet/conv.py b/python/dgl/nn/mxnet/conv.py index 72cc92429aae..641322676a68 100644 --- a/python/dgl/nn/mxnet/conv.py +++ b/python/dgl/nn/mxnet/conv.py @@ -1,11 +1,15 @@ """MXNet modules for graph convolutions.""" # pylint: disable= no-member, arguments-differ +import math import mxnet as mx -from mxnet import gluon +from mxnet import gluon, nd +from mxnet.gluon import nn +import numpy as np +from . import utils from ... import function as fn -__all__ = ['GraphConv'] +__all__ = ['GraphConv', 'RelGraphConv'] class GraphConv(gluon.Block): r"""Apply graph convolution over an input signal. @@ -142,3 +146,191 @@ def __repr__(self): self._norm, self._activation) summary += '\n)' return summary + +class RelGraphConv(gluon.Block): + r"""Relational graph convolution layer. + + Relational graph convolution is introduced in "`Modeling Relational Data with Graph + Convolutional Networks `__" + and can be described as below: + + .. math:: + + h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} + \sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)}) + + where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation + :math:`r`. :math:`c_{i,r}` is the normalizer equal + to :math:`|\mathcal{N}^r(i)|`. :math:`\sigma` is an activation function. :math:`W_0` + is the self-loop weight. + + The basis regularization decomposes :math:`W_r` by: + + .. math:: + + W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)} + + where :math:`B` is the number of bases. + + The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B` + number of block diagonal matrices. We refer :math:`B` as the number of bases. + + Parameters + ---------- + in_feat : int + Input feature size. + out_feat : int + Output feature size. + num_rels : int + Number of relations. + regularizer : str + Which weight regularizer to use "basis" or "bdd" + num_bases : int, optional + Number of bases. If is none, use number of relations. Default: None. + bias : bool, optional + True if bias is added. Default: True + activation : callable, optional + Activation function. Default: None + self_loop : bool, optional + True to include self loop message. Default: False + dropout : float, optional + Dropout rate. Default: 0.0 + """ + def __init__(self, + in_feat, + out_feat, + num_rels, + regularizer="basis", + num_bases=None, + bias=True, + activation=None, + self_loop=False, + dropout=0.0): + super(RelGraphConv, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + self.num_rels = num_rels + self.regularizer = regularizer + self.num_bases = num_bases + if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0: + self.num_bases = self.num_rels + self.bias = bias + self.activation = activation + self.self_loop = self_loop + + if regularizer == "basis": + # add basis weights + self.weight = self.params.get( + 'weight', shape=(self.num_bases, self.in_feat, self.out_feat), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + if self.num_bases < self.num_rels: + # linear combination coefficients + self.w_comp = self.params.get( + 'w_comp', shape=(self.num_rels, self.num_bases), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + # message func + self.message_func = self.basis_message_func + elif regularizer == "bdd": + if in_feat % num_bases != 0 or out_feat % num_bases != 0: + raise ValueError('Feature size must be a multiplier of num_bases.') + # add block diagonal weights + self.submat_in = in_feat // self.num_bases + self.submat_out = out_feat // self.num_bases + + # assuming in_feat and out_feat are both divisible by num_bases + self.weight = self.params.get( + 'weight', + shape=(self.num_rels, self.num_bases * self.submat_in * self.submat_out), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + # message func + self.message_func = self.bdd_message_func + else: + raise ValueError("Regularizer must be either 'basis' or 'bdd'") + + # bias + if self.bias: + self.h_bias = self.params.get('bias', shape=(out_feat,), + init=mx.init.Zero()) + + # weight for self loop + if self.self_loop: + self.loop_weight = self.params.get( + 'W_0', shape=(in_feat, out_feat), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + + self.dropout = nn.Dropout(dropout) + + def basis_message_func(self, edges): + """Message function for basis regularizer""" + ctx = edges.src['h'].context + if self.num_bases < self.num_rels: + # generate all weights from bases + weight = self.weight.data(ctx).reshape( + self.num_bases, self.in_feat * self.out_feat) + weight = nd.dot(self.w_comp.data(ctx), weight).reshape( + self.num_rels, self.in_feat, self.out_feat) + else: + weight = self.weight.data(ctx) + + msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type']) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def bdd_message_func(self, edges): + """Message function for block-diagonal-decomposition regularizer""" + ctx = edges.src['h'].context + if edges.src['h'].dtype in (np.int32, np.int64) and len(edges.src['h'].shape) == 1: + raise TypeError('Block decomposition does not allow integer ID feature.') + weight = self.weight.data(ctx)[edges.data['type'], :].reshape( + -1, self.submat_in, self.submat_out) + node = edges.src['h'].reshape(-1, 1, self.submat_in) + msg = nd.batch_dot(node, weight).reshape(-1, self.out_feat) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def forward(self, g, x, etypes, norm=None): + """Forward computation + + Parameters + ---------- + g : DGLGraph + The graph. + x : mx.ndarray.NDArray + Input node features. Could be either + - (|V|, D) dense tensor + - (|V|,) int64 vector, representing the categorical values of each + node. We then treat the input feature as an one-hot encoding feature. + etypes : mx.ndarray.NDArray + Edge type tensor. Shape: (|E|,) + norm : mx.ndarray.NDArray + Optional edge normalizer tensor. Shape: (|E|, 1) + + Returns + ------- + mx.ndarray.NDArray + New node features. + """ + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = etypes + if norm is not None: + g.edata['norm'] = norm + if self.self_loop: + loop_message = utils.matmul_maybe_select(x, self.loop_weight.data(x.context)) + + # message passing + g.update_all(self.message_func, fn.sum(msg='msg', out='h')) + + # apply bias and activation + node_repr = g.ndata['h'] + if self.bias: + node_repr = node_repr + self.h_bias.data(x.context) + if self.self_loop: + node_repr = node_repr + loop_message + if self.activation: + node_repr = self.activation(node_repr) + node_repr = self.dropout(node_repr) + + return node_repr diff --git a/python/dgl/nn/mxnet/utils.py b/python/dgl/nn/mxnet/utils.py new file mode 100644 index 000000000000..f9446a97b872 --- /dev/null +++ b/python/dgl/nn/mxnet/utils.py @@ -0,0 +1,86 @@ +"""Utilities for pytorch NN package""" +#pylint: disable=no-member, invalid-name + +from mxnet import nd +import numpy as np + +def matmul_maybe_select(A, B): + """Perform Matrix multiplication C = A * B but A could be an integer id vector. + + If A is an integer vector, we treat it as multiplying a one-hot encoded tensor. + In this case, the expensive dense matrix multiply can be replaced by a much + cheaper index lookup. + + For example, + :: + + A = [2, 0, 1], + B = [[0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6]] + + then matmul_maybe_select(A, B) is equivalent to + :: + + [[0, 0, 1], [[0.1, 0.2], + [1, 0, 0], * [0.3, 0.4], + [0, 1, 0]] [0.5, 0.6]] + + In all other cases, perform a normal matmul. + + Parameters + ---------- + A : torch.Tensor + lhs tensor + B : torch.Tensor + rhs tensor + + Returns + ------- + C : torch.Tensor + result tensor + """ + if A.dtype in (np.int32, np.int64) and len(A.shape) == 1: + return nd.take(B, A, axis=0) + else: + return nd.dot(A, B) + +def bmm_maybe_select(A, B, index): + """Slice submatrices of A by the given index and perform bmm. + + B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of + N matrices of shape (D1, D2). The input index is an integer vector of length M. + A could be either: + (1) a dense tensor of shape (M, D1), + (2) an integer vector of length M. + The result C is a 2D matrix of shape (M, D2) + + For case (1), C is computed by bmm: + :: + + C[i, :] = matmul(A[i, :], B[index[i], :, :]) + + For case (2), C is computed by index select: + :: + + C[i, :] = B[index[i], A[i], :] + + Parameters + ---------- + A : torch.Tensor + lhs tensor + B : torch.Tensor + rhs tensor + index : torch.Tensor + index tensor + + Returns + ------- + C : torch.Tensor + return tensor + """ + if A.dtype in (np.int32, np.int64) and len(A.shape) == 1: + return B[index, A, :] + else: + BB = nd.take(B, index, axis=0) + return nd.batch_dot(A.expand_dims(1), BB).squeeze() diff --git a/python/dgl/nn/pytorch/conv.py b/python/dgl/nn/pytorch/conv.py index b31b2ed2049b..cf1b4413e3fb 100644 --- a/python/dgl/nn/pytorch/conv.py +++ b/python/dgl/nn/pytorch/conv.py @@ -4,9 +4,10 @@ from torch import nn from torch.nn import init +from . import utils from ... import function as fn -__all__ = ['GraphConv'] +__all__ = ['GraphConv', 'RelGraphConv'] class GraphConv(nn.Module): r"""Apply graph convolution over an input signal. @@ -148,3 +149,188 @@ def extra_repr(self): if '_activation' in self.__dict__: summary += ', activation={_activation}' return summary.format(**self.__dict__) + +class RelGraphConv(nn.Module): + r"""Relational graph convolution layer. + + Relational graph convolution is introduced in "`Modeling Relational Data with Graph + Convolutional Networks `__" + and can be described as below: + + .. math:: + + h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} + \sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)}) + + where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation + :math:`r`. :math:`c_{i,r}` is the normalizer equal + to :math:`|\mathcal{N}^r(i)|`. :math:`\sigma` is an activation function. :math:`W_0` + is the self-loop weight. + + The basis regularization decomposes :math:`W_r` by: + + .. math:: + + W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)} + + where :math:`B` is the number of bases. + + The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B` + number of block diagonal matrices. We refer :math:`B` as the number of bases. + + Parameters + ---------- + in_feat : int + Input feature size. + out_feat : int + Output feature size. + num_rels : int + Number of relations. + regularizer : str + Which weight regularizer to use "basis" or "bdd" + num_bases : int, optional + Number of bases. If is none, use number of relations. Default: None. + bias : bool, optional + True if bias is added. Default: True + activation : callable, optional + Activation function. Default: None + self_loop : bool, optional + True to include self loop message. Default: False + dropout : float, optional + Dropout rate. Default: 0.0 + """ + def __init__(self, + in_feat, + out_feat, + num_rels, + regularizer="basis", + num_bases=None, + bias=True, + activation=None, + self_loop=False, + dropout=0.0): + super(RelGraphConv, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + self.num_rels = num_rels + self.regularizer = regularizer + self.num_bases = num_bases + if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0: + self.num_bases = self.num_rels + self.bias = bias + self.activation = activation + self.self_loop = self_loop + + if regularizer == "basis": + # add basis weights + self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat)) + if self.num_bases < self.num_rels: + # linear combination coefficients + self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases)) + nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) + if self.num_bases < self.num_rels: + nn.init.xavier_uniform_(self.w_comp, + gain=nn.init.calculate_gain('relu')) + # message func + self.message_func = self.basis_message_func + elif regularizer == "bdd": + if in_feat % num_bases != 0 or out_feat % num_bases != 0: + raise ValueError('Feature size must be a multiplier of num_bases.') + # add block diagonal weights + self.submat_in = in_feat // self.num_bases + self.submat_out = out_feat // self.num_bases + + # assuming in_feat and out_feat are both divisible by num_bases + self.weight = nn.Parameter(th.Tensor( + self.num_rels, self.num_bases * self.submat_in * self.submat_out)) + nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) + # message func + self.message_func = self.bdd_message_func + else: + raise ValueError("Regularizer must be either 'basis' or 'bdd'") + + # bias + if self.bias: + self.h_bias = nn.Parameter(th.Tensor(out_feat)) + nn.init.zeros_(self.h_bias) + + # weight for self loop + if self.self_loop: + self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) + nn.init.xavier_uniform_(self.loop_weight, + gain=nn.init.calculate_gain('relu')) + + self.dropout = nn.Dropout(dropout) + + def basis_message_func(self, edges): + """Message function for basis regularizer""" + if self.num_bases < self.num_rels: + # generate all weights from bases + weight = self.weight.view(self.num_bases, + self.in_feat * self.out_feat) + weight = th.matmul(self.w_comp, weight).view( + self.num_rels, self.in_feat, self.out_feat) + else: + weight = self.weight + + msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type']) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def bdd_message_func(self, edges): + """Message function for block-diagonal-decomposition regularizer""" + if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1: + raise TypeError('Block decomposition does not allow integer ID feature.') + weight = self.weight.index_select(0, edges.data['type']).view( + -1, self.submat_in, self.submat_out) + node = edges.src['h'].view(-1, 1, self.submat_in) + msg = th.bmm(node, weight).view(-1, self.out_feat) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def forward(self, g, x, etypes, norm=None): + """Forward computation + + Parameters + ---------- + g : DGLGraph + The graph. + x : torch.Tensor + Input node features. Could be either + - (|V|, D) dense tensor + - (|V|,) int64 vector, representing the categorical values of each + node. We then treat the input feature as an one-hot encoding feature. + etypes : torch.Tensor + Edge type tensor. Shape: (|E|,) + norm : torch.Tensor + Optional edge normalizer tensor. Shape: (|E|, 1) + + Returns + ------- + torch.Tensor + New node features. + """ + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = etypes + if norm is not None: + g.edata['norm'] = norm + if self.self_loop: + loop_message = utils.matmul_maybe_select(x, self.loop_weight) + + # message passing + g.update_all(self.message_func, fn.sum(msg='msg', out='h')) + + # apply bias and activation + node_repr = g.ndata['h'] + if self.bias: + node_repr = node_repr + self.h_bias + if self.self_loop: + node_repr = node_repr + loop_message + if self.activation: + node_repr = self.activation(node_repr) + node_repr = self.dropout(node_repr) + + return node_repr diff --git a/python/dgl/nn/pytorch/utils.py b/python/dgl/nn/pytorch/utils.py new file mode 100644 index 000000000000..81bad75610d7 --- /dev/null +++ b/python/dgl/nn/pytorch/utils.py @@ -0,0 +1,88 @@ +"""Utilities for pytorch NN package""" +#pylint: disable=no-member, invalid-name + +import torch as th + +def matmul_maybe_select(A, B): + """Perform Matrix multiplication C = A * B but A could be an integer id vector. + + If A is an integer vector, we treat it as multiplying a one-hot encoded tensor. + In this case, the expensive dense matrix multiply can be replaced by a much + cheaper index lookup. + + For example, + :: + + A = [2, 0, 1], + B = [[0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6]] + + then matmul_maybe_select(A, B) is equivalent to + :: + + [[0, 0, 1], [[0.1, 0.2], + [1, 0, 0], * [0.3, 0.4], + [0, 1, 0]] [0.5, 0.6]] + + In all other cases, perform a normal matmul. + + Parameters + ---------- + A : torch.Tensor + lhs tensor + B : torch.Tensor + rhs tensor + + Returns + ------- + C : torch.Tensor + result tensor + """ + if A.dtype == th.int64 and len(A.shape) == 1: + return B.index_select(0, A) + else: + return th.matmul(A, B) + +def bmm_maybe_select(A, B, index): + """Slice submatrices of A by the given index and perform bmm. + + B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of + N matrices of shape (D1, D2). The input index is an integer vector of length M. + A could be either: + (1) a dense tensor of shape (M, D1), + (2) an integer vector of length M. + The result C is a 2D matrix of shape (M, D2) + + For case (1), C is computed by bmm: + :: + + C[i, :] = matmul(A[i, :], B[index[i], :, :]) + + For case (2), C is computed by index select: + :: + + C[i, :] = B[index[i], A[i], :] + + Parameters + ---------- + A : torch.Tensor + lhs tensor + B : torch.Tensor + rhs tensor + index : torch.Tensor + index tensor + + Returns + ------- + C : torch.Tensor + return tensor + """ + if A.dtype == th.int64 and len(A.shape) == 1: + # following is a faster version of B[index, A, :] + B = B.view(-1, B.shape[2]) + flatidx = index * B.shape[1] + A + return B.index_select(0, flatidx) + else: + BB = B.index_select(0, index) + return th.bmm(A.unsqueeze(1), BB).squeeze() diff --git a/tests/mxnet/test_nn.py b/tests/mxnet/test_nn.py index e151958f8d31..6cf58132a783 100644 --- a/tests/mxnet/test_nn.py +++ b/tests/mxnet/test_nn.py @@ -1,10 +1,11 @@ import mxnet as mx import networkx as nx import numpy as np +import scipy as sp import dgl import dgl.nn.mxnet as nn import backend as F -from mxnet import autograd, gluon +from mxnet import autograd, gluon, nd def check_close(a, b): assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4) @@ -182,9 +183,61 @@ def test_edge_softmax(): assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(), 1e-4, 1e-4) +def test_rgcn(): + ctx = F.ctx() + etype = [] + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + # 5 etypes + R = 5 + for i in range(g.number_of_edges()): + etype.append(i % 5) + B = 2 + I = 10 + O = 8 + + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) + rgc_basis.initialize(ctx=ctx) + h = nd.random.randn(100, I, ctx=ctx) + r = nd.array(etype, ctx=ctx) + h_new = rgc_basis(g, h, r) + assert list(h_new.shape) == [100, O] + + rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) + rgc_bdd.initialize(ctx=ctx) + h = nd.random.randn(100, I, ctx=ctx) + r = nd.array(etype, ctx=ctx) + h_new = rgc_bdd(g, h, r) + assert list(h_new.shape) == [100, O] + + # with norm + norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx) + + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) + rgc_basis.initialize(ctx=ctx) + h = nd.random.randn(100, I, ctx=ctx) + r = nd.array(etype, ctx=ctx) + h_new = rgc_basis(g, h, r, norm) + assert list(h_new.shape) == [100, O] + + rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) + rgc_bdd.initialize(ctx=ctx) + h = nd.random.randn(100, I, ctx=ctx) + r = nd.array(etype, ctx=ctx) + h_new = rgc_bdd(g, h, r, norm) + assert list(h_new.shape) == [100, O] + + # id input + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) + rgc_basis.initialize(ctx=ctx) + h = nd.random.randint(0, I, (100,), ctx=ctx) + r = nd.array(etype, ctx=ctx) + h_new = rgc_basis(g, h, r) + assert list(h_new.shape) == [100, O] + if __name__ == '__main__': test_graph_conv() test_edge_softmax() test_set2set() test_glob_att_pool() test_simple_pool() + test_rgcn() diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 3fbfdbadcf9e..f37270db5b78 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -260,6 +260,51 @@ def generate_rand_graph(n): assert len(g.edata) == 2 assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend +def test_rgcn(): + ctx = F.ctx() + etype = [] + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + # 5 etypes + R = 5 + for i in range(g.number_of_edges()): + etype.append(i % 5) + B = 2 + I = 10 + O = 8 + + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) + h = th.randn((100, I)).to(ctx) + r = th.tensor(etype).to(ctx) + h_new = rgc_basis(g, h, r) + assert list(h_new.shape) == [100, O] + + rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx) + h = th.randn((100, I)).to(ctx) + r = th.tensor(etype).to(ctx) + h_new = rgc_bdd(g, h, r) + assert list(h_new.shape) == [100, O] + + # with norm + norm = th.zeros((g.number_of_edges(), 1)).to(ctx) + + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) + h = th.randn((100, I)).to(ctx) + r = th.tensor(etype).to(ctx) + h_new = rgc_basis(g, h, r, norm) + assert list(h_new.shape) == [100, O] + + rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx) + h = th.randn((100, I)).to(ctx) + r = th.tensor(etype).to(ctx) + h_new = rgc_bdd(g, h, r, norm) + assert list(h_new.shape) == [100, O] + + # id input + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx) + h = th.randint(0, I, (100,)).to(ctx) + r = th.tensor(etype).to(ctx) + h_new = rgc_basis(g, h, r) + assert list(h_new.shape) == [100, O] if __name__ == '__main__': test_graph_conv() @@ -268,3 +313,4 @@ def generate_rand_graph(n): test_glob_att_pool() test_simple_pool() test_set_trans() + test_rgcn()