Skip to content

Commit

Permalink
[KVStore] make pull/push handler per tensor. (dmlc#1646)
Browse files Browse the repository at this point in the history
* make pull/push handler per tensor.

* update.
  • Loading branch information
zheng-da authored Jun 16, 2020
1 parent 41349dc commit e8a56dc
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 49 deletions.
116 changes: 68 additions & 48 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def __setstate__(self, state):

def process_request(self, server_state):
kv_store = server_state.kv_store
if kv_store.part_policy.__contains__(self.name) is False:
if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name)
if kv_store.data_store.__contains__(self.name) is False:
if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
data = kv_store.pull_handler(kv_store.data_store, self.name, local_id)
data = kv_store.pull_handlers[self.name](kv_store.data_store, self.name, local_id)
res = PullResponse(kv_store.server_id, data)
return res

Expand Down Expand Up @@ -93,12 +93,13 @@ def __setstate__(self, state):

def process_request(self, server_state):
kv_store = server_state.kv_store
if kv_store.part_policy.__contains__(self.name) is False:
if self.name not in kv_store.part_policy:
raise RuntimeError("KVServer cannot find partition policy with name: %s" % self.name)
if kv_store.data_store.__contains__(self.name) is False:
if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name)
local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)
kv_store.push_handler(kv_store.data_store, self.name, local_id, self.data_tensor)
kv_store.push_handlers[self.name](kv_store.data_store, self.name,
local_id, self.data_tensor)

INIT_DATA = 901233
INIT_MSG = 'Init'
Expand Down Expand Up @@ -244,18 +245,19 @@ class RegisterPullHandlerRequest(rpc.Request):
pull_func : func
UDF pull handler
"""
def __init__(self, pull_func):
def __init__(self, name, pull_func):
self.name = name
self.pull_func = pull_func

def __getstate__(self):
return self.pull_func
return self.name, self.pull_func

def __setstate__(self, state):
self.pull_func = state
self.name, self.pull_func = state

def process_request(self, server_state):
kv_store = server_state.kv_store
kv_store.pull_handler = self.pull_func
kv_store.pull_handlers[self.name] = self.pull_func
res = RegisterPullHandlerResponse(REGISTER_PULL_MSG)
return res

Expand Down Expand Up @@ -288,18 +290,19 @@ class RegisterPushHandlerRequest(rpc.Request):
push_func : func
UDF push handler
"""
def __init__(self, push_func):
def __init__(self, name, push_func):
self.name = name
self.push_func = push_func

def __getstate__(self):
return self.push_func
return self.name, self.push_func

def __setstate__(self, state):
self.push_func = state
self.name, self.push_func = state

def process_request(self, server_state):
kv_store = server_state.kv_store
kv_store.push_handler = self.push_func
kv_store.push_handlers[self.name] = self.push_func
res = RegisterPushHandlerResponse(REGISTER_PUSH_MSG)
return res

Expand Down Expand Up @@ -569,8 +572,8 @@ def __init__(self, server_id, ip_config, num_clients):
self._num_clients = num_clients
self._barrier_count = 0
# push and pull handler
self._push_handler = default_push_handler
self._pull_handler = default_pull_handler
self._push_handlers = {}
self._pull_handlers = {}

@property
def server_id(self):
Expand Down Expand Up @@ -608,24 +611,14 @@ def part_id(self):
return self._part_id

@property
def push_handler(self):
def push_handlers(self):
"""Get push handler"""
return self._push_handler
return self._push_handlers

@property
def pull_handler(self):
def pull_handlers(self):
"""Get pull handler"""
return self._pull_handler

@pull_handler.setter
def pull_handler(self, pull_handler):
"""Set pull handler"""
self._pull_handler = pull_handler

@push_handler.setter
def push_handler(self, push_handler):
"""Set push handler"""
self._push_handler = push_handler
return self._pull_handlers

