From a67ba946bc5f63404660cdfab3928422f9324327 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Thu, 11 Jun 2020 13:04:06 +0800 Subject: [PATCH] [RPC] Add send_request_to_machine() and remote_call_to_machine() (#1619) * add send_request_to_machine() * update * update * update * update * update * update * fix lint * update --- python/dgl/distributed/kvstore.py | 11 +-- python/dgl/distributed/rpc.py | 106 ++++++++++++++++++++++++++- python/dgl/distributed/rpc_client.py | 1 + src/rpc/rpc.cc | 11 +++ src/rpc/rpc.h | 5 ++ tests/compute/test_rpc.py | 17 ++++- 6 files changed, 139 insertions(+), 12 deletions(-) diff --git a/python/dgl/distributed/kvstore.py b/python/dgl/distributed/kvstore.py index d703a2ae30fa..e73cc186f7a4 100644 --- a/python/dgl/distributed/kvstore.py +++ b/python/dgl/distributed/kvstore.py @@ -1,7 +1,6 @@ """Define distributed kvstore""" import os -import random import numpy as np from . import rpc @@ -993,10 +992,7 @@ def push(self, name, id_tensor, data_tensor): local_data = partial_data else: # push data to remote server request = PushRequest(name, partial_id, partial_data) - # randomly select a server node in target machine for load-balance - server_id = random.randint(machine_idx*self._group_count, \ - (machine_idx+1)*self._group_count-1) - rpc.send_request(server_id, request) + rpc.send_request_to_machine(machine_idx, request) start += count[idx] if local_id is not None: # local push self._push_handler(self._data_store, name, local_id, local_data) @@ -1041,10 +1037,7 @@ def pull(self, name, id_tensor): local_id = self._part_policy[name].to_local(partial_id) else: # pull data from remote server request = PullRequest(name, partial_id) - # randomly select a server node in target machine for load-balance - server_id = random.randint(machine_idx*self._group_count, \ - (machine_idx+1)*self._group_count-1) - rpc.send_request(server_id, request) + rpc.send_request_to_machine(machine_idx, request) pull_count += 1 start += count[idx] # recv response diff --git a/python/dgl/distributed/rpc.py b/python/dgl/distributed/rpc.py index aac6b238e5a4..1c5283b3e2e6 100644 --- a/python/dgl/distributed/rpc.py +++ b/python/dgl/distributed/rpc.py @@ -2,6 +2,7 @@ server and clients.""" import abc import pickle +import random from .._ffi.object import register_object, ObjectBase from .._ffi.function import _init_api @@ -12,7 +13,8 @@ 'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \ 'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ -'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call'] +'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ +'send_request_to_machine', 'remote_call_to_machine'] REQUEST_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {} @@ -220,6 +222,16 @@ def get_num_server(): """ return _CAPI_DGLRPCGetNumServer() +def set_num_server_per_machine(num_server): + """Set the total number of server per machine + """ + _CAPI_DGLRPCSetNumServerPerMachine(num_server) + +def get_num_server_per_machine(): + """Get the total number of server per machine + """ + return _CAPI_DGLRPCGetNumServerPerMachine() + def incr_msg_seq(): """Increment the message sequence number and return the old one. @@ -517,6 +529,35 @@ def send_request(target, request): msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) send_rpc_message(msg, server_id) +def send_request_to_machine(target, request): + """Send one request to the target machine, which will randomly + select a server node to process this request. + + The operation is non-blocking -- it does not guarantee the payloads have + reached the target or even have left the sender process. However, + all the payloads (i.e., data and arrays) can be safely freed after this + function returns. + + Parameters + ---------- + target : int + ID of target machine. + request : Request + The request to send. + + Raises + ------ + ConnectionError if there is any problem with the connection. + """ + service_id = request.service_id + msg_seq = incr_msg_seq() + client_id = get_rank() + server_id = random.randint(target*get_num_server_per_machine(), + (target+1)*get_num_server_per_machine()-1) + data, tensors = serialize_to_payload(request) + msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) + send_rpc_message(msg, server_id) + def send_response(target, response): """Send one response to the target client. @@ -644,6 +685,69 @@ def remote_call(target_and_requests, timeout=0): Responses for each target-request pair. If the request does not have response, None is placed. + Raises + ------ + ConnectionError if there is any problem with the connection. + """ + # TODO(chao): handle timeout + all_res = [None] * len(target_and_requests) + msgseq2pos = {} + num_res = 0 + myrank = get_rank() + for pos, (target, request) in enumerate(target_and_requests): + # send request + service_id = request.service_id + msg_seq = incr_msg_seq() + client_id = get_rank() + server_id = random.randint(target*get_num_server_per_machine(), + (target+1)*get_num_server_per_machine()-1) + data, tensors = serialize_to_payload(request) + msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) + send_rpc_message(msg, server_id) + # check if has response + res_cls = get_service_property(service_id)[1] + if res_cls is not None: + num_res += 1 + msgseq2pos[msg_seq] = pos + while num_res != 0: + # recv response + msg = recv_rpc_message(timeout) + num_res -= 1 + _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id] + if res_cls is None: + raise DGLError('Got response message from service ID {}, ' + 'but no response class is registered.'.format(msg.service_id)) + res = deserialize_from_payload(res_cls, msg.data, msg.tensors) + if msg.client_id != myrank: + raise DGLError('Got reponse of request sent by client {}, ' + 'different from my rank {}!'.format(msg.client_id, myrank)) + # set response + all_res[msgseq2pos[msg.msg_seq]] = res + return all_res + +def remote_call_to_machine(target_and_requests, timeout=0): + """Invoke registered services on remote machine + (which will ramdom select a server to process the request) and collect responses. + + The operation is blocking -- it returns when it receives all responses + or it times out. + + If the target server state is available locally, it invokes local computation + to calculate the response. + + Parameters + ---------- + target_and_requests : list[(int, Request)] + A list of requests and the machine they should be sent to. + timeout : int, optional + The timeout value in milliseconds. If zero, wait indefinitely. + + Returns + ------- + list[Response] + Responses for each target-request pair. If the request does not have + response, None is placed. + Raises ------ ConnectionError if there is any problem with the connection. diff --git a/python/dgl/distributed/rpc_client.py b/python/dgl/distributed/rpc_client.py index 793f08fe395f..7a44604bdc2b 100644 --- a/python/dgl/distributed/rpc_client.py +++ b/python/dgl/distributed/rpc_client.py @@ -122,6 +122,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket group_count.append(server_info[3]) if server_info[0] > max_machine_id: max_machine_id = server_info[0] + rpc.set_num_server_per_machine(group_count[0]) num_machines = max_machine_id+1 rpc.set_num_machines(num_machines) machine_id = get_local_machine_id(server_namebook) diff --git a/src/rpc/rpc.cc b/src/rpc/rpc.cc index 87371ef96221..08f8b0b50272 100644 --- a/src/rpc/rpc.cc +++ b/src/rpc/rpc.cc @@ -159,6 +159,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer") *rv = RPCContext::ThreadLocal()->num_servers; }); +DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + const int32_t num_servers = args[0]; + *rv = RPCContext::ThreadLocal()->num_servers_per_machine = num_servers; +}); + +DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::ThreadLocal()->num_servers_per_machine; +}); + DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = (RPCContext::ThreadLocal()->msg_seq)++; diff --git a/src/rpc/rpc.h b/src/rpc/rpc.h index baad7102988b..5f716c8ddc96 100644 --- a/src/rpc/rpc.h +++ b/src/rpc/rpc.h @@ -56,6 +56,11 @@ struct RPCContext { */ int32_t num_servers = 0; + /*! + * \brief Total number of server per machine. + */ + int32_t num_servers_per_machine = 0; + /*! * \brief Sender communicator. */ diff --git a/tests/compute/test_rpc.py b/tests/compute/test_rpc.py index d782daa19fbf..a0d6782c23aa 100644 --- a/tests/compute/test_rpc.py +++ b/tests/compute/test_rpc.py @@ -100,6 +100,21 @@ def start_client(): for i in range(10): target_and_requests.append((0, req)) res_list = dgl.distributed.remote_call(target_and_requests) + for res in res_list: + assert res.hello_str == STR + assert res.integer == INTEGER + assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) + # test send_request_to_machine + dgl.distributed.send_request_to_machine(0, req) + res = dgl.distributed.recv_response() + assert res.hello_str == STR + assert res.integer == INTEGER + assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) + # test remote_call_to_machine + target_and_requests = [] + for i in range(10): + target_and_requests.append((0, req)) + res_list = dgl.distributed.remote_call_to_machine(target_and_requests) for res in res_list: assert res.hello_str == STR assert res.integer == INTEGER @@ -153,8 +168,6 @@ def test_rpc(): start_client() if __name__ == '__main__': - test_rank() - test_msg_seq() test_serialize() test_rpc_msg() test_rpc()