Skip to content

Commit

Permalink
[KVstore] Fast-pull (dmlc#1446)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint
  • Loading branch information
aksnzhy authored Apr 16, 2020
1 parent ad222fb commit 338f24c
Show file tree
Hide file tree
Showing 3 changed files with 399 additions and 184 deletions.
196 changes: 106 additions & 90 deletions python/dgl/contrib/dis_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..network import _receiver_wait, _sender_connect
from ..network import _send_kv_msg, _recv_kv_msg
from ..network import _clear_kv_msg
from ..network import _fast_pull
from ..network import KVMsgType, KVStoreMsg

from .. import backend as F
Expand Down Expand Up @@ -146,6 +147,11 @@ def __init__(self, server_id, server_namebook, num_client, queue_size=20*1024*10
self._open_file_list = []
# record for total message count
self._msg_count = 0
# user-defined push handler
self._udf_push_handler = None
self._udf_push_param = None
# user-defined pull handler
self._udf_pull_handler = None


def __del__(self):
Expand Down Expand Up @@ -317,6 +323,8 @@ def start(self):
# Get connected with all client nodes
_receiver_wait(self._receiver, self._ip, self._port, self._client_count)

print("%d clients connected!" % self._client_count)

# recv client address information
addr_list = []
for i in range(self._client_count):
Expand Down Expand Up @@ -378,14 +386,20 @@ def start(self):
local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else:
local_id = msg.id
self._push_handler(msg.name+'-data-', local_id, msg.data, self._data_store)
if self._udf_push_handler is not None:
self._udf_push_handler(msg.name+'-data-', local_id, msg.data, self._data_store, self._udf_push_param)
else:
self._default_push_handler(msg.name+'-data-', local_id, msg.data, self._data_store)
# Pull message
elif msg.type == KVMsgType.PULL:
if (msg.name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else:
local_id = msg.id
res_tensor = self._pull_handler(msg.name+'-data-', local_id, self._data_store)
if self._udf_pull_handler is not None:
res_tensor = self._udf_pull_handler(msg.name+'-data-', local_id, self._data_store)
else:
res_tensor = self._default_pull_handler(msg.name+'-data-', local_id, self._data_store)
back_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=self._server_id,
Expand Down Expand Up @@ -500,7 +514,7 @@ def _read_data_shape(self, filename):
return data_shape


def _push_handler(self, name, ID, data, target):
def _default_push_handler(self, name, ID, data, target):
"""Default handler for PUSH message.
On default, _push_handler perform update operation for the tensor.
Expand All @@ -519,7 +533,7 @@ def _push_handler(self, name, ID, data, target):
target[name][ID] = data


def _pull_handler(self, name, ID, target):
def _default_pull_handler(self, name, ID, target):
"""Default handler for PULL operation.
On default, _pull_handler perform get operation for the tensor.
Expand Down Expand Up @@ -582,6 +596,7 @@ def __init__(self, server_namebook, queue_size=20*1024*1024*1024, net_type='sock
self._server_namebook = server_namebook
self._server_count = len(server_namebook)
self._group_count = server_namebook[0][3]
self._machine_count = int(self._server_count / self._group_count)
# client ID will be assign by server after connecting to server
self._client_id = -1
# Get local machine id via server_namebook
Expand All @@ -593,6 +608,11 @@ def __init__(self, server_namebook, queue_size=20*1024*1024*1024, net_type='sock
self._open_file_list = []
# Gargage_collection
self._garbage_msg = []
# User-defined pull handler
self._udf_pull_handler = None
# User-defined push handler
self._udf_push_handler = None
self._udf_push_param = None
# Used load-balance
random.seed(time.time())

Expand Down Expand Up @@ -812,7 +832,10 @@ def push(self, name, id_tensor, data_tensor):
start += count[idx]

if local_id is not None: # local push
self._push_handler(name+'-data-', local_id, local_data, self._data_store)
if self._udf_push_handler is not None:
self._udf_push_handler(name+'-data-', local_id, local_data, self._data_store, self._udf_push_param)
else:
self._default_push_handler(name+'-data-', local_id, local_data, self._data_store)


def pull(self, name, id_tensor):
Expand All @@ -833,73 +856,88 @@ def pull(self, name, id_tensor):
assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.'

for msg in self._garbage_msg:
_clear_kv_msg(msg)
self._garbage_msg = []

# partition data
machine_id = self._data_store[name+'-part-'][id_tensor]
# sort index by machine id
sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))
back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))
id_tensor = id_tensor[sorted_id]
machine, count = np.unique(F.asnumpy(machine_id), return_counts=True)
# pull data from server by order
start = 0
pull_count = 0
local_id = None
for idx in range(len(machine)):
end = start + count[idx]
if start == end: # No data for target machine
continue
partial_id = id_tensor[start:end]
if machine[idx] == self._machine_id: # local pull
# Note that DO NOT pull local data right now because we can overlap
# communication-local_pull here
if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id]
else:
local_id = partial_id
else: # pull data from remote server
msg = KVStoreMsg(
type=KVMsgType.PULL,
rank=self._client_id,
if self._udf_pull_handler is None: # Use fast-pull
g2l = None
if name+'-g2l-' in self._data_store:
g2l = self._data_store[name+'-g2l-']
return _fast_pull(name, id_tensor,
self._machine_count,
self._group_count,
self._machine_id,
self._client_id,
self._data_store[name+'-part-'],
g2l,
self._data_store[name+'-data-'],
self._sender,
self._receiver)
else:
for msg in self._garbage_msg:
_clear_kv_msg(msg)
self._garbage_msg = []

# partition data
machine_id = self._data_store[name+'-part-'][id_tensor]
# sort index by machine id
sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))
back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))
id_tensor = id_tensor[sorted_id]
machine, count = np.unique(F.asnumpy(machine_id), return_counts=True)
# pull data from server by order
start = 0
pull_count = 0
local_id = None
for idx in range(len(machine)):
end = start + count[idx]
if start == end: # No data for target machine
continue
partial_id = id_tensor[start:end]
if machine[idx] == self._machine_id: # local pull
# Note that DO NOT pull local data right now because we can overlap
# communication-local_pull here
if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id]
else:
local_id = partial_id
else: # pull data from remote server
msg = KVStoreMsg(
type=KVMsgType.PULL,
rank=self._client_id,
name=name,
id=partial_id,
data=None,
c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
_send_kv_msg(self._sender, msg, s_id)
pull_count += 1

start += count[idx]

msg_list = []
if local_id is not None: # local pull
local_data = self._udf_pull_handler(name+'-data-', local_id, self._data_store)
s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1)
local_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=s_id,
name=name,
id=partial_id,
data=None,
id=None,
data=local_data,
c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
_send_kv_msg(self._sender, msg, s_id)
pull_count += 1

