Skip to content

Commit

Permalink
[RPC] Add send_request_to_machine() and remote_call_to_machine() (dml…
Browse files Browse the repository at this point in the history
…c#1619)

* add send_request_to_machine()

* update

* update

* update

* update

* update

* update

* fix lint

* update
  • Loading branch information
aksnzhy authored Jun 11, 2020
1 parent 50962c4 commit a67ba94
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 12 deletions.
11 changes: 2 additions & 9 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Define distributed kvstore"""

import os
import random
import numpy as np

from . import rpc
Expand Down Expand Up @@ -993,10 +992,7 @@ def push(self, name, id_tensor, data_tensor):
local_data = partial_data
else: # push data to remote server
request = PushRequest(name, partial_id, partial_data)
# randomly select a server node in target machine for load-balance
server_id = random.randint(machine_idx*self._group_count, \
(machine_idx+1)*self._group_count-1)
rpc.send_request(server_id, request)
rpc.send_request_to_machine(machine_idx, request)
start += count[idx]
if local_id is not None: # local push
self._push_handler(self._data_store, name, local_id, local_data)
Expand Down Expand Up @@ -1041,10 +1037,7 @@ def pull(self, name, id_tensor):
local_id = self._part_policy[name].to_local(partial_id)
else: # pull data from remote server
request = PullRequest(name, partial_id)
# randomly select a server node in target machine for load-balance
server_id = random.randint(machine_idx*self._group_count, \
(machine_idx+1)*self._group_count-1)
rpc.send_request(server_id, request)
rpc.send_request_to_machine(machine_idx, request)
pull_count += 1
start += count[idx]
# recv response
Expand Down
106 changes: 105 additions & 1 deletion python/dgl/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
server and clients."""
import abc
import pickle
import random

from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
Expand All @@ -12,7 +13,8 @@
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call']
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine']

REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {}
Expand Down Expand Up @@ -220,6 +222,16 @@ def get_num_server():
"""
return _CAPI_DGLRPCGetNumServer()

def set_num_server_per_machine(num_server):
"""Set the total number of server per machine
"""
_CAPI_DGLRPCSetNumServerPerMachine(num_server)

def get_num_server_per_machine():
"""Get the total number of server per machine
"""
return _CAPI_DGLRPCGetNumServerPerMachine()

def incr_msg_seq():
"""Increment the message sequence number and return the old one.
Expand Down Expand Up @@ -517,6 +529,35 @@ def send_request(target, request):
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, server_id)

def send_request_to_machine(target, request):
"""Send one request to the target machine, which will randomly
select a server node to process this request.
The operation is non-blocking -- it does not guarantee the payloads have
reached the target or even have left the sender process. However,
all the payloads (i.e., data and arrays) can be safely freed after this
function returns.
Parameters
----------
target : int
ID of target machine.
request : Request
The request to send.
Raises
------
ConnectionError if there is any problem with the connection.
"""
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, server_id)

def send_response(target, response):
"""Send one response to the target client.
Expand Down Expand Up @@ -644,6 +685,69 @@ def remote_call(target_and_requests, timeout=0):
Responses for each target-request pair. If the request does not have
response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
all_res = [None] * len(target_and_requests)
msgseq2pos = {}
num_res = 0
myrank = get_rank()
for pos, (target, request) in enumerate(target_and_requests):
# send request
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, server_id)
# check if has response
res_cls = get_service_property(service_id)[1]
if res_cls is not None:
num_res += 1
msgseq2pos[msg_seq] = pos
while num_res != 0:
# recv response
msg = recv_rpc_message(timeout)
num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
raise DGLError('Got response message from service ID {}, '
'but no response class is registered.'.format(msg.service_id))
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != myrank:
raise DGLError('Got reponse of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, myrank))
# set response
all_res[msgseq2pos[msg.msg_seq]] = res
return all_res

def remote_call_to_machine(target_and_requests, timeout=0):
"""Invoke registered services on remote machine
(which will ramdom select a server to process the request) and collect responses.
The operation is blocking -- it returns when it receives all responses
or it times out.
If the target server state is available locally, it invokes local computation
to calculate the response.
Parameters
----------
target_and_requests : list[(int, Request)]
A list of requests and the machine they should be sent to.
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
Expand Down
1 change: 1 addition & 0 deletions python/dgl/distributed/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
group_count.append(server_info[3])
if server_info[0] > max_machine_id:
max_machine_id = server_info[0]
rpc.set_num_server_per_machine(group_count[0])
num_machines = max_machine_id+1
rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook)
Expand Down
11 changes: 11 additions & 0 deletions src/rpc/rpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
*rv = RPCContext::ThreadLocal()->num_servers;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0];
*rv = RPCContext::ThreadLocal()->num_servers_per_machine = num_servers;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->num_servers_per_machine;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = (RPCContext::ThreadLocal()->msg_seq)++;
Expand Down
5 changes: 5 additions & 0 deletions src/rpc/rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ struct RPCContext {
*/
int32_t num_servers = 0;

/*!
* \brief Total number of server per machine.
*/
int32_t num_servers_per_machine = 0;

/*!
* \brief Sender communicator.
*/
Expand Down
17 changes: 15 additions & 2 deletions tests/compute/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ def start_client():
for i in range(10):
target_and_requests.append((0, req))
res_list = dgl.distributed.remote_call(target_and_requests)
for res in res_list:
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test send_request_to_machine
dgl.distributed.send_request_to_machine(0, req)
res = dgl.distributed.recv_response()
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test remote_call_to_machine
target_and_requests = []
for i in range(10):
target_and_requests.append((0, req))
res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
for res in res_list:
assert res.hello_str == STR
assert res.integer == INTEGER
Expand Down Expand Up @@ -153,8 +168,6 @@ def test_rpc():
start_client()

if __name__ == '__main__':
test_rank()
test_msg_seq()
test_serialize()
test_rpc_msg()
test_rpc()

0 comments on commit a67ba94

Please sign in to comment.