Skip to content

Commit

Permalink
[KVStore] Support group barrier (dmlc#1880)
Browse files Browse the repository at this point in the history
* udpate

* update

* update

* update

* update

* update

* update

* update

* fix lint

* update

* update

* update

* update

* udpate

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

Co-authored-by: Da Zheng <[email protected]>
  • Loading branch information
aksnzhy and zheng-da authored Jul 30, 2020
1 parent 444becf commit 5b515cf
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 33 deletions.
110 changes: 89 additions & 21 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,26 +193,26 @@ class BarrierRequest(rpc.Request):
Parameters
----------
msg : string
string msg
role : string
client role
"""
def __init__(self, msg):
self.msg = msg
def __init__(self, role):
self.role = role

def __getstate__(self):
return self.msg
return self.role

def __setstate__(self, state):
self.msg = state
self.role = state

def process_request(self, server_state):
assert self.msg == BARRIER_MSG
kv_store = server_state.kv_store
kv_store.barrier_count = kv_store.barrier_count + 1
if kv_store.barrier_count == kv_store.num_clients:
kv_store.barrier_count = 0
count = kv_store.barrier_count[self.role]
kv_store.barrier_count[self.role] = count + 1
if kv_store.barrier_count[self.role] == len(kv_store.role[self.role]):
kv_store.barrier_count[self.role] = 0
res_list = []
for target_id in range(kv_store.num_clients):
for target_id in kv_store.role[self.role]:
res_list.append((target_id, BarrierResponse(BARRIER_MSG)))
return res_list
return None
Expand Down Expand Up @@ -506,6 +506,52 @@ def process_request(self, server_state):
res = DeleteDataResponse(DELETE_MSG)
return res

REGISTER_ROLE = 901241
ROLE_MSG = "Register_Role"

class RegisterRoleResponse(rpc.Response):
"""Send a confirmation signal (just a short string message)
of RegisterRoleRequest to client.
"""
def __init__(self, msg):
self.msg = msg

def __getstate__(self):
return self.msg

def __setstate__(self, state):
self.msg = state

class RegisterRoleRequest(rpc.Request):
"""Send client id and role to server
Parameters
----------
client_id : int
ID of client
role : str
role of client
"""
def __init__(self, client_id, role):
self.client_id = client_id
self.role = role

def __getstate__(self):
return self.client_id, self.role

def __setstate__(self, state):
self.client_id, self.role = state

def process_request(self, server_state):
kv_store = server_state.kv_store
role = kv_store.role
if self.role not in role:
role[self.role] = set()
kv_store.barrier_count[self.role] = 0
role[self.role].add(self.client_id)
res = RegisterRoleResponse(ROLE_MSG)
return res

############################ KVServer ###############################

def default_push_handler(target, name, id_tensor, data_tensor):
Expand Down Expand Up @@ -604,6 +650,9 @@ def __init__(self, server_id, ip_config, num_clients):
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store the partition information with specified data name
Expand All @@ -620,10 +669,12 @@ def __init__(self, server_id, ip_config, num_clients):
# We assume partition_id is equal to machine_id
self._part_id = self._machine_id
self._num_clients = num_clients
self._barrier_count = 0
self._barrier_count = {}
# push and pull handler
self._push_handlers = {}
self._pull_handlers = {}
# store client role
self._role = {}

@property
def server_id(self):
Expand Down Expand Up @@ -665,6 +716,11 @@ def push_handlers(self):
"""Get push handler"""
return self._push_handlers

@property
def role(self):
"""Get client role"""
return self._role

@property
def pull_handlers(self):
"""Get pull handler"""
Expand Down Expand Up @@ -748,8 +804,10 @@ class KVClient(object):
----------
ip_config : str
Path of IP configuration file.
role : str
We can set different role for kvstore.
"""
def __init__(self, ip_config):
def __init__(self, ip_config, role='default'):
assert rpc.get_rank() != -1, 'Please invoke rpc.connect_to_server() \
before creating KVClient.'
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config
Expand Down Expand Up @@ -784,6 +842,9 @@ def __init__(self, ip_config):
rpc.register_service(DELETE_DATA,
DeleteDataRequest,
DeleteDataResponse)
rpc.register_service(REGISTER_ROLE,
RegisterRoleRequest,
RegisterRoleResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store the partition information with specified data name
Expand All @@ -805,12 +866,23 @@ def __init__(self, ip_config):
# push and pull handler
self._pull_handlers = {}
self._push_handlers = {}
# register role on server-0
self._role = role
request = RegisterRoleRequest(self._client_id, self._role)
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == ROLE_MSG

@property
def client_id(self):
"""Get client ID"""
return self._client_id

@property
def role(self):
"""Get client role"""
return self._role

@property
def machine_id(self):
"""Get machine ID"""
Expand All @@ -821,14 +893,10 @@ def barrier(self):
This API will be blocked untill all the clients invoke this API.
"""
request = BarrierRequest(BARRIER_MSG)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == BARRIER_MSG
request = BarrierRequest(self._role)
rpc.send_request(0, request)
response = rpc.recv_response()
assert response.msg == BARRIER_MSG

def register_push_handler(self, name, func):
"""Register UDF push function.
Expand Down
55 changes: 54 additions & 1 deletion python/dgl/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'get_num_client', 'set_num_client']
'get_num_client', 'set_num_client', 'client_barrier']

REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {}
Expand Down Expand Up @@ -899,6 +899,13 @@ def recv_rpc_message(timeout=0):
_CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg

def client_barrier():
"""Barrier all client processes"""
req = ClientBarrierRequest()
send_request(0, req)
res = recv_response()
assert res.msg == 'barrier'

def finalize_server():
"""Finalize resources of current server
"""
Expand Down Expand Up @@ -1068,4 +1075,50 @@ def process_request(self, server_state):
res = GetNumberClientsResponse(get_num_client())
return res

CLIENT_BARRIER = 22454

class ClientBarrierResponse(Response):
"""Send the barrier confirmation to client
Parameters
----------
msg : str
string msg
"""
def __init__(self, msg='barrier'):
self.msg = msg

def __getstate__(self):
return self.msg

def __setstate__(self, state):
self.msg = state

class ClientBarrierRequest(Request):
"""Send the barrier information to server
Parameters
----------
msg : str
string msg
"""
def __init__(self, msg='barrier'):
self.msg = msg

def __getstate__(self):
return self.msg

def __setstate__(self, state):
self.msg = state

def process_request(self, server_state):
_CAPI_DGLRPCSetBarrierCount(_CAPI_DGLRPCGetBarrierCount()+1)
if _CAPI_DGLRPCGetBarrierCount() == get_num_client():
_CAPI_DGLRPCSetBarrierCount(0)
res_list = []
for target_id in range(get_num_client()):
res_list.append((target_id, ClientBarrierResponse()))
return res_list
return None

_init_api("dgl.distributed.rpc")
4 changes: 4 additions & 0 deletions python/dgl/distributed/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.register_ctrl_c()
server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook)
Expand Down Expand Up @@ -199,6 +202,7 @@ def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
rpc.client_barrier()
shutdown_servers()
finalize_client()
atexit.unregister(exit_client)
Expand Down
3 changes: 3 additions & 0 deletions python/dgl/distributed/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.register_service(rpc.GET_NUM_CLIENT,
rpc.GetNumberClientsRequest,
rpc.GetNumberClientsResponse)
rpc.register_service(rpc.CLIENT_BARRIER,
rpc.ClientBarrierRequest,
rpc.ClientBarrierResponse)
rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config)
machine_id = server_namebook[server_id][0]
Expand Down
11 changes: 11 additions & 0 deletions src/rpc/rpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
RPCContext::ThreadLocal()->msg_seq = msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->barrier_count;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t count = args[0];
RPCContext::ThreadLocal()->barrier_count = count;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->machine_id;
Expand Down
5 changes: 5 additions & 0 deletions src/rpc/rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ struct RPCContext {
*/
int32_t num_clients = 0;

/*!
* \brief Current barrier count
*/
int32_t barrier_count = 0;

/*!
* \brief Total number of server per machine.
*/
Expand Down
Loading

0 comments on commit 5b515cf

Please sign in to comment.