Skip to content

Commit

Permalink
Send any shape of tensor rather than matrix (dmlc#942)
Browse files Browse the repository at this point in the history
  • Loading branch information
aksnzhy authored Oct 21, 2019
1 parent ae3102d commit ff9f67e
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 19 deletions.
16 changes: 12 additions & 4 deletions examples/mxnet/dis_kvstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@ def start_client(args):
# Initialize data on server
client.init_data(name='embed_0', shape=[10, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_2', shape=[11], init_type='zero')

tensor_id = mx.nd.array([0, 1, 2], dtype='int64')
tensor_data = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])

for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, mx.nd.array([2., 2., 2.]))

tensor_id = mx.nd.array([6, 7, 8], dtype='int64')
for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, mx.nd.array([3., 3., 3.]))

client.barrier()

Expand All @@ -39,16 +42,21 @@ def start_client(args):
new_tensor_0 = client.pull('embed_0', tensor_id)
tensor_id = mx.nd.array([0,1,2,3,4,5,6,7,8,9,10], dtype='int64')
new_tensor_1 = client.pull('embed_1', tensor_id)
new_tensor_2 = client.pull('embed_2', tensor_id)

client.push_all('embed_0', new_tensor_0)
client.push_all('embed_1', new_tensor_1)
client.push_all('embed_2', new_tensor_2)

new_tensor_2 = client.pull_all('embed_0')
new_tensor_3 = client.pull_all('embed_1')
new_tensor_3 = client.pull_all('embed_0')
new_tensor_4 = client.pull_all('embed_1')
new_tensor_5 = client.pull_all('embed_2')
print("embed_0: ")
print(new_tensor_2)
print("embed_1: ")
print(new_tensor_3)
print("embed_1: ")
print(new_tensor_4)
print("embed_2: ")
print(new_tensor_5)

# Shut-down all the servers
if client.get_id() == 0:
Expand Down
17 changes: 13 additions & 4 deletions examples/pytorch/dis_kvstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import time
import argparse
import torch as th

server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')

Expand All @@ -19,18 +20,21 @@ def start_client(args):
# Initialize data on server
client.init_data(name='embed_0', shape=[10, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_2', shape=[11], init_type='zero')

tensor_id = torch.tensor([0, 1, 2])
tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])

for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, th.tensor([2., 2., 2.]))

tensor_id = torch.tensor([6, 7, 8])
for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, th.tensor([3., 3., 3.]))

client.barrier()

Expand All @@ -39,16 +43,21 @@ def start_client(args):
new_tensor_0 = client.pull('embed_0', tensor_id)
tensor_id = torch.tensor([0,1,2,3,4,5,6,7,8,9,10])
new_tensor_1 = client.pull('embed_1', tensor_id)
new_tensor_2 = client.pull('embed_2', tensor_id)

client.push_all('embed_0', new_tensor_0)
client.push_all('embed_1', new_tensor_1)
client.push_all('embed_2', new_tensor_2)

new_tensor_2 = client.pull_all('embed_0')
new_tensor_3 = client.pull_all('embed_1')
new_tensor_3 = client.pull_all('embed_0')
new_tensor_4 = client.pull_all('embed_1')
new_tensor_5 = client.pull_all('embed_2')
print("embed_0:")
print(new_tensor_2)
print("embed_1:")
print(new_tensor_3)
print("embed_1:")
print(new_tensor_4)
print("embed_2:")
print(new_tensor_5)

