Skip to content

Commit

Permalink
add prefetcher for neighbor sampler (dmlc#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored and zheng-da committed Dec 16, 2018
1 parent d7a3b2a commit 3010740
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 4 deletions.
135 changes: 131 additions & 4 deletions python/dgl/contrib/sampling/sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# This file contains subgraph samplers.

import numpy as np
import threading
import random
import traceback

from ... import utils
from ...subgraph import DGLSubGraph
from ... import backend as F
try:
import Queue as queue
except ImportError:
import queue

__all__ = ['NeighborSampler']

Expand Down Expand Up @@ -77,10 +84,124 @@ def __next__(self):
aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor()
return self._subgraphs.pop(0), aux_infos

class _Prefetcher(object):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
or Process-based implementation."""
_dataq = None # Data queue transmits prefetched elements
_controlq = None # Control queue to instruct thread / process shutdown
_errorq = None # Error queue to transmit exceptions from worker to master

_checked_start = False # True once startup has been checkd by _check_start

def __init__(self, loader, num_prefetch):
super(_Prefetcher, self).__init__()
self.loader = loader
assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.'
self.num_prefetch = num_prefetch

def run(self):
"""Method representing the process’s activity."""
# Startup - Master waits for this
try:
loader_iter = iter(self.loader)
self._errorq.put(None)
except Exception as e: # pylint: disable=broad-except
tb = traceback.format_exc()
self._errorq.put((e, tb))

while True:
try: # Check control queue
c = self._controlq.get(False)
if c is None:
break
else:
raise RuntimeError('Got unexpected control code {}'.format(repr(c)))
except queue.Empty:
pass
except RuntimeError as e:
tb = traceback.format_exc()
self._errorq.put((e, tb))
self._dataq.put(None)

try:
data = next(loader_iter)
error = None
except Exception as e: # pylint: disable=broad-except
tb = traceback.format_exc()
error = (e, tb)
data = None
finally:
self._errorq.put(error)
self._dataq.put(data)

def __next__(self):
next_item = self._dataq.get()
next_error = self._errorq.get()

if next_error is None:
return next_item
else:
self._controlq.put(None)
if isinstance(next_error[0], StopIteration):
raise StopIteration
else:
return self._reraise(*next_error)

def _reraise(self, e, tb):
print('Reraising exception from Prefetcher', file=sys.stderr)
print(tb, file=sys.stderr)
raise e

def _check_start(self):
assert not self._checked_start
self._checked_start = True
next_error = self._errorq.get(block=True)
if next_error is not None:
self._reraise(*next_error)

def next(self):
return self.__next__()


class _ThreadPrefetcher(_Prefetcher, threading.Thread):
"""Internal threaded prefetcher."""

def __init__(self, *args, **kwargs):
super(_ThreadPrefetcher, self).__init__(*args, **kwargs)
self._dataq = queue.Queue(self.num_prefetch)
self._controlq = queue.Queue()
self._errorq = queue.Queue(self.num_prefetch)
self.daemon = True
self.start()
self._check_start()

class _PrefetchingLoader(object):
"""Prefetcher for a Loader in a separate Thread or Process.
This iterator will create another thread or process to perform
``iter_next`` and then store the data in memory. It potentially accelerates
the data read, at the cost of more memory usage.
Parameters
----------
loader : an iterator
Source loader.
num_prefetch : int, default 1
Number of elements to prefetch from the loader. Must be greater 0.
"""

def __init__(self, loader, num_prefetch=1):
self._loader = loader
self._num_prefetch = num_prefetch
if num_prefetch < 1:
raise ValueError('num_prefetch must be greater 0.')

def __iter__(self):
return _ThreadPrefetcher(self._loader, self._num_prefetch)

def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, max_subgraph_size=None,
return_seed_id=False):
return_seed_id=False, prefetch=False):
'''Create a sampler that samples neighborhood.
.. note:: This method currently only supports MXNet backend. Set
Expand Down Expand Up @@ -129,12 +250,18 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
GPU doesn't support very large subgraphs.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
The seed Ids are in the parent graph.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
Returns
-------
A subgraph iterator
The iterator returns a list of batched subgraphs and a dictionary of additional
information about the subgraphs.
'''
return NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id)
loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id)
if not prefetch:
return loader
else:
return _PrefetchingLoader(loader, num_prefetch=num_workers*2)
11 changes: 11 additions & 0 deletions tests/mxnet/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def test_1neighbor_sampler():
assert subg.number_of_edges() <= 5
verify_subgraph(g, subg, seed_ids)

def test_prefetch_neighbor_sampler():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
num_workers=4, return_seed_id=True, prefetch=True):
seed_ids = aux['seeds']
assert len(seed_ids) == 1
assert subg.number_of_nodes() <= 6
assert subg.number_of_edges() <= 5
verify_subgraph(g, subg, seed_ids)

def test_10neighbor_sampler_all():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
Expand Down

0 comments on commit 3010740

Please sign in to comment.