Skip to content

Commit

Permalink
Add tests for MXNet and throw NotImplementedError if not implemented (d…
Browse files Browse the repository at this point in the history
…mlc#258)

* add more unit tests for mxnet.

* fix.
  • Loading branch information
zheng-da authored Dec 5, 2018
1 parent 2c5b48a commit 899d125
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def stack(seq, dim):
return nd.stack(*seq, dim=dim)

def split(x, sizes_or_sections, dim):
if isinstance(sizes_or_sections, list):
if isinstance(sizes_or_sections, list) or isinstance(sizes_or_sections, np.ndarray):
# TODO: fallback to numpy is unfortunate
np_arr = x.asnumpy()
indices = np.cumsum(sizes_or_sections)[:-1]
Expand Down
81 changes: 79 additions & 2 deletions python/dgl/immutable_graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,25 @@ def edge_ids(self, u, v):
u, v, ids = self._sparse.edge_ids(u, v)
return utils.toindex(u), utils.toindex(v), utils.toindex(ids)

def find_edges(self, eid):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
eid : utils.Index
The edge ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
raise NotImplementedError('immutable graph doesn\'t implement find_edges for now.')

def in_edges(self, v):
"""Return the in edges of the node(s).
Expand Down Expand Up @@ -442,6 +461,21 @@ def node_subgraphs(self, vs_arr):
return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]

def edge_subgraph(self, e):
"""Return the induced edge subgraph.
Parameters
----------
e : utils.Index
The edges.
Returns
-------
SubgraphIndex
The subgraph index.
"""
raise NotImplementedError('immutable graph doesn\'t implement edge_subgraph for now.')

def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, max_subgraph_size):
if len(seed_ids) == 0:
Expand Down Expand Up @@ -519,7 +553,7 @@ def incidence_matrix(self, type, ctx):
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
raise Exception('immutable graph doesn\'t support incidence_matrix for now.')
raise NotImplementedError('immutable graph doesn\'t implement incidence_matrix for now.')

def to_networkx(self):
"""Convert to networkx graph.
Expand Down Expand Up @@ -581,7 +615,7 @@ def line_graph(self, backtracking=True):
ImmutableGraphIndex
The line graph of this graph.
"""
raise Exception('immutable graph doesn\'t support line_graph')
raise NotImplementedError('immutable graph doesn\'t implement line_graph')

class ImmutableSubgraphIndex(ImmutableGraphIndex):
"""Graph index for an immutable subgraph.
Expand Down Expand Up @@ -626,6 +660,49 @@ def induced_nodes(self):
"""
return utils.toindex(self._induced_nodes)

def disjoint_union(graphs):
"""Return a disjoint union of the input graphs.
The new graph will include all the nodes/edges in the given graphs.
Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
in the given sequence order. For example, giving input [g1, g2, g3], where
they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
in the result graph. Edge ids are re-assigned similarly.
Parameters
----------
graphs : iterable of GraphIndex
The input graphs
Returns
-------
GraphIndex
The disjoint union
"""
raise NotImplementedError('immutable graph doesn\'t implement disjoint_union for now.')

def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly.
This is a reverse operation of DisjointUnion. The graph will be partitioned
into num graphs. This requires the given number of partitions to evenly
divides the number of nodes in the graph. If the a size list is given,
the sum of the given sizes is equal.
Parameters
----------
graph : GraphIndex
The graph to be partitioned
num_or_size_splits : int or utils.Index
The partition number of size splits
Returns
-------
list of GraphIndex
The partitioned graphs
"""
raise NotImplementedError('immutable graph doesn\'t implement disjoint_partition for now.')

def create_immutable_graph_index(graph_data=None):
"""Create a graph index object.
Expand Down
73 changes: 73 additions & 0 deletions tests/mxnet/test_propagate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
os.environ['DGLBACKEND'] = 'mxnet'
import dgl
import networkx as nx
import numpy as np
import mxnet as mx

def mfunc(edges):
return {'m' : edges.src['x']}

def rfunc(nodes):
msg = mx.nd.sum(nodes.mailbox['m'], 1)
return {'x' : nodes.data['x'] + msg}

def test_prop_nodes_bfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.ndata['x'] = mx.nd.ones(shape=(5, 2))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)

dgl.prop_nodes_bfs(g, 0)
# pull nodes using bfs order will result in a cumsum[i] + data[i] + data[i+1]
assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))

def test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)

g.ndata['x'] = mx.nd.ones(shape=(5, 2))
dgl.prop_edges_dfs(g, 0)
# snr using dfs results in a cumsum
assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]))

g.ndata['x'] = mx.nd.ones(shape=(5, 2))
dgl.prop_edges_dfs(g, 0, has_reverse_edge=True)
# result is cumsum[i] + cumsum[i-1]
assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]]))

g.ndata['x'] = mx.nd.ones(shape=(5, 2))
dgl.prop_edges_dfs(g, 0, has_nontree_edge=True)
# result is cumsum[i] + cumsum[i+1]
assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))

def test_prop_nodes_topo():
# bi-directional chain
g = dgl.DGLGraph(nx.path_graph(5))

# tree
tree = dgl.DGLGraph()
tree.add_nodes(5)
tree.add_edge(1, 0)
tree.add_edge(2, 0)
tree.add_edge(3, 2)
tree.add_edge(4, 2)
tree.register_message_func(mfunc)
tree.register_reduce_func(rfunc)
# init node feature data
tree.ndata['x'] = mx.nd.zeros(shape=(5, 2))
# set all leaf nodes to be ones
tree.nodes[[1, 3, 4]].data['x'] = mx.nd.ones(shape=(3, 2))
dgl.prop_nodes_topo(tree)
# root node get the sum
assert np.allclose(tree.nodes[0].data['x'].asnumpy(), np.array([[3., 3.]]))

if __name__ == '__main__':
test_prop_nodes_bfs()
test_prop_edges_dfs()
test_prop_nodes_topo()
126 changes: 126 additions & 0 deletions tests/mxnet/test_traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
os.environ['DGLBACKEND'] = 'mxnet'
import random
import sys
import time

import dgl
import networkx as nx
import numpy as np
import scipy.sparse as sp
import mxnet as mx

import itertools

np.random.seed(42)

def toset(x):
return set(x.asnumpy().tolist())

def test_bfs(n=1000):
def _bfs_nx(g_nx, src):
edges = nx.bfs_edges(g_nx, src)
layers_nx = [set([src])]
edges_nx = []
frontier = set()
edge_frontier = set()
for u, v in edges:
if u in layers_nx[-1]:
frontier.add(v)
edge_frontier.add(g.edge_id(u, v))
else:
layers_nx.append(frontier)
edges_nx.append(edge_frontier)
frontier = set([v])
edge_frontier = set([g.edge_id(u, v)])
layers_nx.append(frontier)
edges_nx.append(edge_frontier)
return layers_nx, edges_nx

g = dgl.DGLGraph()
a = sp.random(n, n, 10 / n, data_rvs=lambda n: np.ones(n))
g.from_scipy_sparse_matrix(a)
g_nx = g.to_networkx()
src = random.choice(range(n))
layers_nx, _ = _bfs_nx(g_nx, src)
layers_dgl = dgl.bfs_nodes_generator(g, src)
assert len(layers_dgl) == len(layers_nx)
assert all(toset(x) == y for x, y in zip(layers_dgl, layers_nx))

g_nx = nx.random_tree(n, seed=42)
g = dgl.DGLGraph()
g.from_networkx(g_nx)
src = 0
_, edges_nx = _bfs_nx(g_nx, src)
edges_dgl = dgl.bfs_edges_generator(g, src)
assert len(edges_dgl) == len(edges_nx)
assert all(toset(x) == y for x, y in zip(edges_dgl, edges_nx))

def test_topological_nodes(n=1000):
g = dgl.DGLGraph()
a = sp.random(n, n, 10 / n, data_rvs=lambda n: np.ones(n))
b = sp.tril(a, -1).tocoo()
g.from_scipy_sparse_matrix(b)

layers_dgl = dgl.topological_nodes_generator(g)

adjmat = g.adjacency_matrix()
def tensor_topo_traverse():
n = g.number_of_nodes()
mask = mx.nd.ones(shape=(n, 1))
degree = mx.nd.dot(adjmat, mask)
while mx.nd.sum(mask) != 0.:
v = (degree == 0.).astype(np.float32)
v = v * mask
mask = mask - v
tmp = np.nonzero(mx.nd.squeeze(v).asnumpy())[0]
frontier = mx.nd.array(tmp, dtype=tmp.dtype)
yield frontier
degree -= mx.nd.dot(adjmat, v)

layers_spmv = list(tensor_topo_traverse())

assert len(layers_dgl) == len(layers_spmv)
assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))

DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree']
def test_dfs_labeled_edges(n=1000, example=False):
dgl_g = dgl.DGLGraph()
dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
dgl_edges, dgl_labels = dgl.dfs_labeled_edges_generator(
dgl_g, [0, 3], has_reverse_edge=True, has_nontree_edge=True)
dgl_edges = [toset(t) for t in dgl_edges]
dgl_labels = [toset(t) for t in dgl_labels]

g1_solutions = [
# edges labels
[[0, 1, 1, 0, 2], [0, 0, 1, 1, 2]],
[[2, 2, 0, 1, 0], [0, 1, 0, 2, 1]],
]
g2_solutions = [
# edges labels
[[3, 3, 4, 4], [0, 1, 0, 1]],
[[4, 4, 3, 3], [0, 1, 0, 1]],
]

def combine_frontiers(sol):
es, ls = zip(*sol)
es = [set(i for i in t if i is not None)
for t in itertools.zip_longest(*es)]
ls = [set(i for i in t if i is not None)
for t in itertools.zip_longest(*ls)]
return es, ls

for sol_set in itertools.product(g1_solutions, g2_solutions):
es, ls = combine_frontiers(sol_set)
if es == dgl_edges and ls == dgl_labels:
break
else:
assert False


if __name__ == '__main__':
test_bfs()
test_topological_nodes()
test_dfs_labeled_edges()

0 comments on commit 899d125

Please sign in to comment.