Skip to content

Commit

Permalink
[KVStore] Remove Freeze flag (dmlc#1605)
Browse files Browse the repository at this point in the history
* remove freeze

* update

* update

* fix lint
  • Loading branch information
aksnzhy authored Jun 9, 2020
1 parent cbe4c28 commit 8eab08d
Showing 1 changed file with 27 additions and 48 deletions.
75 changes: 27 additions & 48 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Define distributed kvstore"""

import os
import time
import random
import numpy as np

Expand Down Expand Up @@ -356,8 +355,6 @@ def process_request(self, server_state):
kv_store.part_policy[name].policy_str)
if len(meta) == 0:
raise RuntimeError('There is no data on kvserver.')
# Freeze data init
kv_store.freeze = True
res = GetSharedDataResponse(meta)
return res

Expand Down Expand Up @@ -451,10 +448,11 @@ def __setstate__(self, state):
def process_request(self, server_state):
kv_store = server_state.kv_store
assert kv_store.is_backup_server()
shared_data = empty_shared_mem(self.name+'-kvdata-', False, self.shape, self.dtype)
dlpack = shared_data.to_dlpack()
kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack)
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
if self.name not in kv_store.data_store:
shared_data = empty_shared_mem(self.name+'-kvdata-', False, self.shape, self.dtype)
dlpack = shared_data.to_dlpack()
kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack)
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res

Expand Down Expand Up @@ -570,8 +568,6 @@ def __init__(self, server_id, ip_config, num_clients):
# push and pull handler
self._push_handler = default_push_handler
self._pull_handler = default_pull_handler
# We cannot create new data on kvstore when freeze == True
self._freeze = False

@property
def server_id(self):
Expand All @@ -588,16 +584,6 @@ def barrier_count(self, count):
"""Set barrier count"""
self._barrier_count = count

@property
def freeze(self):
"""Get freeze"""
return self._freeze

@freeze.setter
def freeze(self, freeze):
"""Set freeze"""
self._freeze = freeze

@property
def num_clients(self):
"""Get number of clients"""
Expand Down Expand Up @@ -669,9 +655,6 @@ def init_data(self, name, policy_str, data_tensor=None):
read shared-memory when client invoking get_shared_data().
"""
assert len(name) > 0, 'name cannot be empty.'
if self._freeze:
raise RuntimeError("KVServer cannot create new data \
after client invoking get_shared_data() API.")
if self._data_store.__contains__(name):
raise RuntimeError("Data %s has already exists!" % name)
if data_tensor is not None: # Create shared-tensor
Expand Down Expand Up @@ -764,9 +747,6 @@ def __init__(self, ip_config):
# push and pull handler
self._pull_handler = default_pull_handler
self._push_handler = default_push_handler
# We cannot create new data on kvstore when freeze == True
self._freeze = False
random.seed(time.time())

@property
def client_id(self):
Expand Down Expand Up @@ -858,9 +838,7 @@ def init_data(self, name, shape, dtype, policy_str, partition_book, init_func):
assert len(name) > 0, 'name cannot be empty.'
assert len(shape) > 0, 'shape cannot be empty'
assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.'
if self._freeze:
raise RuntimeError("KVClient cannot create new \
data after invoking get_shared_data() API.")
assert name not in self._data_name_list, 'data name: %s already exists.' % name
shape = list(shape)
if self._client_id == 0:
for machine_id in range(self._machine_count):
Expand Down Expand Up @@ -920,27 +898,28 @@ def map_shared_data(self, partition_book):
rpc.send_request(self._main_server_id, request)
response = rpc.recv_response()
for name, meta in response.meta.items():
shape, dtype, policy_str = meta
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype)
dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book)
self._data_name_list.add(name)
if name not in self._data_name_list:
shape, dtype, policy_str = meta
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype)
dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book)
# Get full data shape across servers
for name, meta in response.meta.items():
shape, _, _ = meta
data_shape = list(shape)
data_shape[0] = 0
request = GetPartShapeRequest(name)
# send request to all main server nodes
for machine_id in range(self._machine_count):
server_id = machine_id * self._group_count
rpc.send_request(server_id, request)
# recv response from all the main server nodes
for _ in range(self._machine_count):
res = rpc.recv_response()
data_shape[0] += res.shape[0]
self._full_data_shape[name] = tuple(data_shape)
if name not in self._data_name_list:
shape, _, _ = meta
data_shape = list(shape)
data_shape[0] = 0
request = GetPartShapeRequest(name)
# send request to all main server nodes
for machine_id in range(self._machine_count):
server_id = machine_id * self._group_count
rpc.send_request(server_id, request)
# recv response from all the main server nodes
for _ in range(self._machine_count):
res = rpc.recv_response()
data_shape[0] += res.shape[0]
self._full_data_shape[name] = tuple(data_shape)
# Send meta data to backup servers
for name, meta in response.meta.items():
shape, dtype, policy_str = meta
Expand All @@ -953,7 +932,7 @@ def map_shared_data(self, partition_book):
for _ in range(self._group_count-1):
response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG
self._freeze = True
self._data_name_list.add(name)

def data_name_list(self):
"""Get all the data name"""
Expand Down

0 comments on commit 8eab08d

Please sign in to comment.