Skip to content

Commit

Permalink
[Distributed] Simplify distributed API (dmlc#2775)
Browse files Browse the repository at this point in the history
* remove num_workers.

* remove num_workers.

* remove num_workers.

* remove num-servers.

* update error message.

* update docstring.

* fix docs.

* fix tests.

* fix test.

* fix.

* print messages in test.

* fix.

* fix test.

* fix.

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
zheng-da and Ubuntu authored Mar 30, 2021
1 parent 97863ab commit e36c5db
Show file tree
Hide file tree
Showing 15 changed files with 50 additions and 47 deletions.
2 changes: 1 addition & 1 deletion docs/source/guide/distributed-apis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Typically, the initialization APIs should be invoked in the following order:

.. code:: python
dgl.distributed.initialize('ip_config.txt', num_workers=4)
dgl.distributed.initialize('ip_config.txt')
th.distributed.init_process_group(backend='gloo')
**Note**: If the training script contains user-defined functions (UDFs) that have to be invoked on
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guide/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ are the same as :ref:`mini-batch training <guide-minibatch>`.
import dgl
import torch as th
dgl.distributed.initialize('ip_config.txt', num_servers, num_workers)
dgl.distributed.initialize('ip_config.txt')
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph('graph_name', 'part_config.json')
pb = g.get_partition_book()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guide_cn/distributed-apis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ DGL分布式模块的初始化

.. code:: python
dgl.distributed.initialize('ip_config.txt', num_workers=4)
dgl.distributed.initialize('ip_config.txt')
th.distributed.init_process_group(backend='gloo')
**Note**: 如果训练脚本里包含需要在服务器(细节内容可以在下面的DistTensor和DistEmbedding章节里查看)上调用的用户自定义函数(UDF),
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guide_cn/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ DGL采用完全分布式的方法,可将数据和计算同时分布在一组
import dgl
import torch as th
dgl.distributed.initialize('ip_config.txt', num_servers, num_workers)
dgl.distributed.initialize('ip_config.txt')
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph('graph_name', 'part_config.json')
pb = g.get_partition_book()
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/graphsage/experimental/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ python3 ~/workspace/dgl/tools/launch.py \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist.py --graph_name ogb-product --ip_config ip_config.txt --num_servers 1 --num_epochs 30 --batch_size 1000 --num_workers 4"
"python3 train_dist.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 30 --batch_size 1000"
```

To run unsupervised training:
Expand All @@ -137,7 +137,7 @@ python3 ~/workspace/dgl/tools/launch.py \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist_unsupervised.py --graph_name ogb-product --ip_config ip_config.txt --num_servers 1 --num_epochs 3 --batch_size 1000 --num_workers 4"
"python3 train_dist_unsupervised.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 3 --batch_size 1000"
```

By default, this code will run on CPU. If you have GPU support, you can just add a `--num_gpus` argument in user command:
Expand All @@ -150,7 +150,7 @@ python3 ~/workspace/dgl/tools/launch.py \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist.py --graph_name ogb-product --ip_config ip_config.txt --num_servers 1 --num_epochs 30 --batch_size 1000 --num_workers 4 --num_gpus 4"
"python3 train_dist.py --graph_name ogb-product --ip_config ip_config.txt --num_epochs 30 --batch_size 1000 --num_gpus 4"
```

**Note:** if you are using conda or other virtual environments on the remote machines, you need to replace `python3` in the command string (i.e. the last argument) with the path to the Python interpreter in that environment.
Expand Down
9 changes: 1 addition & 8 deletions examples/pytorch/graphsage/experimental/train_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def run(args, device, data):
time.time() - start))

def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
dgl.distributed.initialize(args.ip_config)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
Expand Down Expand Up @@ -288,7 +288,6 @@ def main(args):
parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num_clients', type=int, help='The number of clients')
parser.add_argument('--num_servers', type=int, default=1, help='The number of servers')
parser.add_argument('--n_classes', type=int, help='the number of classes')
parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training")
Expand All @@ -302,15 +301,9 @@ def main(args):
parser.add_argument('--eval_every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--num_workers', type=int, default=4,
help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
args = parser.parse_args()
assert args.num_workers == int(os.environ.get('DGL_NUM_SAMPLER')), \
'The num_workers should be the same value with DGL_NUM_SAMPLER.'
assert args.num_servers == int(os.environ.get('DGL_NUM_SERVER')), \
'The num_servers should be the same value with DGL_NUM_SERVER.'

print(args)
main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def run(args, device, data):
th.save(pred, 'emb.pt')

def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
dgl.distributed.initialize(args.ip_config)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
Expand Down Expand Up @@ -458,7 +458,6 @@ def main(args):
parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num_servers', type=int, default=1, help='Server count on each machine.')
parser.add_argument('--n_classes', type=int, help='the number of classes')
parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training")
Expand All @@ -472,8 +471,6 @@ def main(args):
parser.add_argument('--eval_every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--num_workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
parser.add_argument('--num_negs', type=int, default=1)
Expand All @@ -482,10 +479,6 @@ def main(args):
parser.add_argument('--remove_edge', default=False, action='store_true',
help="whether to remove edges during sampling")
args = parser.parse_args()
assert args.num_workers == int(os.environ.get('DGL_NUM_SAMPLER')), \
'The num_workers should be the same value with DGL_NUM_SAMPLER.'
assert args.num_servers == int(os.environ.get('DGL_NUM_SERVER')), \
'The num_servers should be the same value with DGL_NUM_SERVER.'

print(args)
main(args)
2 changes: 1 addition & 1 deletion examples/pytorch/rgcn/experimental/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ python3 ~/workspace/dgl/tools/launch.py \
--num_samplers 4 \
--part_config data/ogbn-mag.json \
--ip_config ip_config.txt \
"python3 entity_classify_dist.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 1024 --n-hidden 64 --lr 0.01 --eval-batch-size 1024 --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt --num-workers 4 --num-servers 1 --sparse-embedding --sparse-lr 0.06 --num_gpus 1"
"python3 entity_classify_dist.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 1024 --n-hidden 64 --lr 0.01 --eval-batch-size 1024 --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt --sparse-embedding --sparse-lr 0.06 --num_gpus 1"
```

We can get the performance score at the second epoch:
Expand Down
6 changes: 1 addition & 5 deletions examples/pytorch/rgcn/experimental/entity_classify_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def run(args, device, data):
time.time() - start))