def is_backup_server(self):
"""Return True if current server is a backup server.
Expand Down Expand Up @@ -667,6 +660,8 @@ def init_data(self, name, policy_str, data_tensor=None):
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name][:] = data_tensor[:]
self._part_policy[name] = self.find_policy(policy_str)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler

def find_policy(self, policy_str):
"""Find a partition policy from existing policy set
Expand Down Expand Up @@ -748,8 +743,8 @@ def __init__(self, ip_config):
self._part_id = self._machine_id
self._main_server_id = self._machine_id * self._group_count
# push and pull handler
self._pull_handler = default_pull_handler
self._push_handler = default_push_handler
self._pull_handlers = {}
self._push_handlers = {}

@property
def client_id(self):
Expand All @@ -775,48 +770,69 @@ def barrier(self):
response = rpc.recv_response()
assert response.msg == BARRIER_MSG

def register_push_handler(self, func):
"""Register UDF push function on server.
def register_push_handler(self, name, func):
"""Register UDF push function.
This UDF is triggered for every push. The signature of the UDF is
client_0 will send this request to all servers, and the other
clients will just invoke the barrier() api.
```
def push_handler(data_store, name, local_offset, data)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition. `data` is the new data
to be written.
Parameters
----------
func : UDF push function
name : str
The name of the tensor
func : callable
The function to be called.
"""
if self._client_id == 0:
request = RegisterPushHandlerRequest(func)
request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_count):
response = rpc.recv_response()
assert response.msg == REGISTER_PUSH_MSG
self._push_handler = func
self._push_handlers[name] = func
self.barrier()

def register_pull_handler(self, func):
"""Register UDF pull function on server.
def register_pull_handler(self, name, func):
"""Register UDF pull function.
client_0 will send this request to all servers, and the other
clients will just invoke the barrier() api.
This UDF is triggered for every pull. The signature of the UDF is
```
def pull_handler(data_store, name, local_offset)
```
`data_store` is a dict that contains all tensors in the kvstore. `name` is the name
of the tensor where new data is pushed to. `local_offset` is the offset where new
data should be written in the tensor in the local partition.
Parameters
----------
func : UDF pull function
name : str
The name of the tensor
func : callable
The function to be called.
"""
if self._client_id == 0:
request = RegisterPullHandlerRequest(func)
request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes
for server_id in range(self._server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(self._server_namebook):
response = rpc.recv_response()
assert response.msg == REGISTER_PULL_MSG
self._pull_handler = func
self._pull_handlers[name] = func
self.barrier()

def init_data(self, name, shape, dtype, policy_str, partition_book, init_func):
Expand Down Expand Up @@ -887,6 +903,8 @@ def init_data(self, name, shape, dtype, policy_str, partition_book, init_func):
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_name_list.add(name)
self._full_data_shape[name] = tuple(shape)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler

def map_shared_data(self, partition_book):
"""Mapping shared-memory tensor from server to client.
Expand All @@ -907,6 +925,8 @@ def map_shared_data(self, partition_book):
dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
# Get full data shape across servers
for name, meta in response.meta.items():
if name not in self._data_name_list:
Expand Down Expand Up @@ -995,7 +1015,7 @@ def push(self, name, id_tensor, data_tensor):
rpc.send_request_to_machine(machine_idx, request)
start += count[idx]
if local_id is not None: # local push
self._push_handler(self._data_store, name, local_id, local_data)
self._push_handlers[name](self._data_store, name, local_id, local_data)

def pull(self, name, id_tensor):
"""Pull message from KVServer.
Expand Down Expand Up @@ -1043,7 +1063,7 @@ def pull(self, name, id_tensor):
# recv response
response_list = []
if local_id is not None: # local pull
local_data = self._pull_handler(self._data_store, name, local_id)
local_data = self._pull_handlers[name](self._data_store, name, local_id)
server_id = self._main_server_id
local_response = PullResponse(server_id, local_data)
response_list.append(local_response)
Expand Down
4 changes: 3 additions & 1 deletion tests/distributed/test_new_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def start_client():
res = kvclient.pull(name='data_2', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# Register new push handler
kvclient.register_push_handler(udf_push)
kvclient.register_push_handler('data_0', udf_push)
kvclient.register_push_handler('data_1', udf_push)
kvclient.register_push_handler('data_2', udf_push)
# Test push and pull
kvclient.push(name='data_0',
id_tensor=id_tensor,
Expand Down

0 comments on commit e8a56dc

Please sign in to comment.