Skip to content

Commit

Permalink
[RPC] Rpc exit with explicit invocation (dmlc#1825)
Browse files Browse the repository at this point in the history
* exit client

* update

* update

* update

* update

* update

* update

* update

* update test

* update

* update

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
aksnzhy authored Jul 20, 2020
1 parent 6455781 commit 5c92f6c
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 18 deletions.
2 changes: 0 additions & 2 deletions examples/pytorch/graphsage/experimental/train_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ def run(args, device, data):
# clean up
if not args.standalone:
g._client.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()

def main(args):
if not args.standalone:
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, finalize_client, shutdown_servers
from .rpc_client import connect_to_server, exit_client
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .graph_services import sample_neighbors, in_subgraph
Expand Down
1 change: 0 additions & 1 deletion python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ def __init__(self, ip_config, graph_name, gpb=None, conf_file=None):
self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges'])


def init_ndata(self, name, shape, dtype, init_func=None):
'''Initialize node data
Expand Down
10 changes: 10 additions & 0 deletions python/dgl/distributed/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import socket
import atexit

from . import rpc
from .constants import MAX_QUEUE_SIZE
Expand Down Expand Up @@ -169,6 +170,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.send_request(0, get_client_num_req)
res = rpc.recv_response()
rpc.set_num_client(res.num_client)
atexit.register(exit_client)

def finalize_client():
"""Release resources of this client."""
Expand All @@ -186,3 +188,11 @@ def shutdown_servers():
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)

def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
shutdown_servers()
finalize_client()
atexit.unregister(exit_client)
4 changes: 0 additions & 4 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ def check_dist_graph(g, num_nodes, num_edges):
for n in nodes:
assert n in local_nids

# clean up
if os.environ['DGL_DIST_MODE'] == 'distributed':
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print('end')

def check_server_client(shared_mem):
Expand Down
6 changes: 2 additions & 4 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
dgl.distributed.exit_client()
return sampled_graph


Expand Down Expand Up @@ -162,8 +161,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
dgl.distributed.exit_client()
return sampled_graph


Expand Down
2 changes: 0 additions & 2 deletions tests/distributed/test_new_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ def start_client(num_clients):
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# clean up
kvclient.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()

@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store():
Expand Down
4 changes: 0 additions & 4 deletions tests/distributed/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ def start_client(ip_config):

# clean up
time.sleep(2)
if dgl.distributed.get_rank() == 0:
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print("Get rank: %d" % dgl.distributed.get_rank())

def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
Expand Down

0 comments on commit 5c92f6c

Please sign in to comment.