Skip to content

Commit

Permalink
[Feature] Enable UVA for Weighted Samplers (dmlc#4314)
Browse files Browse the repository at this point in the history
* enable use for weighted neighbor sampler and biased random walk

* add unit tests

* fix for mxnet/tf

* fix typo
  • Loading branch information
yaox12 authored Aug 1, 2022
1 parent 9a16a5e commit 44b6864
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 84 deletions.
21 changes: 11 additions & 10 deletions python/dgl/sampling/randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
If a random walk stops in advance, DGL pads the trace with -1 to have the same
length.
This function supports the graph on GPU.
This function supports the graph on GPU and UVA sampling.
Parameters
----------
Expand All @@ -39,8 +39,9 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
nodes : Tensor
Node ID tensor from which the random walk traces starts.
The tensor must be on the same device as the graph and have the same dtype as the ID type
of the graph.
The tensor must have the same dtype as the ID type of the graph.
The tensor must be on the same device as the graph or
on the GPU when the graph is pinned (UVA sampling).
metapath : list[str or tuple of str], optional
Metapath, specified as a list of edge types.
Expand Down Expand Up @@ -69,6 +70,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
Probability to terminate the current trace before each transition.
If a tensor is given, :attr:`restart_prob` should be on the same device as the graph
or on the GPU when the graph is pinned (UVA sampling),
and have the same length as :attr:`metapath` or :attr:`length`.
return_eids : bool, optional
If True, additionally return the edge IDs traversed.
Expand Down Expand Up @@ -180,19 +182,16 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
metapath = F.to_dgl_nd(F.astype(F.tensor(metapath), g.idtype))

# Load the probability tensor from the edge frames
ctx = utils.to_dgl_context(g.device)
if prob is None:
p_nd = [nd.array([], ctx=nodes.ctx) for _ in g.canonical_etypes]
p_nd = [nd.array([], ctx=ctx) for _ in g.canonical_etypes]
else:
p_nd = []
for etype in g.canonical_etypes:
if prob in g.edges[etype].data:
prob_nd = F.to_dgl_nd(g.edges[etype].data[prob])
if prob_nd.ctx != nodes.ctx:
raise ValueError(
'context of seed node array and edges[%s].data[%s] are different' %
(etype, prob))
else:
prob_nd = nd.array([], ctx=nodes.ctx)
prob_nd = nd.array([], ctx=ctx)
p_nd.append(prob_nd)

# Actual random walk
Expand All @@ -202,9 +201,11 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
restart_prob = F.to_dgl_nd(restart_prob)
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart(
gidx, nodes, metapath, p_nd, restart_prob)
else:
elif isinstance(restart_prob, float):
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart(
gidx, nodes, metapath, p_nd, restart_prob)
else:
raise TypeError("restart_prob should be float or Tensor.")

traces = F.from_dgl_nd(traces)
types = F.from_dgl_nd(types)
Expand Down
3 changes: 2 additions & 1 deletion src/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,8 @@ COOMatrix CSRRowWiseSampling(
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
});
} else {
CHECK_SAME_CONTEXT(rows, prob);
// prob is pinned and rows on GPU is valid
CHECK_VALID_CONTEXT(prob, rows);
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
Expand Down
3 changes: 0 additions & 3 deletions src/graph/sampling/randomwalks/randomwalk_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ std::pair<IdArray, IdArray> RandomWalk(
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
});
return ret;
Expand Down Expand Up @@ -442,7 +441,6 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(
hg, seeds, metapath, prob, restart_prob_array);
});
Expand Down Expand Up @@ -471,7 +469,6 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
});
return ret;
Expand Down
7 changes: 6 additions & 1 deletion src/graph/sampling/randomwalks/randomwalks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ void CheckRandomWalkInputs(
}
for (uint64_t i = 0; i < prob.size(); ++i) {
FloatArray p = prob[i];
CHECK_EQ(hg->Context(), p->ctx) << "Expected prob (" << p->ctx << ")" << " to have the same " \
<< "context as graph (" << hg->Context() << ").";
CHECK_FLOAT(p, "probability");
if (p.GetSize() != 0)
if (p.GetSize() != 0) {
CHECK_EQ(hg->IsPinned(), p.IsPinned())
<< "The prob array should have the same pinning status as the graph";
CHECK_NDIM(p, 1, "probability");
}
}
}

