Skip to content

Commit

Permalink
[Feature] long live server for multiple client groups (dmlc#3645)
Browse files Browse the repository at this point in the history
* [Feature] long live server for multiple client groups

* generate globally unique name for DistTensor within DGL automatically
  • Loading branch information
Rhett-Ying authored Jan 26, 2022
1 parent 2b98e76 commit 02e4cd8
Show file tree
Hide file tree
Showing 18 changed files with 541 additions and 165 deletions.
2 changes: 1 addition & 1 deletion python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server
from .rpc_client import connect_to_server, shutdown_servers
from .dist_context import initialize, exit_client
from .kvstore import KVServer, KVClient
from .server_state import ServerState
Expand Down
3 changes: 3 additions & 0 deletions python/dgl/distributed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@

# Maximum size of message queue in bytes
MAX_QUEUE_SIZE = 20*1024*1024*1024

SERVER_EXIT = "server_exit"
SERVER_KEEP_ALIVE = "server_keep_alive"
30 changes: 22 additions & 8 deletions python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import os
import sys
import queue
import gc
from enum import Enum

from . import rpc
from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server, shutdown_servers
from .rpc_client import connect_to_server
from .role import init_role
from .. import utils

Expand All @@ -33,13 +34,13 @@ def get_sampler_pool():
return SAMPLER_POOL, NUM_SAMPLER_WORKERS


def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads):
def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_threads, group_id):
''' This init function is called in the worker processes.
'''
try:
utils.set_num_threads(num_threads)
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id)
init_role(role)
init_kvstore(ip_config, num_servers, role)
except Exception as e:
Expand Down Expand Up @@ -227,12 +228,14 @@ def initialize(ip_config, num_servers=1, num_workers=0,
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
rpc.reset()
keep_alive = os.environ.get('DGL_KEEP_ALIVE') is not None
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'),
graph_format=formats)
graph_format=formats,
keep_alive=keep_alive)
serv.start()
sys.exit()
else:
Expand All @@ -244,22 +247,23 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_servers = int(os.environ.get('DGL_NUM_SERVER'))
else:
num_servers = 1

group_id = int(os.environ.get('DGL_GROUP_ID', 0))
rpc.reset()
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
is_standalone = os.environ.get(
'DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone:
SAMPLER_POOL = CustomPool(num_workers, (ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads))
net_type, 'sampler', num_worker_threads,
group_id))
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
connect_to_server(ip_config, num_servers, max_queue_size, net_type, group_id=group_id)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')

Expand Down Expand Up @@ -299,6 +303,14 @@ def is_initialized():
return INITIALIZED


def _shutdown_servers():
set_initialized(False)
# send ShutDownRequest to servers
if rpc.get_rank() == 0: # Only client_0 issue this command
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)

def exit_client():
"""Trainer exits
Expand All @@ -311,9 +323,11 @@ def exit_client():
"""
# Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
# collect data such as DistTensor before exit
gc.collect()
if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.client_barrier()
shutdown_servers()
_shutdown_servers()
finalize_client()
join_finalize_worker()
close_kvstore()
Expand Down
17 changes: 12 additions & 5 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self, g, ntype=None):
dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy)
part_policy=policy, attach=False)

def _get_names(self):
return list(self._data.keys())
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(self, g, etype=None):
dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy)
part_policy=policy, attach=False)

