diff --git a/examples/mxnet/gat/gat_batch.py b/examples/mxnet/gat/gat_batch.py index 3697db7fb966..aee13852215d 100644 --- a/examples/mxnet/gat/gat_batch.py +++ b/examples/mxnet/gat/gat_batch.py @@ -18,18 +18,18 @@ def elu(data): return mx.nd.LeakyReLU(data, act_type='elu') -def gat_message(src, edge): - return {'ft' : src['ft'], 'a2' : src['a2']} +def gat_message(edges): + return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']} class GATReduce(gluon.Block): def __init__(self, attn_drop): super(GATReduce, self).__init__() self.attn_drop = attn_drop - def forward(self, node, msgs): - a1 = mx.nd.expand_dims(node['a1'], 1) # shape (B, 1, 1) - a2 = msgs['a2'] # shape (B, deg, 1) - ft = msgs['ft'] # shape (B, deg, D) + def forward(self, nodes): + a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1) + a2 = nodes.mailbox['a2'] # shape (B, deg, 1) + ft = nodes.mailbox['ft'] # shape (B, deg, D) # attention a = a1 + a2 # shape (B, deg, 1) e = mx.nd.softmax(mx.nd.LeakyReLU(a)) @@ -48,13 +48,13 @@ def __init__(self, headid, indim, hiddendim, activation, residual): if indim != hiddendim: self.residual_fc = gluon.nn.Dense(hiddendim) - def forward(self, node): - ret = node['accum'] + def forward(self, nodes): + ret = nodes.data['accum'] if self.residual: if self.residual_fc is not None: - ret = self.residual_fc(node['h']) + ret + ret = self.residual_fc(nodes.data['h']) + ret else: - ret = node['h'] + ret + ret = nodes.data['h'] + ret return {'head%d' % self.headid : self.activation(ret)} class GATPrepare(gluon.Block): diff --git a/examples/mxnet/gcn/gcn_batch.py b/examples/mxnet/gcn/gcn_batch.py index 2cc8dc833362..c095637bcf21 100644 --- a/examples/mxnet/gcn/gcn_batch.py +++ b/examples/mxnet/gcn/gcn_batch.py @@ -14,11 +14,11 @@ from dgl import DGLGraph from dgl.data import register_data_args, load_data -def gcn_msg(src, edge): - return src +def gcn_msg(edge): + return {'m': edge.src['h']} -def gcn_reduce(node, msgs): - return mx.nd.sum(msgs, 1) +def gcn_reduce(node): + return {'accum': mx.nd.sum(node.mailbox['m'], 1)} class NodeUpdateModule(gluon.Block): def __init__(self, out_feats, activation=None): @@ -26,7 +26,7 @@ def __init__(self, out_feats, activation=None): self.linear = gluon.nn.Dense(out_feats, activation=activation) def forward(self, node): - return self.linear(node) + return {'h': self.linear(node.data['accum'])} class GCN(gluon.Block): def __init__(self, @@ -50,14 +50,14 @@ def __init__(self, self.layers.add(NodeUpdateModule(n_classes)) def forward(self, features): - self.g.set_n_repr(features) + self.g.ndata['h'] = features for layer in self.layers: # apply dropout if self.dropout: - val = F.dropout(self.g.get_n_repr(), p=self.dropout) - self.g.set_n_repr(val) + val = F.dropout(self.g.ndata['h'], p=self.dropout) + self.g.ndata['h'] = val self.g.update_all(gcn_msg, gcn_reduce, layer) - return self.g.pop_n_repr() + return self.g.ndata.pop('h') def main(args): # load and preprocess dataset diff --git a/examples/mxnet/sse/sse_batch.py b/examples/mxnet/sse/sse_batch.py index 21937aaea8bf..41d62ac60088 100644 --- a/examples/mxnet/sse/sse_batch.py +++ b/examples/mxnet/sse/sse_batch.py @@ -6,49 +6,52 @@ import argparse import numpy as np import time +import math import mxnet as mx from mxnet import gluon import dgl import dgl.function as fn -from dgl import DGLGraph, utils +from dgl import DGLGraph from dgl.data import register_data_args, load_data -def gcn_msg(src, edge): +def gcn_msg(edges): # TODO should we use concat? - return {'m': mx.nd.concat(src['in'], src['h'], dim=1)} + return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)} -def gcn_reduce(node, msgs): - return {'accum': mx.nd.sum(msgs['m'], 1)} +def gcn_reduce(nodes): + return {'accum': mx.nd.sum(nodes.mailbox['m'], 1)} class NodeUpdate(gluon.Block): - def __init__(self, out_feats, activation=None, alpha=0.9): - super(NodeUpdate, self).__init__() + def __init__(self, out_feats, activation=None, alpha=0.9, **kwargs): + super(NodeUpdate, self).__init__(**kwargs) self.linear1 = gluon.nn.Dense(out_feats, activation=activation) # TODO what is the dimension here? self.linear2 = gluon.nn.Dense(out_feats) self.alpha = alpha - def forward(self, node): - tmp = mx.nd.concat(node['in'], node['accum'], dim=1) - hidden = self.linear2(self.linear1(tmp)) - return {'h': node['h'] * (1 - self.alpha) + self.alpha * hidden} + def forward(self, nodes): + hidden = mx.nd.concat(nodes.data['in'], nodes.data['accum'], dim=1) + hidden = self.linear2(self.linear1(hidden)) + return {'h': nodes.data['h'] * (1 - self.alpha) + self.alpha * hidden} class SSEUpdateHidden(gluon.Block): def __init__(self, n_hidden, activation, dropout, - use_spmv): - super(SSEUpdateHidden, self).__init__() - self.layer = NodeUpdate(n_hidden, activation) + use_spmv, + **kwargs): + super(SSEUpdateHidden, self).__init__(**kwargs) + with self.name_scope(): + self.layer = NodeUpdate(n_hidden, activation) self.dropout = dropout self.use_spmv = use_spmv def forward(self, g, vertices): if self.use_spmv: - feat = g.get_n_repr()['in'] - h = g.get_n_repr()['h'] - g.set_n_repr({'cat': mx.nd.concat(feat, h, dim=1)}) + feat = g.ndata['in'] + h = g.ndata['h'] + g.ndata['cat'] = mx.nd.concat(feat, h, dim=1) msg_func = fn.copy_src(src='cat', out='tmp') reduce_func = fn.sum(msg='tmp', out='accum') @@ -56,24 +59,36 @@ def forward(self, g, vertices): msg_func = gcn_msg reduce_func = gcn_reduce if vertices is None: - g.update_all(msg_func, reduce_func, self.layer) - ret = g.get_n_repr()['h'] + g.update_all(msg_func, reduce_func, None) + if self.use_spmv: + g.ndata.pop('cat') + batch_size = 100000 + num_batches = int(math.ceil(g.number_of_nodes() / batch_size)) + for i in range(num_batches): + vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64) + g.apply_nodes(self.layer, vs, inplace=True) + g.ndata.pop('accum') + ret = g.ndata['h'] else: # We don't need dropout for inference. if self.dropout: # TODO here we apply dropout on all vertex representation. - val = mx.nd.Dropout(g.get_n_repr()['h'], p=self.dropout) - g.set_n_repr({'h': val}) + val = mx.nd.Dropout(g.ndata['h'], p=self.dropout) + g.ndata['h'] = val g.pull(vertices, msg_func, reduce_func, self.layer) - ctx = g.get_n_repr()['h'].context - ret = mx.nd.take(g.get_n_repr()['h'], vertices.tousertensor().as_in_context(ctx)) + ctx = g.ndata['h'].context + ret = mx.nd.take(g.ndata['h'], vertices.tousertensor().as_in_context(ctx)) + if self.use_spmv: + g.ndata.pop('cat') + g.ndata.pop('accum') return ret class SSEPredict(gluon.Block): - def __init__(self, update_hidden, out_feats, dropout): - super(SSEPredict, self).__init__() - self.linear1 = gluon.nn.Dense(out_feats, activation='relu') - self.linear2 = gluon.nn.Dense(out_feats) + def __init__(self, update_hidden, out_feats, dropout, **kwargs): + super(SSEPredict, self).__init__(**kwargs) + with self.name_scope(): + self.linear1 = gluon.nn.Dense(out_feats, activation='relu') + self.linear2 = gluon.nn.Dense(out_feats) self.update_hidden = update_hidden self.dropout = dropout @@ -83,10 +98,11 @@ def forward(self, g, vertices): hidden = mx.nd.Dropout(hidden, p=self.dropout) return self.linear2(self.linear1(hidden)) -def subgraph_gen(g, seed_vertices): +def subgraph_gen(g, seed_vertices, ctxs): + assert len(seed_vertices) % len(ctxs) == 0 vertices = [] for seed in seed_vertices: - src, _, _ = g.in_edges(seed) + src, _ = g.in_edges(seed) vs = np.concatenate((src.asnumpy(), seed.asnumpy()), axis=0) vs = mx.nd.array(np.unique(vs), dtype=np.int64) vertices.append(vs) @@ -94,12 +110,23 @@ def subgraph_gen(g, seed_vertices): nids = [] for i, subg in enumerate(subgs): subg.copy_from_parent() - nids.append(subg.map_to_subgraph_nid(utils.toindex(seed_vertices[i]))) + nids.append(subg.map_to_subgraph_nid(seed_vertices[i])) return subgs, nids +def copy_to_gpu(subg, ctx): + frame = subg.ndata + for key in frame: + subg.ndata[key] = frame[key].as_in_context(ctx) + def main(args, data): - features = mx.nd.array(data.features) - labels = mx.nd.array(data.labels) + if isinstance(data.features, mx.nd.NDArray): + features = data.features + else: + features = mx.nd.array(data.features) + if isinstance(data.labels, mx.nd.NDArray): + labels = data.labels + else: + labels = mx.nd.array(data.labels) train_size = len(labels) * args.train_percent train_vs = np.arange(train_size, dtype='int64') eval_vs = np.arange(train_size, len(labels), dtype='int64') @@ -111,42 +138,45 @@ def main(args, data): n_classes = data.num_labels n_edges = data.graph.number_of_edges() - if args.gpu <= 0: - cuda = False - ctx = mx.cpu(0) - else: - cuda = True - features = features.as_in_context(mx.gpu(0)) - train_labels = train_labels.as_in_context(mx.gpu(0)) - eval_labels = eval_labels.as_in_context(mx.gpu(0)) - ctx = mx.gpu(0) - # create the SSE model try: graph = data.graph.get_graph() except AttributeError: graph = data.graph g = DGLGraph(graph, readonly=True) - g.set_n_repr({'in': features, 'h': mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden), - ctx=ctx)}) + g.ndata['in'] = features + g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden), + ctx=mx.cpu(0)) - update_hidden = SSEUpdateHidden(args.n_hidden, 'relu', args.update_dropout, args.use_spmv) - model = SSEPredict(update_hidden, args.n_hidden, args.predict_dropout) - model.initialize(ctx=ctx) + update_hidden_infer = SSEUpdateHidden(args.n_hidden, 'relu', + args.update_dropout, args.use_spmv, prefix='sse') + update_hidden_infer.initialize(ctx=mx.cpu(0)) + + train_ctxs = [] + update_hidden_train = SSEUpdateHidden(args.n_hidden, 'relu', + args.update_dropout, args.use_spmv, prefix='sse') + model = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app') + if args.gpu <= 0: + model.initialize(ctx=mx.cpu(0)) + train_ctxs.append(mx.cpu(0)) + else: + for i in range(args.gpu): + train_ctxs.append(mx.gpu(i)) + model.initialize(ctx=train_ctxs) # use optimizer num_batches = int(g.number_of_nodes() / args.batch_size) scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches, args.lr * 10, 0, 0, args.lr/5) trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr, - 'lr_scheduler': scheduler}) + 'lr_scheduler': scheduler}, kvstore=mx.kv.create('device')) + + # compute vertex embedding. + update_hidden_infer(g, None) # initialize graph dur = [] for epoch in range(args.n_epochs): - # compute vertex embedding. - update_hidden(g, None) - t0 = time.time() permute = np.random.permutation(len(train_vs)) randv = train_vs[permute] @@ -162,14 +192,27 @@ def main(args, data): if len(data) < args.num_parallel_subgraphs: continue - subgs, seed_ids = subgraph_gen(g, data) + subgs, seed_ids = subgraph_gen(g, data, train_ctxs) + + losses = [] + i = 0 for subg, seed_id, label, d in zip(subgs, seed_ids, labels, data): + if args.gpu > 0: + ctx = mx.gpu(i % args.gpu) + copy_to_gpu(subg, ctx) with mx.autograd.record(): logits = model(subg, seed_id) + if label.context != logits.context: + label = label.as_in_context(logits.context) loss = mx.nd.softmax_cross_entropy(logits, label) loss.backward() - trainer.step(d.shape[0]) - train_loss += loss.asnumpy()[0] + losses.append(loss) + i = i + 1 + if i % args.gpu == 0: + trainer.step(d.shape[0] * len(subgs)) + for loss in losses: + train_loss += loss.asnumpy()[0] + losses = [] data = [] labels = [] @@ -178,13 +221,48 @@ def main(args, data): #eval_loss = eval_loss.asnumpy()[0] eval_loss = 0 + # compute vertex embedding. + infer_params = update_hidden_infer.collect_params() + for key in infer_params: + idx = trainer._param2idx[key] + trainer._kvstore.pull(idx, out=infer_params[key].data()) + update_hidden_infer(g, None) + dur.append(time.time() - t0) print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000)) +class MXNetGraph(object): + """A simple graph object that uses scipy matrix.""" + def __init__(self, mat): + self._mat = mat + + def get_graph(self): + return self._mat + + def number_of_nodes(self): + return self._mat.shape[0] + + def number_of_edges(self): + return mx.nd.contrib.getnnz(self._mat) + +class GraphData: + def __init__(self, csr, num_feats): + num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0] + edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64) + csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64) + self.graph = MXNetGraph(csr) + self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats)) + self.labels = mx.nd.floor(mx.nd.random.normal(loc=0, scale=10, shape=(csr.shape[0]))) + self.num_labels = 10 + if __name__ == '__main__': parser = argparse.ArgumentParser(description='GCN') register_data_args(parser) + parser.add_argument("--graph-file", type=str, default="", + help="graph file") + parser.add_argument("--num-feats", type=int, default=10, + help="the number of features") parser.add_argument("--gpu", type=int, default=-1, help="gpu") parser.add_argument("--lr", type=float, default=1e-3, @@ -210,5 +288,10 @@ def main(args, data): args = parser.parse_args() # load and preprocess dataset - data = load_data(args) + if args.graph_file != '': + csr = mx.nd.load(args.graph_file)[0] + data = GraphData(csr, args.num_feats) + csr = None + else: + data = load_data(args) main(args, data) diff --git a/python/dgl/backend/mxnet/immutable_graph_index.py b/python/dgl/backend/mxnet/immutable_graph_index.py index 360aa50e5fcd..5fbaa0c819fe 100644 --- a/python/dgl/backend/mxnet/immutable_graph_index.py +++ b/python/dgl/backend/mxnet/immutable_graph_index.py @@ -81,9 +81,13 @@ def edge_ids(self, u, v): NDArray Teh edge id array. """ + if len(u) == 0 or len(v) == 0: + return [], [], [] ids = mx.nd.contrib.edge_id(self._in_csr, v, u) ids = ids.asnumpy() - return ids[ids >= 0] + v = v.asnumpy() + u = u.asnumpy() + return u[ids >= 0], v[ids >= 0], ids[ids >= 0] def predecessors(self, v, radius=1): """Return the predecessors of the node. diff --git a/python/dgl/backend/mxnet/tensor.py b/python/dgl/backend/mxnet/tensor.py index d2d224435842..a07f3f3ff5d7 100644 --- a/python/dgl/backend/mxnet/tensor.py +++ b/python/dgl/backend/mxnet/tensor.py @@ -27,11 +27,11 @@ def sparse_matrix(data, index, shape, force_format=False): raise TypeError('MXNet backend only supports CSR format,' ' but COO format is forced.') coord = index[1] - return nd.sparse.csr_matrix((data, (coord[0], coord[1])), shape) + return nd.sparse.csr_matrix((data, (coord[0], coord[1])), tuple(shape)) elif fmt == 'csr': indices = index[1] indptr = index[2] - return nd.sparse.csr_matrix((data, indices, indptr), shape) + return nd.sparse.csr_matrix((data, indices, indptr), tuple(shape)) else: raise TypeError('Invalid format: %s.' % fmt) @@ -65,7 +65,7 @@ def sum(input, dim): return nd.sum(input, axis=dim) def max(input, dim): - return nd.max(input, axis=dim) + return nd.max(input, axis=dim).asnumpy()[0] def cat(seq, dim): return nd.concat(*seq, dim=dim) @@ -131,7 +131,7 @@ def nonzero_1d(input): def sort_1d(input): # TODO: this isn't an ideal implementation. - val = nd.sort(input, is_ascend=True) + val = nd.sort(input, axis=None, is_ascend=True) idx = nd.argsort(input, is_ascend=True) idx = nd.cast(idx, dtype='int64') return val, idx diff --git a/python/dgl/graph.py b/python/dgl/graph.py index 648c4d5ffc75..71203073a440 100644 --- a/python/dgl/graph.py +++ b/python/dgl/graph.py @@ -853,7 +853,7 @@ def register_apply_edge_func(self, func): """ self._apply_edge_func = func - def apply_nodes(self, func="default", v=ALL): + def apply_nodes(self, func="default", v=ALL, inplace=False): """Apply the function on the node features. Applying a None function will be ignored. @@ -865,7 +865,7 @@ def apply_nodes(self, func="default", v=ALL): v : int, iterable of int, tensor, optional The node id(s). """ - self._internal_apply_nodes(v, func) + self._internal_apply_nodes(v, func, inplace=inplace) def apply_edges(self, func="default", edges=ALL): """Apply the function on the edge features. @@ -1464,7 +1464,8 @@ def filter_edges(self, predicate, edges=ALL): edges = F.tensor(edges) return edges[e_mask] - def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None): + def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None, + inplace=False): """Internal apply nodes Parameters @@ -1478,7 +1479,7 @@ def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None) # Skip none function call. if reduce_accum is not None: # write reduce result back - self.set_n_repr(reduce_accum, v) + self.set_n_repr(reduce_accum, v, inplace=inplace) return # take out current node repr curr_repr = self.get_n_repr(v) @@ -1491,4 +1492,4 @@ def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None) # merge new node_repr with reduce output reduce_accum.update(new_repr) new_repr = reduce_accum - self.set_n_repr(new_repr, v) + self.set_n_repr(new_repr, v, inplace=inplace) diff --git a/python/dgl/immutable_graph_index.py b/python/dgl/immutable_graph_index.py index ae3b77fada2f..b7c209af003f 100644 --- a/python/dgl/immutable_graph_index.py +++ b/python/dgl/immutable_graph_index.py @@ -62,6 +62,17 @@ def clear(self): """Clear the graph.""" raise Exception('Immutable graph doesn\'t support clearing up') + def is_multigraph(self): + """Return whether the graph is a multigraph + + Returns + ------- + bool + True if it is a multigraph, False otherwise. + """ + # Immutable graph doesn't support multi-edge. + return False + def number_of_nodes(self): """Return the number of nodes. @@ -207,7 +218,7 @@ def edge_id(self, u, v): """ u = F.tensor([u], dtype=F.int64) v = F.tensor([v], dtype=F.int64) - id = self._sparse.edge_ids(u, v) + _, _, id = self._sparse.edge_ids(u, v) return utils.toindex(id) def edge_ids(self, u, v): @@ -223,12 +234,16 @@ def edge_ids(self, u, v): Returns ------- utils.Index - The edge id array. + The src nodes. + utils.Index + The dst nodes. + utils.Index + The edge ids. """ u = u.tousertensor() v = v.tousertensor() - ids = self._sparse.edge_ids(u, v) - return utils.toindex(ids) + u, v, ids = self._sparse.edge_ids(u, v) + return utils.toindex(u), utils.toindex(v), utils.toindex(ids) def in_edges(self, v): """Return the in edges of the node(s). diff --git a/python/dgl/subgraph.py b/python/dgl/subgraph.py index f44303a40d53..c4c4b4db4426 100644 --- a/python/dgl/subgraph.py +++ b/python/dgl/subgraph.py @@ -119,4 +119,4 @@ def copy_from_parent(self): self._parent._edge_frame[self._parent_eid])) def map_to_subgraph_nid(self, parent_vids): - return map_to_subgraph_nid(self._graph, parent_vids) + return map_to_subgraph_nid(self._graph, utils.toindex(parent_vids)) diff --git a/tests/mxnet/test_basics.py b/tests/mxnet/test_basics.py index f0cb2a214c73..124143a5fbb8 100644 --- a/tests/mxnet/test_basics.py +++ b/tests/mxnet/test_basics.py @@ -3,6 +3,7 @@ import mxnet as mx import numpy as np from dgl.graph import DGLGraph +import scipy.sparse as spsp D = 5 reduce_msg_shapes = set() @@ -26,20 +27,39 @@ def reduce_func(nodes): def apply_node_func(nodes): return {'h' : nodes.data['h'] + nodes.data['m']} -def generate_graph(grad=False): - g = DGLGraph() - g.add_nodes(10) # 10 nodes. - # create a graph where 0 is the source and 9 is the sink - for i in range(1, 9): - g.add_edge(0, i) - g.add_edge(i, 9) - # add a back flow from 9 to 0 - g.add_edge(9, 0) - ncol = mx.nd.random.normal(shape=(10, D)) - if grad: - ncol.attach_grad() - g.ndata['h'] = ncol - return g +def generate_graph(grad=False, readonly=False): + if readonly: + row_idx = [] + col_idx = [] + for i in range(1, 9): + row_idx.append(0) + col_idx.append(i) + row_idx.append(i) + col_idx.append(9) + row_idx.append(9) + col_idx.append(0) + ones = np.ones(shape=(len(row_idx))) + csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(10, 10)) + g = DGLGraph(csr, readonly=True) + ncol = mx.nd.random.normal(shape=(10, D)) + if grad: + ncol.attach_grad() + g.ndata['h'] = ncol + return g + else: + g = DGLGraph() + g.add_nodes(10) # 10 nodes. + # create a graph where 0 is the source and 9 is the sink + for i in range(1, 9): + g.add_edge(0, i) + g.add_edge(i, 9) + # add a back flow from 9 to 0 + g.add_edge(9, 0) + ncol = mx.nd.random.normal(shape=(10, D)) + if grad: + ncol.attach_grad() + g.ndata['h'] = ncol + return g def test_batch_setter_getter(): def _pfc(x): @@ -121,7 +141,7 @@ def _pfc(x): def test_batch_setter_autograd(): with mx.autograd.record(): - g = generate_graph(grad=True) + g = generate_graph(grad=True, readonly=True) h1 = g.ndata['h'] h1.attach_grad() # partial set @@ -153,9 +173,9 @@ def _fmsg(edges): v = mx.nd.array([9], dtype='int64') g.send((u, v)) -def test_batch_recv(): +def check_batch_recv(readonly): # basic recv test - g = generate_graph() + g = generate_graph(readonly=readonly) g.register_message_func(message_func) g.register_reduce_func(reduce_func) g.register_apply_node_func(apply_node_func) @@ -167,8 +187,12 @@ def test_batch_recv(): #assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) #reduce_msg_shapes.clear() -def test_update_routines(): - g = generate_graph() +def test_batch_recv(): + check_batch_recv(True) + check_batch_recv(False) + +def check_update_routines(readonly): + g = generate_graph(readonly=readonly) g.register_message_func(message_func) g.register_reduce_func(reduce_func) g.register_apply_node_func(apply_node_func) @@ -201,13 +225,27 @@ def test_update_routines(): assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)}) reduce_msg_shapes.clear() -def test_reduce_0deg(): - g = DGLGraph() - g.add_nodes(5) - g.add_edge(1, 0) - g.add_edge(2, 0) - g.add_edge(3, 0) - g.add_edge(4, 0) +def test_update_routines(): + check_update_routines(True) + check_update_routines(False) + +def check_reduce_0deg(readonly): + if readonly: + row_idx = [] + col_idx = [] + for i in range(1, 5): + row_idx.append(i) + col_idx.append(0) + ones = np.ones(shape=(len(row_idx))) + csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(5, 5)) + g = DGLGraph(csr, readonly=True) + else: + g = DGLGraph() + g.add_nodes(5) + g.add_edge(1, 0) + g.add_edge(2, 0) + g.add_edge(3, 0) + g.add_edge(4, 0) def _message(edges): return {'m' : edges.src['h']} def _reduce(nodes): @@ -220,10 +258,23 @@ def _reduce(nodes): assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy()) -def test_pull_0deg(): - g = DGLGraph() - g.add_nodes(2) - g.add_edge(0, 1) +def test_reduce_0deg(): + check_reduce_0deg(True) + check_reduce_0deg(False) + +def check_pull_0deg(readonly): + if readonly: + row_idx = [] + col_idx = [] + row_idx.append(0) + col_idx.append(1) + ones = np.ones(shape=(len(row_idx))) + csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(2, 2)) + g = DGLGraph(csr, readonly=True) + else: + g = DGLGraph() + g.add_nodes(2) + g.add_edge(0, 1) def _message(edges): return {'m' : edges.src['h']} def _reduce(nodes): @@ -246,6 +297,10 @@ def _reduce(nodes): assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy()) +def test_pull_0deg(): + check_pull_0deg(True) + check_pull_0deg(False) + if __name__ == '__main__': test_batch_setter_getter() test_batch_setter_autograd() diff --git a/tests/mxnet/test_graph_index.py b/tests/mxnet/test_graph_index.py index 49b066484af0..1380a39a19b0 100644 --- a/tests/mxnet/test_graph_index.py +++ b/tests/mxnet/test_graph_index.py @@ -67,7 +67,7 @@ def check_basics(g, ig): assert g.has_edge_between(u, v) == ig.has_edge_between(u, v) randv = utils.toindex(randv) ids = g.edge_ids(randv, randv)[2].tolist() - assert sum(ig.edge_ids(randv, randv).tolist() == ids) == len(ids) + assert sum(ig.edge_ids(randv, randv)[2].tolist() == ids) == len(ids) assert sum(g.has_edges_between(randv, randv).tolist() == ig.has_edges_between(randv, randv).tolist()) == len(randv)