Skip to content

Commit

Permalink
[KVStore] Fix bug of get_meta() API (dmlc#1491)
Browse files Browse the repository at this point in the history
* fix get-shape

* update

* update

* update

* update

* update

* fix typo

* update

* update
  • Loading branch information
aksnzhy authored May 2, 2020
1 parent 30b8074 commit 5fc334f
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 19 deletions.
42 changes: 39 additions & 3 deletions python/dgl/contrib/dis_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,18 @@ def start(self):
shape=msg.shape,
c_ptr=None)
_send_kv_msg(self._sender, back_msg, 0)
# Get shape message
elif msg.type == KVMsgType.GET_SHAPE:
data_shape = F.tensor(F.shape(self._data_store[msg.name+'-data-']))
back_msg = KVStoreMsg(
type=KVMsgType.GET_SHAPE_BACK,
rank=self._server_id,
name=msg.name,
id=None,
data=None,
shape=data_shape,
c_ptr=None)
_send_kv_msg(self._sender, back_msg, msg.rank)
# Barrier message
elif msg.type == KVMsgType.BARRIER:
self._barrier_count += 1
Expand Down Expand Up @@ -704,6 +716,7 @@ def __init__(self, server_namebook, queue_size=20*1024*1024*1024, net_type='sock
self._has_data = set()
# This is used to store local data, which can share memory with local KVServer.
self._data_store = {}
self._full_data_shape = {}
self._data_name_list = []
# Server information
self._server_namebook = server_namebook
Expand Down Expand Up @@ -792,10 +805,9 @@ def connect(self):
tensor_name, dtype = self._deserialize_shared_tensor(data)
while True:
if (os.path.exists(tensor_name+'shape-'+str(self._machine_id))):
time.sleep(2) # wait writing finish
break
else:
time.sleep(2) # wait until the file been created
time.sleep(1) # wait until the file been created
shape, data_type = self._read_data_shape_type(tensor_name+'shape-'+str(self._machine_id))
assert data_type == dtype
shared_data = empty_shared_mem(tensor_name, False, shape, dtype)
Expand All @@ -805,6 +817,29 @@ def connect(self):
self._data_name_list.append(tensor_name[0:-6])
self._has_data.add(tensor_name)

# Get full shape of each data
for name in self._data_name_list:
data_shape = list(F.shape(self._data_store[name+'-data-']))
data_shape[0] = 0
msg = KVStoreMsg(
type=KVMsgType.GET_SHAPE,
rank=self._client_id,
name=name,
id=None,
data=None,
shape=None,
c_ptr=None)
# send msg
for m_id in range(self._machine_count):
s_id = m_id * self._group_count
_send_kv_msg(self._sender, msg, s_id)
# recv msg
for m_id in range(self._machine_count):
back_msg = _recv_kv_msg(self._receiver)
assert back_msg.type == KVMsgType.GET_SHAPE_BACK
data_shape[0] += ((F.asnumpy(back_msg.shape)).tolist())[0]
self._full_data_shape[name] = tuple(data_shape)

print("KVClient %d connect to kvstore successfully!" % self.get_id())


Expand Down Expand Up @@ -872,6 +907,7 @@ def init_data(self, name, shape, dtype, target_name):
self._data_store[name+'-data-'] = F.zerocopy_from_dlpack(dlpack)
self._has_data.add(name+'-data-')
self._data_name_list.append(name)
self._full_data_shape[name] = tuple(shape)


def print(self):
Expand Down Expand Up @@ -947,8 +983,8 @@ def get_data_meta(self, name):
assert name + '-data-' in self._has_data, 'Data (%s) does not exist!' % name

data_type = F.dtype(self._data_store[name+'-data-'])
data_shape = F.shape(self._data_store[name+'-data-'])
partition_book = self._data_store[name+'-part-']
data_shape = self._full_data_shape[name]

return (data_type, data_shape, partition_book)

Expand Down
10 changes: 6 additions & 4 deletions python/dgl/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ class KVMsgType(Enum):
PULL_BACK = 5
BARRIER = 6
IP_ID = 7
GET_SHAPE = 8
GET_SHAPE_BACK = 9


KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data shape c_ptr")
Expand Down Expand Up @@ -234,7 +236,7 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank,
msg.name,
tensor_id)
elif msg.type == KVMsgType.INIT:
elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape)
_CAPI_SenderSendKVMsg(
sender,
Expand All @@ -243,7 +245,7 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank,
msg.name,
tensor_shape)
elif msg.type == KVMsgType.IP_ID:
elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
_CAPI_SenderSendKVMsg(
sender,
int(recv_id),
Expand Down Expand Up @@ -296,7 +298,7 @@ def _recv_kv_msg(receiver):
shape=None,
c_ptr=msg_ptr)
return msg
elif msg_type == KVMsgType.INIT:
elif msg_type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
tensor_shape = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgShape(msg_ptr))
msg = KVStoreMsg(
Expand All @@ -308,7 +310,7 @@ def _recv_kv_msg(receiver):
shape=tensor_shape,
c_ptr=msg_ptr)
return msg
elif msg_type == KVMsgType.IP_ID:
elif msg_type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
msg = KVStoreMsg(
type=msg_type,
Expand Down
37 changes: 26 additions & 11 deletions src/graph/network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,17 @@ static void send_kv_message(network::Sender* sender,
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg->msg_type != kFinalMsg &&
kv_msg->msg_type != kBarrierMsg &&
kv_msg->msg_type != kIPIDMsg) {
kv_msg->msg_type != kIPIDMsg &&
kv_msg->msg_type != kGetShapeMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg->msg_type);
if (kv_msg->msg_type != kInitMsg) {
if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->id);
}
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) {
kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->data);
}
if (kv_msg->msg_type != kPullMsg &&
Expand All @@ -523,7 +526,9 @@ static void send_kv_message(network::Sender* sender,
}
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray
if (kv_msg->msg_type != kInitMsg) {
if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg->id->data);
send_id_msg.size = kv_msg->id.GetSize();
Expand All @@ -535,7 +540,9 @@ static void send_kv_message(network::Sender* sender,
}
// Send data NDArray
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) {
kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message send_data_msg;
send_data_msg.data = static_cast<char*>(kv_msg->data->data);
send_data_msg.size = kv_msg->data.GetSize();
Expand Down Expand Up @@ -571,7 +578,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) {
kv_msg->msg_type == kIPIDMsg ||
kv_msg->msg_type == kGetShapeMsg) {
return kv_msg;
}
// Recv ArrayMeta
Expand All @@ -580,7 +588,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray
if (kv_msg->msg_type != kInitMsg) {
if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
Expand All @@ -593,7 +602,8 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
}
// Recv Data NDArray
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg) {
kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
int64_t ndim = meta.data_shape_[2];
Expand Down Expand Up @@ -644,18 +654,23 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
std::string name = args[args_count++];
kv_msg.name = name;
if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg) {
kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.id = args[args_count++];
}
if (kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg) {
kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.data = args[args_count++];
}
if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kPushMsg &&
kv_msg.msg_type != kPullBackMsg) {
kv_msg.msg_type != kPullBackMsg &&
kv_msg.msg_type != kGetShapeMsg) {
kv_msg.shape = args[args_count++];
}
}
Expand Down
10 changes: 9 additions & 1 deletion src/graph/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ enum MessageType {
/*!
* \brief IP and ID msg for KVStore
*/
kIPIDMsg = 7
kIPIDMsg = 7,
/*!
* \brief Get data shape msg for KVStore
*/
kGetShapeMsg = 8,
/*!
* \brief Get data shape back msg for KVStore
*/
kGetShapeBackMsg = 9
};


