Skip to content

Commit

Permalink
[Distributed] Fix a bug in sampling an empty frontier (dmlc#3298)
Browse files Browse the repository at this point in the history
* handle empty frontiers.

* fix lint.

* fix

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: xiang song(charlie.song) <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2021
1 parent 152760f commit e5ed7ad
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 31 deletions.
9 changes: 8 additions & 1 deletion python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A set of graph services of getting subgraphs from DistGraph"""
from collections import namedtuple
import numpy as np

from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors
Expand Down Expand Up @@ -383,12 +384,18 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
return sampled_graph

def _frontier_to_heterogeneous_graph(g, frontier, gpb):
# We need to handle empty frontiers correctly.
if frontier.number_of_edges() == 0:
data_dict = {etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes}
return heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)

etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
assert len(eid) > 0
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)

Expand Down
131 changes: 101 additions & 30 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,24 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
assert F.array_equal(u, du)
assert F.array_equal(v, dv)

def create_random_hetero():
num_nodes = {'n1': 10000, 'n2': 10010, 'n3': 10020}
def create_random_hetero(dense=False, empty=False):
num_nodes = {'n1': 210, 'n2': 200, 'n3': 220} if dense else \
{'n1': 1010, 'n2': 1000, 'n3': 1020}
etypes = [('n1', 'r1', 'n2'),
('n1', 'r2', 'n3'),
('n2', 'r3', 'n3')]
edges = {}
random.seed(42)
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.001, format='coo',
random_state=100)
arr = spsp.random(num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
density=0.1 if dense else 0.001,
format='coo', random_state=100)
edges[etype] = (arr.row, arr.col)
return dgl.heterograph(edges, num_nodes)
g = dgl.heterograph(edges, num_nodes)
g.nodes['n1'].data['feat'] = F.ones((g.number_of_nodes('n1'), 10), F.float32, F.cpu())
return g

def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w")
Expand Down Expand Up @@ -298,24 +304,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

def create_random_hetero(dense=False):
num_nodes = {'n1': 210, 'n2': 200, 'n3': 220} if dense else \
{'n1': 1010, 'n2': 1000, 'n3': 1020}
etypes = [('n1', 'r1', 'n2'),
('n1', 'r2', 'n3'),
('n2', 'r3', 'n3')]
edges = {}
random.seed(42)
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.1 if dense else 0.001, format='coo',
random_state=100)
edges[etype] = (arr.row, arr.col)
g = dgl.heterograph(edges, num_nodes)
g.nodes['n1'].data['feat'] = F.ones((g.number_of_nodes('n1'), 10), F.float32, F.cpu())
return g

def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
Expand All @@ -327,7 +316,6 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
if gpb is None:
gpb = dist_graph.get_partition_book()
try:
nodes = {'n3': [0, 10, 99, 66, 124, 208]}
sampled_graph = sample_neighbors(dist_graph, nodes, 3)
block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
Expand All @@ -337,7 +325,8 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
dgl.distributed.exit_client()
return block, gpb

def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3):
def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
nodes={'n3': [0, 10, 99, 66, 124, 208]}):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
Expand All @@ -349,7 +338,6 @@ def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3)
if gpb is None:
gpb = dist_graph.get_partition_book()
try:
nodes = {'n3': [0, 10, 99, 66, 124, 208]}
sampled_graph = sample_etype_neighbors(dist_graph, nodes, dgl.ETYPE, fanout)
block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
Expand Down Expand Up @@ -381,7 +369,8 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
pserver_list.append(p)

time.sleep(3)
block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1)
block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1,
nodes = {'n3': [0, 10, 99, 66, 124, 208]})
print("Done sampling")
for p in pserver_list:
p.join()
Expand Down Expand Up @@ -417,6 +406,49 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst)

def get_degrees(g, nids, ntype):
deg = F.zeros((len(nids),), dtype=F.int64)
for srctype, etype, dsttype in g.canonical_etypes:
if srctype == ntype:
deg += g.out_degrees(u=nids, etype=etype)
elif dsttype == ntype:
deg += g.in_degrees(v=nids, etype=etype)
return deg

def check_rpc_hetero_sampling_empty_shuffle(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close()

g = create_random_hetero(empty=True)
num_parts = num_server
num_hops = 1

orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis',
reshuffle=True, return_mapping=True)

pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)

time.sleep(3)
deg = get_degrees(g, orig_nids['n3'], 'n3')
empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_sample_client(0, tmpdir, num_server > 1,
nodes = {'n3': empty_nids})
print("Done sampling")
for p in pserver_list:
p.join()

assert block.number_of_edges() == 0
assert len(block.etypes) == len(g.etypes)

def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
Expand All @@ -439,7 +471,8 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):

time.sleep(3)
fanout = 3
block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout)
block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'n3': [0, 10, 99, 66, 124, 208]})
print("Done sampling")
for p in pserver_list:
p.join()
Expand Down Expand Up @@ -480,6 +513,40 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server):
assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst)

def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close()
g = create_random_hetero(dense=True, empty=True)
num_parts = num_server
num_hops = 1

orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis',
reshuffle=True, return_mapping=True)

pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)

time.sleep(3)
fanout = 3
deg = get_degrees(g, orig_nids['n3'], 'n3')
empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'n3': empty_nids})
print("Done sampling")
for p in pserver_list:
p.join()

assert block.number_of_edges() == 0
assert len(block.etypes) == len(g.etypes)

# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
Expand All @@ -491,7 +558,9 @@ def test_rpc_sampling_shuffle(num_server):
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), num_server)

def check_standalone_sampling(tmpdir, reshuffle):
g = CitationGraphDataset("cora")[0]
Expand Down Expand Up @@ -654,8 +723,6 @@ def test_standalone_etype_sampling():
os.environ['DGL_DIST_MODE'] = 'standalone'
check_standalone_etype_sampling_heterograph(Path(tmpdirname), True)

test_rpc_sampling_shuffle(1)
test_rpc_sampling_shuffle(2)
with tempfile.TemporaryDirectory() as tmpdirname:
os.environ['DGL_DIST_MODE'] = 'standalone'
check_standalone_etype_sampling(Path(tmpdirname), True)
Expand All @@ -676,3 +743,7 @@ def test_standalone_etype_sampling():
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)

0 comments on commit e5ed7ad

Please sign in to comment.