Skip to content

Commit

Permalink
[Fix] sleep for a while when launching clients which will connect to … (
Browse files Browse the repository at this point in the history
dmlc#3704)

* [Fix] sleep for a while when launching clients which will connect to multiple servers

* pre-allocate more ports

* no multiple partitions on single machine
  • Loading branch information
Rhett-Ying authored Jan 30, 2022
1 parent 701b4fc commit 9c8c162
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
18 changes: 9 additions & 9 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_client_empty(graph_name, part_id, server_count, num_clients, num_nodes,
check_dist_graph_empty(g, num_clients, num_nodes, num_edges)

def check_server_client_empty(shared_mem, num_servers, num_clients):
prepare_dist()
prepare_dist(num_servers)
g = create_random_graph(10000)

# Partition the graph
Expand Down Expand Up @@ -281,7 +281,7 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
print('end')

def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist()
prepare_dist(num_servers)
g = create_random_graph(10000)

# Partition the graph
Expand Down Expand Up @@ -311,6 +311,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_group
g.number_of_edges(),
group_id))
p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
cli_ps.append(p)

for p in cli_ps:
Expand All @@ -328,7 +329,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients, num_group
print('clients have terminated')

def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
prepare_dist()
prepare_dist(num_servers)
g = create_random_graph(10000)

# Partition the graph
Expand Down Expand Up @@ -357,6 +358,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
p = ctx.Process(target=run_client, args=(graph_name, 0, num_servers, num_clients, g.number_of_nodes(),
g.number_of_edges(), group_id))
p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
cli_ps.append(p)
for p in cli_ps:
p.join()
Expand All @@ -372,7 +374,7 @@ def check_server_client(shared_mem, num_servers, num_clients, num_groups=1):
print('clients have terminated')

def check_server_client_hierarchy(shared_mem, num_servers, num_clients):
prepare_dist()
prepare_dist(num_servers)
g = create_random_graph(10000)

# Partition the graph
Expand Down Expand Up @@ -535,7 +537,7 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
print('end')

def check_server_client_hetero(shared_mem, num_servers, num_clients):
prepare_dist()
prepare_dist(num_servers)
g = create_random_hetero()

# Partition the graph
Expand Down Expand Up @@ -641,7 +643,6 @@ def test_standalone_node_emb():

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split():
#prepare_dist()
g = create_random_graph(10000)
num_parts = 4
num_hops = 2
Expand Down Expand Up @@ -696,7 +697,6 @@ def set_roles(num_clients):

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split_even():
#prepare_dist(1)
g = create_random_graph(10000)
num_parts = 4
num_hops = 2
Expand Down Expand Up @@ -763,8 +763,8 @@ def set_roles(num_clients):
assert np.all(all_nodes == F.asnumpy(all_nodes2))
assert np.all(all_edges == F.asnumpy(all_edges2))

def prepare_dist():
generate_ip_config("kv_ip_config.txt", 1, 1)
def prepare_dist(num_servers=1):
generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)

if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1):
for group_id in range(num_groups):
p = ctx.Process(target=start_sample_client_shuffle, args=(client_id, tmpdir, num_server > 1, g, num_server, group_id))
p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
pclient_list.append(p)
for p in pclient_list:
p.join()
Expand Down Expand Up @@ -563,7 +564,7 @@ def test_rpc_sampling_shuffle(num_server):
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=5)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2)
check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
Expand Down
5 changes: 5 additions & 0 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def check_neg_dataloader(g, tmpdir, num_server, num_workers):
@pytest.mark.parametrize("num_groups", [1, 2])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle, num_groups):
reset_envs()
# No multiple partitions on single machine for
# multiple client groups in case of race condition.
if num_groups > 1:
num_server = 1
generate_ip_config("mp_ip_config.txt", num_server, num_server)

g = CitationGraphDataset("cora")[0]
Expand Down Expand Up @@ -246,6 +250,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle,
p = ctx.Process(target=start_dist_dataloader, args=(
trainer_id, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id))
p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
ptrainer_list.append(p)

for p in ptrainer_list:
Expand Down

0 comments on commit 9c8c162

Please sign in to comment.