Skip to content

Commit

Permalink
[sampler] Adjust the sampler API for the future extension. (dmlc#243)
Browse files Browse the repository at this point in the history
* return seed ids.

* fix tests.

* implement.
  • Loading branch information
zheng-da authored Dec 5, 2018
1 parent 40506ec commit 7c7cc7e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
8 changes: 5 additions & 3 deletions examples/mxnet/sse/sse_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def main(args, data):
dur = []
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
shuffle=True)
shuffle=True, return_seed_id=True)
if args.cache_subgraph:
sampler = CachedSubgraphLoader(sampler, shuffle=True)
for epoch in range(args.n_epochs):
Expand All @@ -272,7 +272,8 @@ def main(args, data):
i = 0
num_batches = len(train_vs) / args.batch_size
start1 = time.time()
for subg, seeds in sampler:
for subg, aux_infos in sampler:
seeds = aux_infos['seeds']
subg_seeds = subg.map_to_subgraph_nid(seeds)
subg.copy_from_parent()

Expand Down Expand Up @@ -313,7 +314,8 @@ def main(args, data):
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in',
num_workers=args.num_parallel_subgraphs,
seed_nodes=train_vs, shuffle=True)
seed_nodes=train_vs, shuffle=True,
return_seed_id=True)

# prediction.
logits = model_infer(g, eval_vs)
Expand Down
28 changes: 21 additions & 7 deletions python/dgl/contrib/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
class NSSubgraphLoader(object):
def __init__(self, g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None):
shuffle=False, num_workers=1, max_subgraph_size=None,
return_seed_id=False):
self._g = g
if not g._graph.is_readonly():
raise NotImplementedError("subgraph loader only support read-only graphs.")
self._batch_size = batch_size
self._expand_factor = expand_factor
self._num_hops = num_hops
self._node_prob = node_prob
self._return_seed_id = return_seed_id
if self._node_prob is not None:
assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node"
Expand Down Expand Up @@ -56,7 +58,8 @@ def _prefetch(self):
subgraphs = [DGLSubGraph(self._g, i.induced_nodes, i.induced_edges, \
i) for i in sgi]
self._subgraphs.extend(subgraphs)
self._seed_ids.extend(seed_ids)
if self._return_seed_id:
self._seed_ids.extend(seed_ids)

def __iter__(self):
return self
Expand All @@ -69,11 +72,15 @@ def __next__(self):
# iterate all subgraphs and we should stop the iterator now.
if len(self._subgraphs) == 0:
raise StopIteration
return self._subgraphs.pop(0), self._seed_ids.pop(0).tousertensor()
aux_infos = {}
if self._return_seed_id:
aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor()
return self._subgraphs.pop(0), aux_infos

def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None):
shuffle=False, num_workers=1, max_subgraph_size=None,
return_seed_id=False):
'''
This creates a subgraph data loader that samples subgraphs from the input graph
with neighbor sampling. This simpling method is implemented in C and can perform
Expand All @@ -86,6 +93,11 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
that connect the source nodes and the sampled neighbor nodes of the source
nodes.
The subgraph loader returns a list of subgraphs and a dictionary of additional
information about the subgraphs. The size of the subgraph list is the number of workers.
The dictionary contains:
'seeds': a list of 1D tensors of seed Ids, if return_seed_id is True.
Parameters
----------
g: the DGLGraph where we sample subgraphs.
Expand All @@ -109,11 +121,13 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
num_workers: the number of worker threads that sample subgraphs in parallel.
max_subgraph_size: the maximal subgraph size in terms of the number of nodes.
GPU doesn't support very large subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
The seed Ids are in the parent graph.
Returns
-------
A subgraph loader that returns a batch of subgraphs and
the Ids of the seed vertices used in the batch.
A subgraph loader that returns a list of batched subgraphs and a dictionary of
additional infomration about the subgraphs.
'''
return NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, max_subgraph_size)
seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id)
21 changes: 13 additions & 8 deletions tests/mxnet/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def generate_rand_graph(n):
def test_1neighbor_sampler_all():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
num_workers=4):
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
num_workers=4, return_seed_id=True):
seed_ids = aux['seeds']
assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all')
# Test if there is a self loop
Expand Down Expand Up @@ -52,8 +53,9 @@ def verify_subgraph(g, subg, seed_id):
def test_1neighbor_sampler():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4):
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4, return_seed_id=True):
seed_ids = aux['seeds']
assert len(seed_ids) == 1
assert subg.number_of_nodes() <= 6
assert subg.number_of_edges() <= 5
Expand All @@ -62,8 +64,9 @@ def test_1neighbor_sampler():
def test_10neighbor_sampler_all():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4):
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4, return_seed_id=True):
seed_ids = aux['seeds']
src, dst, eid = g.in_edges(seed_ids, form='all')

child_ids = subg.map_to_subgraph_nid(seed_ids)
Expand All @@ -74,8 +77,10 @@ def test_10neighbor_sampler_all():

def check_10neighbor_sampler(g, seeds):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
num_workers=4, seed_nodes=seeds):
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
num_workers=4, seed_nodes=seeds,
return_seed_id=True):
seed_ids = aux['seeds']
assert subg.number_of_nodes() <= 6 * len(seed_ids)
assert subg.number_of_edges() <= 5 * len(seed_ids)
for seed_id in seed_ids:
Expand Down

0 comments on commit 7c7cc7e

Please sign in to comment.