Skip to content

Commit

Permalink
[Graph][Bugfix] Fix the API of map_to_subgraph_nid (dmlc#226)
Browse files Browse the repository at this point in the history
* correct vid mapping API.

* fix sse.
  • Loading branch information
zheng-da authored and jermainewang committed Dec 3, 2018
1 parent 419ffbd commit 2c170a8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/mxnet/sse/sse_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def main(args, data):
copy_to_gpu(subg, ctx)

with mx.autograd.record():
logits = model_train(subg, subg_seeds.tousertensor())
logits = model_train(subg, subg_seeds)
batch_labels = mx.nd.take(labels, seeds).as_in_context(logits.context)
loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
loss.backward()
Expand Down
14 changes: 13 additions & 1 deletion python/dgl/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,16 @@ def copy_from_parent(self):
self._parent._edge_frame[self._get_parent_eid()]))

def map_to_subgraph_nid(self, parent_vids):
return map_to_subgraph_nid(self._graph, utils.toindex(parent_vids))
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
return map_to_subgraph_nid(self._graph, utils.toindex(parent_vids)).tousertensor()
23 changes: 11 additions & 12 deletions tests/mxnet/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,35 @@ def test_1neighbor_sampler_all():
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
num_workers=4):
assert len(seed_ids) == 1
src, dst, eid = g._graph.in_edges(utils.toindex(seed_ids))
src, dst, eid = g.in_edges(seed_ids, form='all')
# Test if there is a self loop
self_loop = mx.nd.sum(src.tousertensor() == dst.tousertensor()).asnumpy() == 1
self_loop = mx.nd.sum(src == dst).asnumpy() == 1
if self_loop:
assert subg.number_of_nodes() == len(src)
else:
assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() >= len(src)

child_ids = subg.map_to_subgraph_nid(seed_ids)
child_src, child_dst, child_eid = subg._graph.in_edges(child_ids)
child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')

child_src1 = subg.map_to_subgraph_nid(src)
assert mx.nd.sum(child_src1.tousertensor() == child_src.tousertensor()).asnumpy() == len(src)
assert mx.nd.sum(child_src1 == child_src).asnumpy() == len(src)

def is_sorted(arr):
return np.sum(np.sort(arr) == arr) == len(arr)

def verify_subgraph(g, subg, seed_id):
seed_id = utils.toindex(seed_id)
src, dst, eid = g._graph.in_edges(utils.toindex(seed_id))
src, dst, eid = g.in_edges(seed_id, form='all')
child_id = subg.map_to_subgraph_nid(seed_id)
child_src, child_dst, child_eid = subg._graph.in_edges(child_id)
child_src = child_src.tousertensor().asnumpy()
child_src, child_dst, child_eid = subg.in_edges(child_id, form='all')
child_src = child_src.asnumpy()
# We don't allow duplicate elements in the neighbor list.
assert(len(np.unique(child_src)) == len(child_src))
# The neighbor list also needs to be sorted.
assert(is_sorted(child_src))

child_src1 = subg.map_to_subgraph_nid(src).tousertensor().asnumpy()
child_src1 = subg.map_to_subgraph_nid(src).asnumpy()
child_src1 = child_src1[child_src1 >= 0]
for i in child_src:
assert i in child_src1
Expand All @@ -65,13 +64,13 @@ def test_10neighbor_sampler_all():
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4):
src, dst, eid = g._graph.in_edges(utils.toindex(seed_ids))
src, dst, eid = g.in_edges(seed_ids, form='all')

child_ids = subg.map_to_subgraph_nid(seed_ids)
child_src, child_dst, child_eid = subg._graph.in_edges(child_ids)
child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')

child_src1 = subg.map_to_subgraph_nid(src)
assert mx.nd.sum(child_src1.tousertensor() == child_src.tousertensor()).asnumpy() == len(src)
assert mx.nd.sum(child_src1 == child_src).asnumpy() == len(src)

def check_10neighbor_sampler(g, seeds):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
Expand Down

0 comments on commit 2c170a8

Please sign in to comment.