def _get_names(self):
return list(self._data.keys())
Expand Down Expand Up @@ -308,16 +308,19 @@ class DistGraphServer(KVServer):
Disable shared memory.
graph_format : str or list of str
The graph formats.
keep_alive : bool
Whether to keep server alive when clients exit
'''
def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False,
graph_format=('csc', 'coo')):
graph_format=('csc', 'coo'), keep_alive=False):
super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config,
num_servers=num_servers,
num_clients=num_clients)
self.ip_config = ip_config
self.num_servers = num_servers
self.keep_alive = keep_alive
# Load graph partition data.
if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards.
Expand Down Expand Up @@ -351,20 +354,24 @@ def __init__(self, server_id, ip_config, num_servers,
data_name = HeteroDataName(True, ntype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=node_feats[name])
self.orig_data.add(str(data_name))
for name in edge_feats:
# The feature name has the following format: edge_type + "/" + feature_name to avoid
# feature name collision for different edge types.
etype, feat_name = name.split('/')
data_name = HeteroDataName(False, etype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=edge_feats[name])
self.orig_data.add(str(data_name))

def start(self):
""" Start graph store server.
"""
# start server
server_state = ServerState(kv_store=self, local_g=self.client_g, partition_book=self.gpb)
print('start graph service on server {} for part {}'.format(self.server_id, self.part_id))
server_state = ServerState(kv_store=self, local_g=self.client_g,
partition_book=self.gpb, keep_alive=self.keep_alive)
print('start graph service on server {} for part {}'.format(
self.server_id, self.part_id))
start_server(server_id=self.server_id,
ip_config=self.ip_config,
num_servers=self.num_servers,
Expand Down
40 changes: 35 additions & 5 deletions python/dgl/distributed/dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .role import get_role
from .. import utils
from .. import backend as F
from .rpc import get_group_id

def _default_init_data(shape, dtype):
return F.zeros(shape, dtype, F.cpu())
Expand Down Expand Up @@ -80,6 +81,8 @@ def init_func(shape, dtype):
Whether the created tensor lives after the ``DistTensor`` object is destroyed.
is_gdata : bool
Whether the created tensor is a ndata/edata or not.
attach : bool
Whether to attach group ID into name to be globally unique.
Examples
--------
Expand All @@ -102,12 +105,13 @@ def init_func(shape, dtype):
do the same.
'''
def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None,
persistent=False, is_gdata=True):
persistent=False, is_gdata=True, attach=True):
self.kvstore = get_kvstore()
assert self.kvstore is not None, \
'Distributed module is not initialized. Please call dgl.distributed.initialize.'
self._shape = shape
self._dtype = dtype
self._attach = attach

part_policies = self.kvstore.all_possible_part_policy
# If a user doesn't provide a partition policy, we should find one based on
Expand All @@ -128,7 +132,6 @@ def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None,
+ 'its first dimension does not match the number of nodes or edges ' \
+ 'of a distributed graph or there does not exist a distributed graph.'