Expand Down
98 changes: 58 additions & 40 deletions tests/compute/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,58 +24,74 @@ def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
u, v = g.find_edges(trace_eids[i, j], etype=metapath[j])
assert (u == traces[i, j]) and (v == traces[i, j + 1])

def test_non_uniform_random_walk():
@pytest.mark.parametrize('use_uva', [True, False])
def test_non_uniform_random_walk(use_uva):
if use_uva:
if F.ctx() == F.cpu():
pytest.skip('UVA biased random walk requires a GPU.')
if dgl.backend.backend_name != 'pytorch':
pytest.skip('UVA biased random walk is only supported with PyTorch.')
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])
}).to(F.ctx())
})
g4 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3])
}).to(F.ctx())
})

g2.edata['p'] = F.tensor([3, 0, 3, 3, 3], dtype=F.float32)
g2.edata['p2'] = F.tensor([[3], [0], [3], [3], [3]], dtype=F.float32)
g4.edges['follow'].data['p'] = F.tensor([3, 0, 3, 3, 3], dtype=F.float32)
g4.edges['viewed-by'].data['p'] = F.tensor([1, 1, 1, 1, 1, 1], dtype=F.float32)
g2.edata['p'] = F.copy_to(F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu())
g2.edata['p2'] = F.copy_to(F.tensor([[3], [0], [3], [3], [3]], dtype=F.float32), F.cpu())
g4.edges['follow'].data['p'] = F.copy_to(F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu())
g4.edges['viewed-by'].data['p'] = F.copy_to(F.tensor([1, 1, 1, 1, 1, 1], dtype=F.float32), F.cpu())

traces, eids, ntypes = dgl.sampling.random_walk(
g2, [0, 1, 2, 3, 0, 1, 2, 3], length=4, prob='p', return_eids=True)
check_random_walk(g2, ['follow'] * 4, traces, ntypes, 'p', trace_eids=eids)
if use_uva:
for g in (g2, g4):
g.create_formats_()
g.pin_memory_()
elif F._default_context_str == 'gpu':
g2 = g2.to(F.ctx())
g4 = g4.to(F.ctx())

try:
traces, ntypes = dgl.sampling.random_walk(
g2, [0, 1, 2, 3, 0, 1, 2, 3], length=4, prob='p2')
fail = False
except dgl.DGLError:
fail = True
assert fail
traces, eids, ntypes = dgl.sampling.random_walk(
g2, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
length=4, prob='p', return_eids=True)
check_random_walk(g2, ['follow'] * 4, traces, ntypes, 'p', trace_eids=eids)

metapath = ['follow', 'view', 'viewed-by'] * 2
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p', return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p', restart_prob=0., return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p',
restart_prob=F.zeros((6,), F.float32, F.ctx()), return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath + ['follow'], prob='p',
restart_prob=F.tensor([0, 0, 0, 0, 0, 0, 1], F.float32), return_eids=True)
check_random_walk(g4, metapath, traces[:, :7], ntypes[:7], 'p', trace_eids=eids)
assert (F.asnumpy(traces[:, 7]) == -1).all()

def _use_uva():
if F._default_context_str == 'cpu':
return [False]
else:
return [True, False]
with pytest.raises(dgl.DGLError):
traces, ntypes = dgl.sampling.random_walk(
g2, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
length=4, prob='p2')

metapath = ['follow', 'view', 'viewed-by'] * 2
traces, eids, ntypes = dgl.sampling.random_walk(
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath, prob='p', return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath, prob='p', restart_prob=0., return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath, prob='p',
restart_prob=F.zeros((6,), F.float32, F.ctx()), return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath + ['follow'], prob='p',
restart_prob=F.tensor([0, 0, 0, 0, 0, 0, 1], F.float32), return_eids=True)
check_random_walk(g4, metapath, traces[:, :7], ntypes[:7], 'p', trace_eids=eids)
assert (F.asnumpy(traces[:, 7]) == -1).all()
finally:
for g in (g2, g4):
g.unpin_memory_()

@pytest.mark.parametrize('use_uva', _use_uva())
@pytest.mark.parametrize('use_uva', [True, False])
def test_uniform_random_walk(use_uva):
if use_uva and F.ctx() == F.cpu():
pytest.skip('UVA random walk requires a GPU.')
g1 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 2], [1, 2, 0])
})
Expand Down Expand Up @@ -178,8 +194,10 @@ def test_pack_traces():
assert F.array_equal(result[2], F.tensor([2, 7], dtype=F.int64))
assert F.array_equal(result[3], F.tensor([0, 2], dtype=F.int64))

