From cccde032f42f9351bc7b7cb4e36928c212f5c7ce Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Mon, 4 Nov 2019 18:27:00 +0800 Subject: [PATCH] [kvstore] Performance improvement for distributed kvstore (#972) * Performance improvment for distributed kvstore * update * update --- examples/mxnet/dis_kvstore/client.py | 58 +++-- examples/pytorch/dis_kvstore/client.py | 57 +++-- python/dgl/contrib/dis_kvstore.py | 296 ++++++++----------------- tests/compute/test_kvstore.py | 90 +++----- third_party/dmlc-core | 2 +- 5 files changed, 172 insertions(+), 331 deletions(-) diff --git a/examples/mxnet/dis_kvstore/client.py b/examples/mxnet/dis_kvstore/client.py index 7be4f2f6eda4..3786329ce92f 100644 --- a/examples/mxnet/dis_kvstore/client.py +++ b/examples/mxnet/dis_kvstore/client.py @@ -17,46 +17,42 @@ def start_client(args): client.connect() # Initialize data on server - client.init_data(name='embed_0', shape=[10, 3], init_type='zero') - client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) - client.init_data(name='embed_2', shape=[11], init_type='zero') + client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero') + client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero') + client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0) + client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0) - tensor_id = mx.nd.array([0, 1, 2], dtype='int64') - tensor_data = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) + data_0 = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) + data_1 = mx.nd.array([0., 1., 2.]) for i in range(5): - client.push('embed_0', tensor_id, tensor_data) - client.push('embed_1', tensor_id, tensor_data) - client.push('embed_2', tensor_id, mx.nd.array([2., 2., 2.])) - - tensor_id = mx.nd.array([6, 7, 8], dtype='int64') - for i in range(5): - client.push('embed_0', tensor_id, tensor_data) - client.push('embed_1', tensor_id, tensor_data) - client.push('embed_2', tensor_id, mx.nd.array([3., 3., 3.])) + client.push(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_0) + client.push(name='embed_0', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_0) + client.push(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1) + client.push(name='embed_1', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_1) client.barrier() if client.get_id() == 0: - tensor_id = mx.nd.array([0,1,2,3,4,5,6,7,8,9], dtype='int64') - new_tensor_0 = client.pull('embed_0', tensor_id) - tensor_id = mx.nd.array([0,1,2,3,4,5,6,7,8,9,10], dtype='int64') - new_tensor_1 = client.pull('embed_1', tensor_id) - new_tensor_2 = client.pull('embed_2', tensor_id) + client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) + server_id, new_tensor_0 = client.pull_wait() + assert server_id == 0 + client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64')) + server_id, new_tensor_1 = client.pull_wait() + assert server_id == 1 + + print("embed_0:") + print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0)) - client.push_all('embed_0', new_tensor_0) - client.push_all('embed_1', new_tensor_1) - client.push_all('embed_2', new_tensor_2) + client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) + server_id, new_tensor_0 = client.pull_wait() + assert server_id == 0 + client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64')) + server_id, new_tensor_1 = client.pull_wait() + assert server_id == 1 - new_tensor_3 = client.pull_all('embed_0') - new_tensor_4 = client.pull_all('embed_1') - new_tensor_5 = client.pull_all('embed_2') - print("embed_0: ") - print(new_tensor_3) - print("embed_1: ") - print(new_tensor_4) - print("embed_2: ") - print(new_tensor_5) + print("embed_1:") + print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0)) # Shut-down all the servers if client.get_id() == 0: diff --git a/examples/pytorch/dis_kvstore/client.py b/examples/pytorch/dis_kvstore/client.py index 4eddb70bd032..aa45d472b6be 100644 --- a/examples/pytorch/dis_kvstore/client.py +++ b/examples/pytorch/dis_kvstore/client.py @@ -1,7 +1,6 @@ # This is a simple pytorch client demo shows how to use DGL distributed kvstore. # In this demo, we initialize two embeddings on server and push/pull data to/from it. import dgl -import torch import time import argparse import torch as th @@ -18,46 +17,42 @@ def start_client(args): client.connect() # Initialize data on server - client.init_data(name='embed_0', shape=[10, 3], init_type='zero') - client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) - client.init_data(name='embed_2', shape=[11], init_type='zero') + client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero') + client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero') + client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0) + client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0) - tensor_id = torch.tensor([0, 1, 2]) - tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) + data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) + data_1 = th.tensor([0., 1., 2.]) for i in range(5): - client.push('embed_0', tensor_id, tensor_data) - client.push('embed_1', tensor_id, tensor_data) - client.push('embed_2', tensor_id, th.tensor([2., 2., 2.])) - - tensor_id = torch.tensor([6, 7, 8]) - for i in range(5): - client.push('embed_0', tensor_id, tensor_data) - client.push('embed_1', tensor_id, tensor_data) - client.push('embed_2', tensor_id, th.tensor([3., 3., 3.])) + client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0) + client.push(name='embed_0', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_0) + client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1) + client.push(name='embed_1', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_1) client.barrier() if client.get_id() == 0: - tensor_id = torch.tensor([0,1,2,3,4,5,6,7,8,9]) - new_tensor_0 = client.pull('embed_0', tensor_id) - tensor_id = torch.tensor([0,1,2,3,4,5,6,7,8,9,10]) - new_tensor_1 = client.pull('embed_1', tensor_id) - new_tensor_2 = client.pull('embed_2', tensor_id) - - client.push_all('embed_0', new_tensor_0) - client.push_all('embed_1', new_tensor_1) - client.push_all('embed_2', new_tensor_2) + client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) + server_id, new_tensor_0 = client.pull_wait() + assert server_id == 0 + client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5])) + server_id, new_tensor_1 = client.pull_wait() + assert server_id == 1 - new_tensor_3 = client.pull_all('embed_0') - new_tensor_4 = client.pull_all('embed_1') - new_tensor_5 = client.pull_all('embed_2') print("embed_0:") - print(new_tensor_3) + print(th.cat([new_tensor_0, new_tensor_1])) + + client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) + server_id, new_tensor_0 = client.pull_wait() + assert server_id == 0 + client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5])) + server_id, new_tensor_1 = client.pull_wait() + assert server_id == 1 + print("embed_1:") - print(new_tensor_4) - print("embed_2:") - print(new_tensor_5) + print(th.cat([new_tensor_0, new_tensor_1])) # Shut-down all the servers if client.get_id() == 0: diff --git a/python/dgl/contrib/dis_kvstore.py b/python/dgl/contrib/dis_kvstore.py index f558b926c688..eb7c562b5465 100644 --- a/python/dgl/contrib/dis_kvstore.py +++ b/python/dgl/contrib/dis_kvstore.py @@ -53,15 +53,10 @@ class KVServer(object): """KVServer is a lightweight key-value store service for DGL distributed training. In practice, developers can use KVServer to hold large-scale graph features or - graph embeddings across machines in a distributed setting or storing them in one standalone - machine with big memory capability. DGL KVServer uses a very simple range-partition scheme to - partition data into different KVServer nodes. For example, if the total embedding size is 200 and - we have two KVServer nodes, the data (0~99) will be stored in kvserver_0, and the data (100~199) will - be stored in kvserver_1. + graph embeddings across machines in a distributed setting. User can re-wriite _push_handler + and _pull_handler to support flexibale models. - For KVServer, user can re-wriite UDF function for _push_handler and _pull_handler. - - DO NOT use KVServer in multiple threads! + Note that, DO NOT use KVServer in multiple threads! Parameters ---------- @@ -77,14 +72,14 @@ class KVServer(object): server_addr : str IP address of current KVServer node, e.g., '127.0.0.1:50051' net_type : str - networking type, e.g., 'socket' (default) or 'mpi'. + networking type, e.g., 'socket' (default) or 'mpi' (do not support yet). """ def __init__(self, server_id, client_namebook, server_addr, net_type='socket'): - assert server_id >= 0, 'server_id cannot be a negative number.' + assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert len(client_namebook) > 0, 'client_namebook cannot be empty.' - assert len(server_addr.split(':')) == 2, 'Incorrect IP format.' + assert len(server_addr.split(':')) == 2, 'Incorrect IP format: %s' % server_addr self._is_init = set() # Contains tensor name - self._data_store = {} # Key is name string and value is tensor + self._data_store = {} # Key is name (string) and value is data (tensor) self._barrier_count = 0; self._server_id = server_id self._client_namebook = client_namebook @@ -130,13 +125,9 @@ def start(self): high=row_1[1]) self._is_init.add(msg.name) elif msg.type == KVMsgType.PUSH: - # convert global ID to local ID - local_id = self._remap_id(msg.name, msg.id) - self._push_handler(msg.name, local_id, msg.data) + self._push_handler(msg.name, msg.id, msg.data) elif msg.type == KVMsgType.PULL: - # convert global ID to local ID - local_id = self._remap_id(msg.name, msg.id) - res_tensor = self._pull_handler(msg.name, local_id) + res_tensor = self._pull_handler(msg.name, msg.id) back_msg = KVStoreMsg( type=KVMsgType.PULL_BACK, rank=self._server_id, @@ -157,7 +148,7 @@ def start(self): _send_kv_msg(self._sender, back_msg, i) self._barrier_count = 0 elif msg.type == KVMsgType.FINAL: - print("Exit KVStore service, server ID: %d" % self.get_id()) + print("Exit KVStore service, server ID: %d" % self._server_id) break # exit loop else: raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value) @@ -204,7 +195,7 @@ def _init_data(self, name, shape, init_type, low, high): raise RuntimeError('Unknown initial method') def _push_handler(self, name, ID, data): - """User-defined handler for PUSH message. + """Default handler for PUSH message. On default, _push_handler perform ADD operation for the tensor. @@ -213,15 +204,15 @@ def _push_handler(self, name, ID, data): name : str data name ID : tensor (mx.ndarray or torch.tensor) - a vector storing the IDs that has been re-mapped to local id. + a vector storing the ID list. data : tensor (mx.ndarray or torch.tensor) - a matrix with the same row size of id + a tensor with the same row size of id """ - for idx in range(ID.shape[0]): # For each row + for idx in range(ID.shape[0]): self._data_store[name][ID[idx]] += data[idx] def _pull_handler(self, name, ID): - """User-defined handler for PULL operation. + """Default handler for PULL operation. On default, _pull_handler perform gather_row() operation for the tensor. @@ -235,40 +226,26 @@ def _pull_handler(self, name, ID): Return ------ tensor - a matrix with the same row size of ID + a tensor with the same row size of ID """ new_tensor = F.gather_row(self._data_store[name], ID) return new_tensor - def _remap_id(self, name, ID): - """Re-mapping global-ID to local-ID. - - Parameters - ---------- - name : str - data name - ID : tensor (mx.ndarray or torch.tensor) - a vector storing the global data ID - - Return - ------ - tensor - re-mapped lcoal ID - """ - row_size = self._data_store[name].shape[0] - return ID % row_size - class KVClient(object): """KVClient is used to push/pull tensors to/from KVServer on DGL trainer. There are five operations supported by KVClient: - * init_data(name, shape, init_type, low, high): initialize tensor on KVServer - * push(name, id, data): push sparse data to KVServer given specified IDs - * pull(name, id): pull sparse data from KVServer given specified IDs - * push_all(name, data): push dense data to KVServer - * pull_all(name): pull sense data from KVServer - * shut_down(): shut down all KVServer nodes + * init_data(name, server_id, shape, init_type, low, high): + initialize tensor on target KVServer. + * push(name, server_id, id_tensor, data_tensor): + push sparse data to KVServer given specified ID. + * pull(name, server_id, id_tensor): + pull sparse data from KVServer given specified ID. + * pull_wait(): + wait scheduled pull operation finish its job. + * shut_down(): + shut down all KVServer nodes. Note that, DO NOT use KVClient in multiple threads! @@ -292,10 +269,6 @@ def __init__(self, client_id, server_namebook, client_addr, net_type='socket'): assert client_id >= 0, 'client_id (%d) cannot be a nagative number.' % client_id assert len(server_namebook) > 0, 'server_namebook cannot be empty.' assert len(client_addr.split(':')) == 2, 'Incorrect IP format: %s' % client_addr - # self._data_size is a key-value store where the key is data name - # and value is the size of tensor. It is used to partition data into - # different KVServer nodes. - self._data_size = {} self._client_id = client_id self._server_namebook = server_namebook self._server_count = len(server_namebook) @@ -319,13 +292,20 @@ def connect(self): client_ip, client_port = self._addr.split(':') _receiver_wait(self._receiver, client_ip, int(client_port), self._server_count) - def init_data(self, name, shape, init_type='zero', low=0.0, high=0.0): + def init_data(self, name, server_id, shape, init_type='zero', low=0.0, high=0.0): """Initialize kvstore tensor + we hack the msg format here: msg.id store the shape of target tensor, + msg.data has two row, and the first row is the init_type, + [0, 0] means 'zero' and [1,1] means 'uniform'. + The second row is the min & max threshold. + Parameters ---------- name : str data name + server_id : int + target server id shape : list of int shape of tensor init_type : str @@ -335,158 +315,86 @@ def init_data(self, name, shape, init_type='zero', low=0.0, high=0.0): high : float max threshold, if use 'uniform' """ - self._data_size[name] = shape[0] - count = math.ceil(shape[0] / self._server_count) - # We hack the msg format here + tensor_shape = F.tensor(shape) init_type = 0.0 if init_type == 'zero' else 1.0 threshold = F.tensor([[init_type, init_type], [low, high]]) - # partition shape on server - for server_id in range(self._server_count): - par_shape = shape.copy() - if shape[0] - server_id*count >= count: - par_shape[0] = count - else: - par_shape[0] = shape[0] - server_id*count - tensor_shape = F.tensor(par_shape) - msg = KVStoreMsg( - type=KVMsgType.INIT, - rank=self._client_id, - name=name, - id=tensor_shape, - data=threshold) - _send_kv_msg(self._sender, msg, server_id) + msg = KVStoreMsg( + type=KVMsgType.INIT, + rank=self._client_id, + name=name, + id=tensor_shape, + data=threshold) + _send_kv_msg(self._sender, msg, server_id) - def push(self, name, ID, data): - """Push sparse message to KVServer + def push(self, name, server_id, id_tensor, data_tensor): + """Push sparse message to target KVServer. - The push() API will partition message into different - KVServer nodes automatically. - - Note that we assume the row Ids in ID is in the ascending order. + Note that push() is an async operation that will return immediately after calling. Parameters ---------- name : str data name - ID : tensor (mx.ndarray or torch.tensor) - a vector storing the global IDs - data : tensor (mx.ndarray or torch.tensor) + server_id : int + target server id + id_tensor : tensor (mx.ndarray or torch.tensor) + a vector storing the ID list + data_tensor : tensor (mx.ndarray or torch.tensor) a tensor with the same row size of id """ - assert F.ndim(ID) == 1, 'ID must be a vector.' - assert F.shape(ID)[0] == F.shape(data)[0], 'The data must has the same row size with ID.' - group_size = [0] * self._server_count - numpy_id = F.asnumpy(ID) - count = math.ceil(self._data_size[name] / self._server_count) - server_id = numpy_id / count - id_list, id_count = np.unique(server_id, return_counts=True) - for idx in range(len(id_list)): - group_size[int(id_list[idx])] += id_count[idx] - min_idx = 0 - max_idx = 0 - for idx in range(self._server_count): - if group_size[idx] == 0: - continue - max_idx += group_size[idx] - range_id = ID[min_idx:max_idx] - range_data = data[min_idx:max_idx] - min_idx = max_idx - msg = KVStoreMsg( - type=KVMsgType.PUSH, - rank=self._client_id, - name=name, - id=range_id, - data=range_data) - _send_kv_msg(self._sender, msg, idx) - - def push_all(self, name, data): - """Push the whole data to KVServer - - The push_all() API will partition message into different - KVServer nodes automatically. - - Note that we assume the row Ids in ID is in the ascending order. - - Parameters - ---------- - name : str - data name - data : tensor (mx.ndarray or torch.tensor) - data tensor - """ - ID = F.zerocopy_from_numpy(np.arange(F.shape(data)[0])) - self.push(name, ID, data) + assert server_id >= 0, 'server_id (%d) cannot be a negative number' % server_id + assert server_id < self._server_count, 'server_id (%d) must be smaller than server_count' % server_id + assert F.ndim(id_tensor) == 1, 'ID must be a vector.' + assert F.shape(id_tensor)[0] == F.shape(data_tensor)[0], 'The data must has the same row size with ID.' + msg = KVStoreMsg( + type=KVMsgType.PUSH, + rank=self._client_id, + name=name, + id=id_tensor, + data=data_tensor) + _send_kv_msg(self._sender, msg, server_id) - def pull(self, name, ID): + def pull(self, name, server_id, id_tensor): """Pull sparse message from KVServer - Note that we assume the row Ids in ID is in the ascending order. + Note that pull() is async operation that will return immediately after calling. + User can use pull_wait() to get the real data pulled from the kvserver. The order + of received data that comes from the same server is deterministic. Parameters ---------- name : str data name - ID : tensor (mx.ndarray or torch.tensor) - a vector storing the IDs + server_id : int + target server id + id_tensor : tensor (mx.ndarray or torch.tensor) + a vector storing the ID list - Return - ------ - tensor - a tensor with the same row size of ID """ - assert F.ndim(ID) == 1, 'ID must be a vector.' - group_size = [0] * self._server_count - numpy_id = F.asnumpy(ID) - count = math.ceil(self._data_size[name] / self._server_count) - server_id = numpy_id / count - id_list, id_count = np.unique(server_id, return_counts=True) - for idx in range(len(id_list)): - group_size[int(id_list[idx])] += id_count[idx] - min_idx = 0 - max_idx = 0 - server_count = 0 - for idx in range(self._server_count): - if group_size[idx] == 0: - continue - server_count += 1 - max_idx += group_size[idx] - range_id = ID[min_idx:max_idx] - min_idx = max_idx - msg = KVStoreMsg( - type=KVMsgType.PULL, - rank=self._client_id, - name=name, - id=range_id, - data=None) - _send_kv_msg(self._sender, msg, idx) - # Recv back message - msg_list = [] - for idx in range(self._server_count): - if group_size[idx] == 0: - continue - msg = _recv_kv_msg(self._receiver) - assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.' - msg_list.append(msg) - - return self._merge_msg(msg_list) + assert server_id >= 0, 'server_id (%d) cannot be a negative number' % server_id + assert server_id < self._server_count, 'server_id (%d) must be smaller than server_count' % server_id + assert F.ndim(id_tensor) == 1, 'ID must be a vector.' + msg = KVStoreMsg( + type=KVMsgType.PULL, + rank=self._client_id, + name=name, + id=id_tensor, + data=None) + _send_kv_msg(self._sender, msg, server_id) - def pull_all(self, name): - """Pull the whole data from KVServer + def pull_wait(self): + """Wait pull() finish its job. - Note that we assume the row Ids in ID is in the ascending order. - - Parameters - ---------- - name : str - data name - - Return - ------ - tensor + Returns + ------- + msg.rank + server_id + msg.data target data tensor """ - ID = F.zerocopy_from_numpy(np.arange(self._data_size[name])) - return self.pull(name, ID) + msg = _recv_kv_msg(self._receiver) + assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.' + return msg.rank, msg.data def barrier(self): """Barrier for all client nodes @@ -528,29 +436,3 @@ def get_id(self): KVClient ID """ return self._client_id - - def _sort_func(self, msg): - """Sort function for KVStoreMsg: sort message by rank - - Parameters - ---------- - msg : KVStoreMsg - KVstore message - """ - return msg.rank - - def _merge_msg(self, msg_list): - """Merge separated message to a big matrix - - Parameters - ---------- - msg_list : list - a list of KVStoreMsg - - Return - ------ - tensor (mx.ndarray or torch.tensor) - a merged data matrix - """ - msg_list.sort(key=self._sort_func) - return F.cat([msg.data for msg in msg_list], 0) \ No newline at end of file diff --git a/tests/compute/test_kvstore.py b/tests/compute/test_kvstore.py index 345d08162465..4325e1e6ce5e 100644 --- a/tests/compute/test_kvstore.py +++ b/tests/compute/test_kvstore.py @@ -2,7 +2,7 @@ import numpy as np import scipy as sp import dgl -import torch +import torch as th from dgl import utils import os @@ -28,70 +28,38 @@ def start_client(): client.connect() - client.init_data(name='embed_0', shape=[10, 3], init_type='zero') - client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) - client.init_data(name='embed_2', shape=[11], init_type='zero') + # Initialize data on server + client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero') + client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0) - tensor_id = torch.tensor([0, 1, 2]) - tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) + data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) + data_1 = th.tensor([0., 1., 2.]) - # Push for i in range(5): - client.push('embed_0', tensor_id, tensor_data) - client.push('embed_1', tensor_id, tensor_data) - client.push('embed_2', tensor_id, torch.tensor([2., 2., 2.])) + client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0) + client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1) - tensor_id = torch.tensor([6, 7, 8]) - for i in range(5): - client.push('embed_0', tensor_id, tensor_data) - client.push('embed_1', tensor_id, tensor_data) - client.push('embed_2', tensor_id, torch.tensor([3., 3., 3.])) - - # Pull - tensor_id = torch.tensor([0, 1, 2, 6, 7, 8]) - new_tensor_0 = client.pull('embed_0', tensor_id) - new_tensor_1 = client.pull('embed_1', tensor_id) - new_tensor_2 = client.pull('embed_2', tensor_id) - - target_tensor = torch.tensor( - [[ 0., 0., 0.], - [ 5., 5., 5.], - [10., 10., 10.], - [ 0., 0., 0.], - [ 5., 5., 5.], - [10., 10., 10.]]) - - assert torch.equal(new_tensor_0, target_tensor) == True - assert torch.equal(new_tensor_1, target_tensor) == True - - target_tensor = tensor.tensor([10., 10., 10., 15., 15., 15.]) - - assert torch.equal(new_tensor_2, target_tensor) == True - - client.push_all('embed_0', client.pull_all('embed_0')) - client.push_all('embed_1', client.pull_all('embed_1')) - client.push_all('embed_2', client.pull_all('embed_2')) - - # Pull - tensor_id = torch.tensor([0, 1, 2, 6, 7, 8]) - new_tensor_0 = client.pull('embed_0', tensor_id) - new_tensor_1 = client.pull('embed_1', tensor_id) - new_tensor_2 = client.pull('embed_2', tensor_id) - - target_tensor = torch.tensor( - [[ 0., 0., 0.], - [ 10., 10., 10.], - [20., 20., 20.], - [ 0., 0., 0.], - [ 10., 10., 10.], - [20., 20., 20.]]) - - assert torch.equal(new_tensor_0, target_tensor) == True - assert torch.equal(new_tensor_1, target_tensor) == True - - target_tensor = tensor.tensor([20., 20., 20., 30., 30., 30.]) - - assert torch.equal(new_tensor_2, target_tensor) == True + client.barrier() + + client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) + server_id, new_tensor = client.pull_wait() + assert server_id == 0 + + target_tensor = th.tensor( + [[ 0., 0., 0.], + [ 0., 0., 0.], + [ 5., 5., 5.], + [ 0., 0., 0.], + [10., 10., 10.]]) + + assert th.equal(new_tensor, target_tensor) == True + + client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) + server_id, new_tensor = client.pull_wait() + + target_tensor = th.tensor([ 0., 0., 5., 0., 10.]) + + assert th.equal(new_tensor, target_tensor) == True client.shut_down() diff --git a/third_party/dmlc-core b/third_party/dmlc-core index 0f3ddbc7240e..7ce90a342b0b 160000 --- a/third_party/dmlc-core +++ b/third_party/dmlc-core @@ -1 +1 @@ -Subproject commit 0f3ddbc7240efa05bfffd5bca808ec262ce3630e +Subproject commit 7ce90a342b0bda9b7f88e707a326496324d60efd