start += count[idx]

msg_list = []
if local_id is not None: # local pull
local_data = self._pull_handler(name+'-data-', local_id, self._data_store)
s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1)
local_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=s_id,
name=name,
id=None,
data=local_data,
c_ptr=None)
msg_list.append(local_msg)
self._garbage_msg.append(local_msg)
msg_list.append(local_msg)
self._garbage_msg.append(local_msg)

# wait message from server nodes
for idx in range(pull_count):
remote_msg = _recv_kv_msg(self._receiver)
msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg)
# wait message from server nodes
for idx in range(pull_count):
remote_msg = _recv_kv_msg(self._receiver)
msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg)

# sort msg by server id and merge tensor together
msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
# sort msg by server id and merge tensor together
msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)

return data_tensor[back_sorted_id] # return data with original index order
return data_tensor[back_sorted_id] # return data with original index order


def barrier(self):
Expand Down Expand Up @@ -1082,7 +1120,7 @@ def _takeId(self, elem):
return elem.rank


def _push_handler(self, name, ID, data, target):
def _default_push_handler(self, name, ID, data, target):
"""Default handler for PUSH message.
On default, _push_handler perform update operation for the tensor.
Expand All @@ -1099,26 +1137,4 @@ def _push_handler(self, name, ID, data, target):
self._data_store
"""
target[name][ID] = data


def _pull_handler(self, name, ID, target):
"""Default handler for PULL operation.
On default, _pull_handler perform get operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list.
target : dict of data
self._data_store
Return
------
tensor
a tensor with the same row size of ID.
"""
return target[name][ID]

53 changes: 53 additions & 0 deletions python/dgl/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,56 @@ def _clear_kv_msg(msg):
F.sync()
if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr)


def _fast_pull(name, id_tensor,
machine_count, group_count, machine_id, client_id,
partition_book, g2l, local_data,
sender, receiver):
""" Pull message
Parameters
----------
name : str
data name string
id_tensor : tensor
tensor of ID
machine_count : int
count of total machine
group_count : int
count of server group
machine_id : int
current machine id
client_id : int
current client ID
partition_book : tensor
tensor of partition book
g2l : tensor
tensor of global2local
local_data : tensor
tensor of local shared data
sender : ctypes.c_void_p
C Sender handle
receiver : ctypes.c_void_p
C Receiver handle
Return
------
tensor
target tensor
"""
if g2l is not None:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'has_g2l',
F.zerocopy_to_dgl_ndarray(g2l))
else:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'no_g2l')

return F.zerocopy_from_dgl_ndarray(res_tensor)
Loading

0 comments on commit 338f24c

Please sign in to comment.