@pytest.mark.parametrize('use_uva', _use_uva())
@pytest.mark.parametrize('use_uva', [True, False])
def test_pinsage_sampling(use_uva):
if use_uva and F.ctx() == F.cpu():
pytest.skip('UVA sampling requires a GPU.')
def _test_sampler(g, sampler, ntype):
seeds = F.copy_to(F.tensor([0, 2], dtype=g.idtype), F.ctx())
neighbor_g = sampler(seeds)
Expand Down
89 changes: 60 additions & 29 deletions tests/pytorch/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import backend as F
import unittest
import torch
import torch.distributed as dist
from functools import partial
from torch.utils.data import DataLoader
from collections import defaultdict
Expand Down Expand Up @@ -70,43 +71,73 @@ def test_saint(num_workers, mode):
for sg in dataloader:
pass

@pytest.mark.parametrize('num_workers', [0, 4])
def test_neighbor_nonuniform(num_workers):
g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]))
@parametrize_idtype
@pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
def test_neighbor_nonuniform(idtype, mode, use_ddp):
if mode != 'cpu' and F.ctx() == F.cpu():
pytest.skip('UVA and GPU sampling require a GPU.')
if use_ddp:
dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
'tcp://127.0.0.1:12347', world_size=1, rank=0)
g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(idtype)
g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
if mode in ('cpu', 'uva_cpu_indices'):
indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
else:
indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
if mode == 'pure_gpu':
g = g.to(F.cuda())
use_uva = mode.startswith('uva')

sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
dataloader = dgl.dataloading.NodeDataLoader(g, [0, 1], sampler, batch_size=1, device=F.ctx())
for input_nodes, output_nodes, blocks in dataloader:
seed = output_nodes.item()
neighbors = set(input_nodes[1:].cpu().numpy())
if seed == 1:
assert neighbors == {5, 6}
elif seed == 0:
assert neighbors == {1, 2}
for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
dataloader = dgl.dataloading.NodeDataLoader(
g, indices, sampler,
batch_size=1, device=F.ctx(),
num_workers=num_workers,
use_uva=use_uva,
use_ddp=use_ddp)
for input_nodes, output_nodes, blocks in dataloader:
seed = output_nodes.item()
neighbors = set(input_nodes[1:].cpu().numpy())
if seed == 1:
assert neighbors == {5, 6}
elif seed == 0:
assert neighbors == {1, 2}

g = dgl.heterograph({
('B', 'BA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
})
}).astype(idtype)
g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
dataloader = dgl.dataloading.NodeDataLoader(
g, {'A': [0, 1]}, sampler, batch_size=1, device=F.ctx())
for input_nodes, output_nodes, blocks in dataloader:
seed = output_nodes['A'].item()
# Seed and neighbors are of different node types so slicing is not necessary here.
neighbors = set(input_nodes['B'].cpu().numpy())
if seed == 1:
assert neighbors == {5, 6}
elif seed == 0:
assert neighbors == {1, 2}

neighbors = set(input_nodes['C'].cpu().numpy())
if seed == 1:
assert neighbors == {7, 8}
elif seed == 0:
assert neighbors == {3, 4}
if mode == 'pure_gpu':
g = g.to(F.cuda())
for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
dataloader = dgl.dataloading.NodeDataLoader(
g, {'A': indices}, sampler,
batch_size=1, device=F.ctx(),
num_workers=num_workers,
use_uva=use_uva,
use_ddp=use_ddp)
for input_nodes, output_nodes, blocks in dataloader:
seed = output_nodes['A'].item()
# Seed and neighbors are of different node types so slicing is not necessary here.
neighbors = set(input_nodes['B'].cpu().numpy())
if seed == 1:
assert neighbors == {5, 6}
elif seed == 0:
assert neighbors == {1, 2}

neighbors = set(input_nodes['C'].cpu().numpy())
if seed == 1:
assert neighbors == {7, 8}
elif seed == 0:
assert neighbors == {3, 4}

if use_ddp:
dist.destroy_process_group()

def _check_dtype(data, dtype, attr_name):
if isinstance(data, dict):
Expand Down

0 comments on commit 44b6864

Please sign in to comment.