Expand Down
6 changes: 6 additions & 0 deletions tests/compute/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,32 @@ def start_client():

meta_0 = my_client.get_data_meta('data_0')
assert meta_0[0] == F.float32
assert meta_0[1] == tuple(F.shape(data_0))
assert_array_equal(meta_0[2], partition_0)

meta_1 = my_client.get_data_meta('data_1')
assert meta_1[0] == F.float32
assert meta_1[1] == tuple(F.shape(data_1))
assert_array_equal(meta_1[2], partition_1)

meta_2 = my_client.get_data_meta('data_2')
assert meta_2[0] == F.float32
assert meta_2[1] == tuple(F.shape(data_0))
assert_array_equal(meta_2[2], partition_0)

meta_3 = my_client.get_data_meta('data_3')
assert meta_3[0] == F.int64
assert meta_3[1] == tuple(F.shape(data_3))
assert_array_equal(meta_3[2], partition_0)

meta_4 = my_client.get_data_meta('data_4')
assert meta_4[0] == F.float64
assert meta_4[1] == tuple(F.shape(data_4))
assert_array_equal(meta_3[2], partition_0)

meta_5 = my_client.get_data_meta('data_5')
assert meta_5[0] == F.int32
assert meta_5[1] == tuple(F.shape(data_5))
assert_array_equal(meta_3[2], partition_0)

my_client.push(name='data_0', id_tensor=F.tensor([0, 1, 2]), data_tensor=F.tensor([[1.,1.,1.],[2.,2.,2.],[3.,3.,3.]]))
Expand Down

0 comments on commit 5fc334f

Please sign in to comment.