From 30107407fd0e9289f0c363dfe598c885a2ef784e Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sat, 15 Dec 2018 19:14:17 -0800 Subject: [PATCH] add prefetcher for neighbor sampler (#298) --- python/dgl/contrib/sampling/sampler.py | 135 ++++++++++++++++++++++++- tests/mxnet/test_sampler.py | 11 ++ 2 files changed, 142 insertions(+), 4 deletions(-) diff --git a/python/dgl/contrib/sampling/sampler.py b/python/dgl/contrib/sampling/sampler.py index 5be17cf7f5cd..c4bf3ebeff38 100644 --- a/python/dgl/contrib/sampling/sampler.py +++ b/python/dgl/contrib/sampling/sampler.py @@ -1,10 +1,17 @@ # This file contains subgraph samplers. import numpy as np +import threading +import random +import traceback from ... import utils from ...subgraph import DGLSubGraph from ... import backend as F +try: + import Queue as queue +except ImportError: + import queue __all__ = ['NeighborSampler'] @@ -77,10 +84,124 @@ def __next__(self): aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor() return self._subgraphs.pop(0), aux_infos +class _Prefetcher(object): + """Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation + or Process-based implementation.""" + _dataq = None # Data queue transmits prefetched elements + _controlq = None # Control queue to instruct thread / process shutdown + _errorq = None # Error queue to transmit exceptions from worker to master + + _checked_start = False # True once startup has been checkd by _check_start + + def __init__(self, loader, num_prefetch): + super(_Prefetcher, self).__init__() + self.loader = loader + assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.' + self.num_prefetch = num_prefetch + + def run(self): + """Method representing the process’s activity.""" + # Startup - Master waits for this + try: + loader_iter = iter(self.loader) + self._errorq.put(None) + except Exception as e: # pylint: disable=broad-except + tb = traceback.format_exc() + self._errorq.put((e, tb)) + + while True: + try: # Check control queue + c = self._controlq.get(False) + if c is None: + break + else: + raise RuntimeError('Got unexpected control code {}'.format(repr(c))) + except queue.Empty: + pass + except RuntimeError as e: + tb = traceback.format_exc() + self._errorq.put((e, tb)) + self._dataq.put(None) + + try: + data = next(loader_iter) + error = None + except Exception as e: # pylint: disable=broad-except + tb = traceback.format_exc() + error = (e, tb) + data = None + finally: + self._errorq.put(error) + self._dataq.put(data) + + def __next__(self): + next_item = self._dataq.get() + next_error = self._errorq.get() + + if next_error is None: + return next_item + else: + self._controlq.put(None) + if isinstance(next_error[0], StopIteration): + raise StopIteration + else: + return self._reraise(*next_error) + + def _reraise(self, e, tb): + print('Reraising exception from Prefetcher', file=sys.stderr) + print(tb, file=sys.stderr) + raise e + + def _check_start(self): + assert not self._checked_start + self._checked_start = True + next_error = self._errorq.get(block=True) + if next_error is not None: + self._reraise(*next_error) + + def next(self): + return self.__next__() + + +class _ThreadPrefetcher(_Prefetcher, threading.Thread): + """Internal threaded prefetcher.""" + + def __init__(self, *args, **kwargs): + super(_ThreadPrefetcher, self).__init__(*args, **kwargs) + self._dataq = queue.Queue(self.num_prefetch) + self._controlq = queue.Queue() + self._errorq = queue.Queue(self.num_prefetch) + self.daemon = True + self.start() + self._check_start() + +class _PrefetchingLoader(object): + """Prefetcher for a Loader in a separate Thread or Process. + This iterator will create another thread or process to perform + ``iter_next`` and then store the data in memory. It potentially accelerates + the data read, at the cost of more memory usage. + + Parameters + ---------- + loader : an iterator + Source loader. + num_prefetch : int, default 1 + Number of elements to prefetch from the loader. Must be greater 0. + """ + + def __init__(self, loader, num_prefetch=1): + self._loader = loader + self._num_prefetch = num_prefetch + if num_prefetch < 1: + raise ValueError('num_prefetch must be greater 0.') + + def __iter__(self): + return _ThreadPrefetcher(self._loader, self._num_prefetch) + def NeighborSampler(g, batch_size, expand_factor, num_hops=1, neighbor_type='in', node_prob=None, seed_nodes=None, shuffle=False, num_workers=1, max_subgraph_size=None, - return_seed_id=False): + return_seed_id=False, prefetch=False): '''Create a sampler that samples neighborhood. .. note:: This method currently only supports MXNet backend. Set @@ -129,12 +250,18 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, GPU doesn't support very large subgraphs. return_seed_id: indicates whether to return seed ids along with the subgraphs. The seed Ids are in the parent graph. - + prefetch : bool, default False + Whether to prefetch the samples in the next batch. + Returns ------- A subgraph iterator The iterator returns a list of batched subgraphs and a dictionary of additional information about the subgraphs. ''' - return NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob, - seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id) + loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob, + seed_nodes, shuffle, num_workers, max_subgraph_size, return_seed_id) + if not prefetch: + return loader + else: + return _PrefetchingLoader(loader, num_prefetch=num_workers*2) diff --git a/tests/mxnet/test_sampler.py b/tests/mxnet/test_sampler.py index e88fc33631a4..bbff369491b8 100644 --- a/tests/mxnet/test_sampler.py +++ b/tests/mxnet/test_sampler.py @@ -61,6 +61,17 @@ def test_1neighbor_sampler(): assert subg.number_of_edges() <= 5 verify_subgraph(g, subg, seed_ids) +def test_prefetch_neighbor_sampler(): + g = generate_rand_graph(100) + # In this case, NeighborSampling simply gets the neighborhood of a single vertex. + for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in', + num_workers=4, return_seed_id=True, prefetch=True): + seed_ids = aux['seeds'] + assert len(seed_ids) == 1 + assert subg.number_of_nodes() <= 6 + assert subg.number_of_edges() <= 5 + verify_subgraph(g, subg, seed_ids) + def test_10neighbor_sampler_all(): g = generate_rand_graph(100) # In this case, NeighborSampling simply gets the neighborhood of a single vertex.