Skip to content

Commit

Permalink
[Sampler] improve random shuffle performance in sampler (dmlc#228)
Browse files Browse the repository at this point in the history
* fix.

* make it generic.

* add the API.

* fix.

* remove mxnet.
  • Loading branch information
zheng-da authored Dec 4, 2018
1 parent 78269ce commit a5a35d1
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 11 deletions.
27 changes: 18 additions & 9 deletions examples/mxnet/sse/sse_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,12 @@ def copy_to_gpu(subg, ctx):
subg.ndata[key] = frame[key].as_in_context(ctx)

class CachedSubgraph(object):
def __init__(self, subg, seeds, subg_seeds):
def __init__(self, subg, seeds):
# We can't cache the input subgraph because it contains node frames
# and data frames.
self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
subg._graph)
self.seeds = seeds
self.subg_seeds = subg_seeds

class CachedSubgraphLoader(object):
def __init__(self, loader, shuffle):
Expand All @@ -184,14 +183,13 @@ def __iter__(self):
def __next__(self):
if len(self._subgraphs) > 0:
s = self._subgraphs.pop(0)
subg, seeds, subg_seeds = s.subg, s.seeds, s.subg_seeds
subg, seeds = s.subg, s.seeds
elif self._gen_subgraph:
subg, seeds = self._loader.__next__()
subg_seeds = subg.map_to_subgraph_nid(seeds)
else:
raise StopIteration
self._cached.append(CachedSubgraph(subg, seeds, subg_seeds))
return subg, seeds, subg_seeds
self._cached.append(CachedSubgraph(subg, seeds))
return subg, seeds

def main(args, data):
if isinstance(data.features, mx.nd.NDArray):
Expand Down Expand Up @@ -266,15 +264,16 @@ def main(args, data):
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
shuffle=True)
sampler = CachedSubgraphLoader(sampler, shuffle=True)
if args.cache_subgraph:
sampler = CachedSubgraphLoader(sampler, shuffle=True)
for epoch in range(args.n_epochs):
t0 = time.time()
train_loss = 0
i = 0
num_batches = len(train_vs) / args.batch_size
start1 = time.time()
sampler.restart()
for subg, seeds, subg_seeds in sampler:
for subg, seeds in sampler:
subg_seeds = subg.map_to_subgraph_nid(seeds)
subg.copy_from_parent()

losses = []
Expand Down Expand Up @@ -308,6 +307,14 @@ def main(args, data):
if i > num_batches / 3:
break

if args.cache_subgraph:
sampler.restart()
else:
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in',
num_workers=args.num_parallel_subgraphs,
seed_nodes=train_vs, shuffle=True)

# prediction.
logits = model_infer(g, eval_vs)
eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
Expand Down Expand Up @@ -381,11 +388,13 @@ def __init__(self, csr, num_feats):
parser.add_argument("--use-spmv", action="store_true",
help="use SpMV for faster speed.")
parser.add_argument("--dgl", action="store_true")
parser.add_argument("--cache-subgraph", default=False, action="store_false")
parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
help="the number of subgraphs to construct in parallel.")
parser.add_argument("--neigh-expand", type=int, default=16,
help="the number of neighbors to sample.")
args = parser.parse_args()
print("cache: " + str(args.cache_subgraph))

# load and preprocess dataset
if args.graph_file != '':
Expand Down
17 changes: 17 additions & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,23 @@ def arange(start, stop):
"""
pass

def rand_shuffle(arr):
"""Random shuffle the data in the first dimension of the array.
The shuffled data is stored in a new array.
Parameters
----------
arr : Tensor
The data tensor
Returns
-------
Tensor
The result tensor
"""
pass

def zerocopy_to_dlpack(input):
"""Create a dlpack tensor that shares the input memory.
Expand Down
3 changes: 3 additions & 0 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def sort_1d(input):
def arange(start, stop):
return nd.arange(start, stop, dtype=np.int64)

def rand_shuffle(arr):
return mx.nd.random.shuffle(arr)

def zerocopy_to_dlpack(arr):
return arr.to_dlpack_for_read()

Expand Down
5 changes: 5 additions & 0 deletions python/dgl/backend/numpy/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def sort_1d(input):
def arange(start, stop):
return np.arange(start, stop, dtype=np.int64)

def rand_shuffle(arr):
copy = np.copy(arr)
np.random.shuffle(copy)
return copy

# zerocopy_to_dlpack not enabled

# zerocopy_from_dlpack not enabled
Expand Down
4 changes: 4 additions & 0 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def sort_1d(input):
def arange(start, stop):
return th.arange(start, stop, dtype=th.int64)

def rand_shuffle(arr):
idx = th.randperm(len(arr))
return arr[idx]

def zerocopy_to_dlpack(input):
return dlpack.to_dlpack(input.contiguous())

Expand Down
5 changes: 3 additions & 2 deletions python/dgl/contrib/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ... import utils
from ...subgraph import DGLSubGraph
from ... import backend as F

__all__ = ['NeighborSampler']

Expand All @@ -22,11 +23,11 @@ def __init__(self, g, batch_size, expand_factor, num_hops=1,
assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node"
if seed_nodes is None:
self._seed_nodes = np.arange(0, g.number_of_nodes(), dtype=np.int64)
self._seed_nodes = F.arange(0, g.number_of_nodes())
else:
self._seed_nodes = seed_nodes
if shuffle:
np.random.shuffle(self._seed_nodes)
self._seed_nodes = F.rand_shuffle(self._seed_nodes)
self._num_workers = num_workers
if max_subgraph_size is None:
# This size is set temporarily.
Expand Down

0 comments on commit a5a35d1

Please sign in to comment.