def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
dgl.distributed.initialize(args.ip_config)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')

Expand Down Expand Up @@ -532,8 +532,6 @@ def main(args):
parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip-config', type=str, help='The file for IP configuration')
parser.add_argument('--conf-path', type=str, help='The path to the partition config file')
parser.add_argument('--num-client', type=int, help='The number of clients')
parser.add_argument('--num-servers', type=int, default=1, help='Server count on each machine.')

# rgcn related
parser.add_argument('--num_gpus', type=int, default=-1,
Expand Down Expand Up @@ -569,8 +567,6 @@ def main(args):
parser.add_argument("--eval-batch-size", type=int, default=128,
help="Mini-batch size. ")
parser.add_argument('--log-every', type=int, default=20)
parser.add_argument("--num-workers", type=int, default=1,
help="Number of workers for distributed dataloader.")
parser.add_argument("--low-mem", default=False, action='store_true',
help="Whether use low mem RelGraphCov")
parser.add_argument("--sparse-embedding", action='store_true',
Expand Down
13 changes: 11 additions & 2 deletions python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def initialize(ip_config, num_servers=1, num_workers=0,
ip_config: str
File path of ip_config file
num_servers : int
The number of server processes on each machine
The number of server processes on each machine. This argument is deprecated in DGL 0.7.0.
num_workers: int
Number of worker process on each machine. The worker processes are used
for distributed sampling.
for distributed sampling. This argument is deprecated in DGL 0.7.0.
max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default).
Expand Down Expand Up @@ -101,6 +101,15 @@ def initialize(ip_config, num_servers=1, num_workers=0,
serv.start()
sys.exit()
else:
if os.environ.get('DGL_NUM_SAMPLER') is not None:
num_workers = int(os.environ.get('DGL_NUM_SAMPLER'))
else:
num_workers = 0
if os.environ.get('DGL_NUM_SERVER') is not None:
num_servers = int(os.environ.get('DGL_NUM_SERVER'))
else:
num_servers = 1

rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL
Expand Down
6 changes: 4 additions & 2 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def rand_init(shape, dtype):

def run_client(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
dgl.distributed.initialize("kv_ip_config.txt", server_count)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph(graph_name, gpb=gpb)
Expand Down Expand Up @@ -240,7 +241,8 @@ def check_server_client(shared_mem, num_servers, num_clients):

def run_client_hetero(graph_name, part_id, server_count, num_clients, num_nodes, num_edges):
time.sleep(5)
dgl.distributed.initialize("kv_ip_config.txt", server_count)
os.environ['DGL_NUM_SERVER'] = str(server_count)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name, _, _ = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph(graph_name, gpb=gpb)
Expand Down
10 changes: 5 additions & 5 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt", 1)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
try:
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
Expand All @@ -40,7 +40,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt", 1)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_find_edges", gpb=gpb)
try:
u, v = find_edges(dist_graph, eids)
Expand Down Expand Up @@ -195,7 +195,7 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt", 1)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert 'feat' in dist_graph.nodes['n1'].data
assert 'feat' not in dist_graph.nodes['n2'].data
Expand Down Expand Up @@ -302,7 +302,7 @@ def check_standalone_sampling(tmpdir, reshuffle):
num_hops=num_hops, part_method='metis', reshuffle=reshuffle)

os.environ['DGL_DIST_MODE'] = 'standalone'
dgl.distributed.initialize("rpc_ip_config.txt", 1)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json')
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)

Expand All @@ -325,7 +325,7 @@ def test_standalone_sampling():

def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None
dgl.distributed.initialize("rpc_ip_config.txt", 1)
dgl.distributed.initialize("rpc_ip_config.txt")
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
Expand Down
12 changes: 7 additions & 5 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients):
g.start()


