diff --git a/examples/mxnet/sse/sse_batch.py b/examples/mxnet/sse/sse_batch.py index bc5084c018ec..bdc39ee6bb7f 100644 --- a/examples/mxnet/sse/sse_batch.py +++ b/examples/mxnet/sse/sse_batch.py @@ -158,13 +158,12 @@ def copy_to_gpu(subg, ctx): subg.ndata[key] = frame[key].as_in_context(ctx) class CachedSubgraph(object): - def __init__(self, subg, seeds, subg_seeds): + def __init__(self, subg, seeds): # We can't cache the input subgraph because it contains node frames # and data frames. self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid, subg._graph) self.seeds = seeds - self.subg_seeds = subg_seeds class CachedSubgraphLoader(object): def __init__(self, loader, shuffle): @@ -184,14 +183,13 @@ def __iter__(self): def __next__(self): if len(self._subgraphs) > 0: s = self._subgraphs.pop(0) - subg, seeds, subg_seeds = s.subg, s.seeds, s.subg_seeds + subg, seeds = s.subg, s.seeds elif self._gen_subgraph: subg, seeds = self._loader.__next__() - subg_seeds = subg.map_to_subgraph_nid(seeds) else: raise StopIteration - self._cached.append(CachedSubgraph(subg, seeds, subg_seeds)) - return subg, seeds, subg_seeds + self._cached.append(CachedSubgraph(subg, seeds)) + return subg, seeds def main(args, data): if isinstance(data.features, mx.nd.NDArray): @@ -266,15 +264,16 @@ def main(args, data): sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand, neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs, shuffle=True) - sampler = CachedSubgraphLoader(sampler, shuffle=True) + if args.cache_subgraph: + sampler = CachedSubgraphLoader(sampler, shuffle=True) for epoch in range(args.n_epochs): t0 = time.time() train_loss = 0 i = 0 num_batches = len(train_vs) / args.batch_size start1 = time.time() - sampler.restart() - for subg, seeds, subg_seeds in sampler: + for subg, seeds in sampler: + subg_seeds = subg.map_to_subgraph_nid(seeds) subg.copy_from_parent() losses = [] @@ -308,6 +307,14 @@ def main(args, data): if i > num_batches / 3: break + if args.cache_subgraph: + sampler.restart() + else: + sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand, + neighbor_type='in', + num_workers=args.num_parallel_subgraphs, + seed_nodes=train_vs, shuffle=True) + # prediction. logits = model_infer(g, eval_vs) eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels) @@ -381,11 +388,13 @@ def __init__(self, csr, num_feats): parser.add_argument("--use-spmv", action="store_true", help="use SpMV for faster speed.") parser.add_argument("--dgl", action="store_true") + parser.add_argument("--cache-subgraph", default=False, action="store_false") parser.add_argument("--num-parallel-subgraphs", type=int, default=1, help="the number of subgraphs to construct in parallel.") parser.add_argument("--neigh-expand", type=int, default=16, help="the number of neighbors to sample.") args = parser.parse_args() + print("cache: " + str(args.cache_subgraph)) # load and preprocess dataset if args.graph_file != '': diff --git a/python/dgl/backend/backend.py b/python/dgl/backend/backend.py index 117eb2b964e8..f6fa4fb8c93a 100644 --- a/python/dgl/backend/backend.py +++ b/python/dgl/backend/backend.py @@ -681,6 +681,23 @@ def arange(start, stop): """ pass +def rand_shuffle(arr): + """Random shuffle the data in the first dimension of the array. + + The shuffled data is stored in a new array. + + Parameters + ---------- + arr : Tensor + The data tensor + + Returns + ------- + Tensor + The result tensor + """ + pass + def zerocopy_to_dlpack(input): """Create a dlpack tensor that shares the input memory. diff --git a/python/dgl/backend/mxnet/tensor.py b/python/dgl/backend/mxnet/tensor.py index e8775c0d81cc..dd8f281ce6ba 100644 --- a/python/dgl/backend/mxnet/tensor.py +++ b/python/dgl/backend/mxnet/tensor.py @@ -179,6 +179,9 @@ def sort_1d(input): def arange(start, stop): return nd.arange(start, stop, dtype=np.int64) +def rand_shuffle(arr): + return mx.nd.random.shuffle(arr) + def zerocopy_to_dlpack(arr): return arr.to_dlpack_for_read() diff --git a/python/dgl/backend/numpy/tensor.py b/python/dgl/backend/numpy/tensor.py index 0727363be48b..aab7ddf4db43 100644 --- a/python/dgl/backend/numpy/tensor.py +++ b/python/dgl/backend/numpy/tensor.py @@ -128,6 +128,11 @@ def sort_1d(input): def arange(start, stop): return np.arange(start, stop, dtype=np.int64) +def rand_shuffle(arr): + copy = np.copy(arr) + np.random.shuffle(copy) + return copy + # zerocopy_to_dlpack not enabled # zerocopy_from_dlpack not enabled diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index d79aef309d33..907beba8c60a 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -136,6 +136,10 @@ def sort_1d(input): def arange(start, stop): return th.arange(start, stop, dtype=th.int64) +def rand_shuffle(arr): + idx = th.randperm(len(arr)) + return arr[idx] + def zerocopy_to_dlpack(input): return dlpack.to_dlpack(input.contiguous()) diff --git a/python/dgl/contrib/sampling/sampler.py b/python/dgl/contrib/sampling/sampler.py index 7b02bbff1cb3..804b37d6457a 100644 --- a/python/dgl/contrib/sampling/sampler.py +++ b/python/dgl/contrib/sampling/sampler.py @@ -4,6 +4,7 @@ from ... import utils from ...subgraph import DGLSubGraph +from ... import backend as F __all__ = ['NeighborSampler'] @@ -22,11 +23,11 @@ def __init__(self, g, batch_size, expand_factor, num_hops=1, assert self._node_prob.shape[0] == g.number_of_nodes(), \ "We need to know the sampling probability of every node" if seed_nodes is None: - self._seed_nodes = np.arange(0, g.number_of_nodes(), dtype=np.int64) + self._seed_nodes = F.arange(0, g.number_of_nodes()) else: self._seed_nodes = seed_nodes if shuffle: - np.random.shuffle(self._seed_nodes) + self._seed_nodes = F.rand_shuffle(self._seed_nodes) self._num_workers = num_workers if max_subgraph_size is None: # This size is set temporarily.