self._tensor_name = name
self._part_policy = part_policy
assert part_policy.get_size() == shape[0], \
'The partition policy does not match the input shape.'
Expand All @@ -146,6 +149,8 @@ def __init__(self, shape, dtype, name=None, init_func=None, part_policy=None,
name = 'anonymous-' + get_role() + '-' + str(DIST_TENSOR_ID)
DIST_TENSOR_ID += 1
assert isinstance(name, str), 'name {} is type {}'.format(name, type(name))
name = self._attach_group_id(name)
self._tensor_name = name
data_name = part_policy.get_data_name(name)
self._name = str(data_name)
self._persistent = persistent
Expand Down Expand Up @@ -220,7 +225,7 @@ def name(self):
str
The name of the tensor.
'''
return self._name
return self._detach_group_id(self._name)

@property
def tensor_name(self):
Expand All @@ -231,7 +236,7 @@ def tensor_name(self):
str
The name of the tensor.
'''
return self._tensor_name
return self._detach_group_id(self._tensor_name)

def count_nonzero(self):
'''Count and return the number of nonzero value
Expand All @@ -241,4 +246,29 @@ def count_nonzero(self):
int
the number of nonzero value
'''
return self.kvstore.count_nonzero(name=self.name)
return self.kvstore.count_nonzero(name=self._name)

def _attach_group_id(self, name):
"""Attach group ID if needed
Returns
-------
str
new name with group ID attached
"""
if not self._attach:
return name
return "{}_{}".format(name, get_group_id())

def _detach_group_id(self, name):
"""Detach group ID if needed
Returns
-------
str
original name without group ID
"""
if not self._attach:
return name
suffix = "_{}".format(get_group_id())
return name[:-len(suffix)]
27 changes: 20 additions & 7 deletions python/dgl/distributed/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,20 +206,23 @@ class BarrierRequest(rpc.Request):
"""
def __init__(self, role):
self.role = role
self.group_id = rpc.get_group_id()

def __getstate__(self):
return self.role
return self.role, self.group_id

def __setstate__(self, state):
self.role = state
self.role, self.group_id = state

def process_request(self, server_state):
kv_store = server_state.kv_store
role = server_state.roles
count = kv_store.barrier_count[self.role]
kv_store.barrier_count[self.role] = count + 1
if kv_store.barrier_count[self.role] == len(role[self.role]):
kv_store.barrier_count[self.role] = 0
roles = server_state.roles
role = roles[self.group_id]
barrier_count = kv_store.barrier_count[self.group_id]
count = barrier_count[self.role]
barrier_count[self.role] = count + 1
if barrier_count[self.role] == len(role[self.role]):
barrier_count[self.role] = 0
res_list = []
for client_id, _ in role[self.role]:
res_list.append((client_id, BarrierResponse(BARRIER_MSG)))
Expand Down Expand Up @@ -362,6 +365,9 @@ def process_request(self, server_state):
meta = {}
kv_store = server_state.kv_store
for name, data in kv_store.data_store.items():
if server_state.keep_alive:
if name not in kv_store.orig_data:
continue
meta[name] = (F.shape(data),
F.reverse_data_type_dict[F.dtype(data)],
kv_store.part_policy[name].policy_str)
Expand Down Expand Up @@ -671,6 +677,8 @@ def __init__(self, server_id, ip_config, num_servers, num_clients):
CountLocalNonzeroResponse)
# Store the tensor data with specified data name
self._data_store = {}
# Store original tensor data names when instantiating DistGraphServer
self._orig_data = set()
# Store the partition information with specified data name
self._policy_set = set()
self._part_policy = {}
Expand Down Expand Up @@ -715,6 +723,11 @@ def data_store(self):
"""Get data store"""
return self._data_store

@property
def orig_data(self):
"""Get original data"""
return self._orig_data

@property
def part_policy(self):
"""Get part policy"""
Expand Down
17 changes: 10 additions & 7 deletions python/dgl/distributed/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,22 @@ def __init__(self, client_id, machine_id, role):
self.client_id = client_id
self.machine_id = machine_id
self.role = role
self.group_id = rpc.get_group_id()

def __getstate__(self):
return self.client_id, self.machine_id, self.role
return self.client_id, self.machine_id, self.role, self.group_id

def __setstate__(self, state):
self.client_id, self.machine_id, self.role = state
self.client_id, self.machine_id, self.role, self.group_id = state

def process_request(self, server_state):
kv_store = server_state.kv_store
role = server_state.roles
role = server_state.roles.setdefault(self.group_id, {})
if self.role not in role:
role[self.role] = set()
if kv_store is not None:
kv_store.barrier_count[self.role] = 0
barrier_count = kv_store.barrier_count.setdefault(self.group_id, {})
barrier_count[self.role] = 0
role[self.role].add((self.client_id, self.machine_id))
total_count = 0
for key in role:
Expand Down Expand Up @@ -84,15 +86,16 @@ class GetRoleRequest(rpc.Request):
"""Send a request to get the roles of all client processes."""
def __init__(self):
self.msg = GET_ROLE_MSG
self.group_id = rpc.get_group_id()

def __getstate__(self):
return self.msg
return self.msg, self.group_id

def __setstate__(self, state):
self.msg = state
self.msg, self.group_id = state

def process_request(self, server_state):
return GetRoleResponse(server_state.roles)
return GetRoleResponse(server_state.roles[self.group_id])

# The key is role, the value is a dict of mapping RPC rank to a rank within the role.
PER_ROLE_RANK = {}
Expand Down
Loading

0 comments on commit 02e4cd8

Please sign in to comment.