# Shut-down all the servers
if client.get_id() == 0:
Expand Down
26 changes: 19 additions & 7 deletions python/dgl/contrib/dis_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
def ReadNetworkConfigure(filename):
"""Read networking configuration from file.
The config file is like:
server 172.31.40.143:50050 0
client 172.31.40.143:50051 0
client 172.31.36.140:50051 1
client 172.31.47.147:50051 2
client 172.31.30.180:50051 3
Here we have 1 server node and 4 client nodes.
Parameters
----------
filename : str
Expand Down Expand Up @@ -251,14 +261,16 @@ def _remap_id(self, name, ID):
class KVClient(object):
"""KVClient is used to push/pull tensors to/from KVServer on DGL trainer.
There are three operations supported by KVClient:
There are five operations supported by KVClient:
* init_data(name, shape, low, high): initialize tensor on KVServer
* push(name, id, data): push data to KVServer
* pull(name, id): pull data from KVServer
* init_data(name, shape, init_type, low, high): initialize tensor on KVServer
* push(name, id, data): push sparse data to KVServer given specified IDs
* pull(name, id): pull sparse data from KVServer given specified IDs
* push_all(name, data): push dense data to KVServer
* pull_all(name): pull sense data from KVServer
* shut_down(): shut down all KVServer nodes
DO NOT use KVClient in multiple threads!
Note that, DO NOT use KVClient in multiple threads!
Parameters
----------
Expand All @@ -277,9 +289,9 @@ class KVClient(object):
networking type, e.g., 'socket' (default) or 'mpi'.
"""
def __init__(self, client_id, server_namebook, client_addr, net_type='socket'):
assert client_id >= 0, 'client_id cannot be a nagative number.'
assert client_id >= 0, 'client_id (%d) cannot be a nagative number.' % client_id
assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
assert len(client_addr.split(':')) == 2, 'Incorrect IP format.'
assert len(client_addr.split(':')) == 2, 'Incorrect IP format: %s' % client_addr
# self._data_size is a key-value store where the key is data name
# and value is the size of tensor. It is used to partition data into
# different KVServer nodes.
Expand Down
12 changes: 8 additions & 4 deletions src/graph/network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ char* ArrayMeta::Serialize(int64_t* size) {
buffer_size += sizeof(data_shape_.size());
buffer_size += sizeof(int64_t) * data_shape_.size();
}
// In the future, we should have a better memory management.
// In the future, we should have a better memory management as
// allocating a large chunk of memory can be very expensive.
buffer = new char[buffer_size];
char* pointer = buffer;
Expand Down Expand Up @@ -124,7 +124,7 @@ char* KVStoreMsg::Serialize(int64_t* size) {
buffer_size += sizeof(this->name.size());
buffer_size += this->name.size();
}
// In the future, we should have a better memory management.
// In the future, we should have a better memory management as
// allocating a large chunk of memory can be very expensive.
buffer = new char[buffer_size];
char* pointer = buffer;
Expand Down Expand Up @@ -532,9 +532,13 @@ DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
if (kv_msg->msg_type != kPullMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[2], 2);
CHECK_GE(meta.data_shape_[2], 1);
std::vector<int64_t> vec_shape;
for (int i = 3; i < meta.data_shape_.size(); ++i) {
vec_shape.push_back(meta.data_shape_[i]);
}
kv_msg->data = CreateNDArrayFromRaw(
{meta.data_shape_[3], meta.data_shape_[4]},
vec_shape,
DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0},
recv_data_msg.data);
Expand Down
8 changes: 8 additions & 0 deletions tests/compute/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def start_client():

client.init_data(name='embed_0', shape=[10, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_2', shape=[11], init_type='zero')

tensor_id = torch.tensor([0, 1, 2])
tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
Expand All @@ -38,16 +39,19 @@ def start_client():
for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, torch.tensor([2., 2., 2.]))

tensor_id = torch.tensor([6, 7, 8])
for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, torch.tensor([3., 3., 3.]))

# Pull
tensor_id = torch.tensor([0, 1, 2, 6, 7, 8])
new_tensor_0 = client.pull('embed_0', tensor_id)
new_tensor_1 = client.pull('embed_1', tensor_id)
new_tensor_2 = client.pull('embed_2', tensor_id)

target_tensor = torch.tensor(
[[ 0., 0., 0.],
Expand All @@ -60,6 +64,10 @@ def start_client():
assert torch.equal(new_tensor_0, target_tensor) == True
assert torch.equal(new_tensor_1, target_tensor) == True

target_tensor = tensor.tensor([10., 10., 10., 15., 15., 15.])

assert torch.equal(new_tensor_2, target_tensor) == True

client.shut_down()

if __name__ == '__main__':
Expand Down

0 comments on commit ff9f67e

Please sign in to comment.