def start_dist_dataloader(rank, tmpdir, num_server, num_workers, drop_last):
def start_dist_dataloader(rank, tmpdir, num_server, drop_last):
import dgl
import torch as th
dgl.distributed.initialize("mp_ip_config.txt", 1, num_workers=num_workers)
dgl.distributed.initialize("mp_ip_config.txt")
gpb = None
disable_shared_mem = num_server > 0
if disable_shared_mem:
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_standalone(tmpdir):

os.environ['DGL_DIST_MODE'] = 'standalone'
try:
start_dist_dataloader(0, tmpdir, 1, 2, True)
start_dist_dataloader(0, tmpdir, 1, True)
except Exception as e:
print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process
Expand Down Expand Up @@ -159,8 +159,9 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):

time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer = ctx.Process(target=start_dist_dataloader, args=(
0, tmpdir, num_server, num_workers, drop_last))
0, tmpdir, num_server, drop_last))
ptrainer.start()
time.sleep(1)

Expand All @@ -171,7 +172,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
def start_node_dataloader(rank, tmpdir, num_server, num_workers):
import dgl
import torch as th
dgl.distributed.initialize("mp_ip_config.txt", 1, num_workers=num_workers)
dgl.distributed.initialize("mp_ip_config.txt")
gpb = None
disable_shared_mem = num_server > 1
if disable_shared_mem:
Expand Down Expand Up @@ -252,6 +253,7 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):

time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer_list = []
if dataloader_type == 'node':
p = ctx.Process(target=start_node_dataloader, args=(
Expand Down
11 changes: 7 additions & 4 deletions tests/distributed/test_new_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def start_server_mul_role(server_id, num_clients, num_servers):
def start_client(num_clients, num_servers):
os.environ['DGL_DIST_MODE'] = 'distributed'
# Note: connect to server first !
dgl.distributed.initialize(ip_config='kv_ip_config.txt', num_servers=num_servers)
dgl.distributed.initialize(ip_config='kv_ip_config.txt')
# Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt', num_servers=num_servers)
kvclient.map_shared_data(partition_book=gpb)
Expand Down Expand Up @@ -278,10 +278,10 @@ def start_client(num_clients, num_servers):
data_tensor = data_tensor * num_clients
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))

def start_client_mul_role(i, num_workers, num_servers):
def start_client_mul_role(i):
os.environ['DGL_DIST_MODE'] = 'distributed'
# Initialize creates kvstore !
dgl.distributed.initialize(ip_config='kv_ip_mul_config.txt', num_servers=num_servers, num_workers=num_workers)
dgl.distributed.initialize(ip_config='kv_ip_mul_config.txt')
if i == 0: # block one trainer
time.sleep(5)
kvclient = dgl.distributed.kvstore.get_kvstore()
Expand All @@ -305,6 +305,7 @@ def test_kv_store():
ctx = mp.get_context('spawn')
pserver_list = []
pclient_list = []
os.environ['DGL_NUM_SERVER'] = str(num_servers)
for i in range(num_servers):
pserver = ctx.Process(target=start_server, args=(i, num_clients, num_servers))
pserver.start()
Expand Down Expand Up @@ -332,12 +333,14 @@ def test_kv_multi_role():
ctx = mp.get_context('spawn')
pserver_list = []
pclient_list = []
os.environ['DGL_NUM_SAMPLER'] = str(num_samplers)
os.environ['DGL_NUM_SERVER'] = str(num_servers)
for i in range(num_servers):
pserver = ctx.Process(target=start_server_mul_role, args=(i, num_clients, num_servers))
pserver.start()
pserver_list.append(pserver)
for i in range(num_trainers):
pclient = ctx.Process(target=start_client_mul_role, args=(i, num_samplers, num_servers))
pclient = ctx.Process(target=start_client_mul_role, args=(i,))
pclient.start()
pclient_list.append(pclient)
for i in range(num_trainers):
Expand Down
5 changes: 5 additions & 0 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def main():
udf_command = str(udf_command[0])
if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.")
if sys.version_info.major and sys.version_info.minor >= 8:
if args.num_samplers > 0:
print('WARNING! DGL does not support multiple sampler processes in Python>=3.8. '
+ 'Set the number of sampler processes to 0.')
args.num_samplers = 0
submit_jobs(args, udf_command)

def signal_handler(signal, frame):
Expand Down

0 comments on commit e36c5db

Please sign in to comment.