Skip to content

Commit

Permalink
[Distributed] Move server start code to initialize. (dmlc#2002)
Browse files Browse the repository at this point in the history
* move server start code to initialize.

* fix.

* fix lint.

* fix examples.

* add more checks.
  • Loading branch information
zheng-da authored Aug 11, 2020
1 parent af99098 commit 4efa332
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 40 deletions.
3 changes: 1 addition & 2 deletions examples/pytorch/graphsage/experimental/train_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,9 @@ def run(args, device, data):
print(profiler.output_text(unicode=True, color=True))

def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')

dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def run(args, device, data):
th.save(pred, 'emb.pt')

def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank())
print('number of edges', g.number_of_edges())
Expand Down Expand Up @@ -452,7 +452,7 @@ def main(args):
parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num-servers', type=int, help='Server count on each machine.')
parser.add_argument('--num-servers', type=int, default=1, help='Server count on each machine.')
parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
Expand Down
19 changes: 0 additions & 19 deletions python/dgl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,3 @@
from .server_state import ServerState
from .dist_dataloader import DistDataLoader
from .graph_services import sample_neighbors, in_subgraph, find_edges

if os.environ.get('DGL_ROLE', 'client') == 'server':
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
SERV = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
SERV.start()
sys.exit()
53 changes: 38 additions & 15 deletions python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import atexit
import time
import os
import sys

from . import rpc
from .constants import MAX_QUEUE_SIZE
Expand Down Expand Up @@ -63,22 +64,44 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_worker_threads: int
The number of threads in a worker process.
"""
rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone:
SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc,
initargs=(ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads))
if os.environ.get('DGL_ROLE', 'client') == 'server':
from .dist_graph import DistGraphServer
assert os.environ.get('DGL_SERVER_ID') is not None, \
'Please define DGL_SERVER_ID to run DistGraph server'
assert os.environ.get('DGL_IP_CONFIG') is not None, \
'Please define DGL_IP_CONFIG to run DistGraph server'
assert os.environ.get('DGL_NUM_SERVER') is not None, \
'Please define DGL_NUM_SERVER to run DistGraph server'
assert os.environ.get('DGL_NUM_CLIENT') is not None, \
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert os.environ.get('DGL_CONF_PATH') is not None, \
'Please define DGL_CONF_PATH to run DistGraph server'
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'))
serv.start()
sys.exit()
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')
rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone'
if num_workers > 0 and not is_standalone:
SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc,
initargs=(ip_config, num_servers, max_queue_size,
net_type, 'sampler', num_worker_threads))
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers
if not is_standalone:
assert num_servers is not None and num_servers > 0, \
'The number of servers per machine must be specified with a positive number.'
connect_to_server(ip_config, num_servers, max_queue_size, net_type)
init_role('default')
init_kvstore(ip_config, num_servers, 'default')

def finalize_client():
"""Release resources of this client."""
Expand Down
14 changes: 12 additions & 2 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,18 @@ def main():
the same machine. By default, it is 1.')
args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers > 0, '--num_trainers must be a positive number.'
assert args.num_samplers >= 0
assert args.num_trainers is not None and args.num_trainers > 0, \
'--num_trainers must be a positive number.'
assert args.num_samplers is not None and args.num_samplers >= 0, \
'--num_samplers must be a non-negative number.'
assert args.num_servers is not None and args.num_servers > 0, \
'--num_servers must be a positive number.'
assert args.num_server_threads > 0, '--num_server_threads must be a positive number.'
assert args.workspace is not None, 'A user has to specify a workspace with --workspace.'
assert args.part_config is not None, \
'A user has to specify a partition configuration file with --part_config.'
assert args.ip_config is not None, \
'A user has to specify an IP configuration file with --ip_config.'
udf_command = str(udf_command[0])
if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.")
Expand Down

0 comments on commit 4efa332

Please sign in to comment.