diff --git a/examples/pytorch/graphsage/experimental/README.md b/examples/pytorch/graphsage/experimental/README.md index 42183826cad9..8ed0a5d7550c 100644 --- a/examples/pytorch/graphsage/experimental/README.md +++ b/examples/pytorch/graphsage/experimental/README.md @@ -24,53 +24,17 @@ python3 partition_graph.py --dataset ogb-product --num_parts 4 --balance_train - When copying data to the cluster, we recommend users to copy the partitioned data to NFS so that all worker machines will be able to access the partitioned data. -### Step 3: run servers +### Step 3: Launch distributed jobs -To perform actual distributed training (running training jobs in multiple machines), we need to run -a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses. - -On each of the machines, set the following environment variables. - -```bash -export DGL_ROLE=server -export DGL_IP_CONFIG=ip_config.txt -export DGL_CONF_PATH=data/ogb-product.json -export DGL_NUM_CLIENT=4 -``` - -```bash -# run server on machine 0 -export DGL_SERVER_ID=0 -python3 train_dist.py - -# run server on machine 1 -export DGL_SERVER_ID=1 -python3 train_dist.py - -# run server on machine 2 -export DGL_SERVER_ID=2 -python3 train_dist.py - -# run server on machine 3 -export DGL_SERVER_ID=3 -python3 train_dist.py -``` - -### Step 4: run trainers -We run a trainer process on each machine. Here we use Pytorch distributed. We need to use pytorch distributed launch to run each trainer process. -Pytorch distributed requires one of the trainer process to be the master. Here we use the first machine to run the master process. +First make sure that the master node has the right permission to ssh to all the other nodes. Then run script: ```bash -# set the DistGraph in distributed mode -export DGL_DIST_MODE="distributed" -# run client on machine 0 -python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1 -# run client on machine 1 -python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=1 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1 -# run client on machine 2 -python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1 -# run client on machine 3 -python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1 +python3 ~/dgl/tools/launch.py \ +--workspace ~/dgl/examples/pytorch/graphsage/experimental \ +--num_client 4 \ +--conf_path data/ogb-product.json \ +--ip_config ip_config.txt \ +"python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000 --lr 0.1 --num-client 4" ``` ## Distributed code runs in the standalone mode diff --git a/tools/launch.py b/tools/launch.py new file mode 100644 index 000000000000..c9606c339df8 --- /dev/null +++ b/tools/launch.py @@ -0,0 +1,108 @@ +"""Launching tool for DGL distributed training""" +import os +import stat +import sys +import subprocess +import argparse +import signal +import logging +import time +from threading import Thread + +def execute_remote(cmd, ip, thread_list): + """execute command line on remote machine via ssh""" + cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\'' + # thread func to run the job + def run(cmd): + subprocess.check_call(cmd, shell = True) + + thread = Thread(target = run, args=(cmd,)) + thread.setDaemon(True) + thread.start() + thread_list.append(thread) + +def submit_jobs(args, udf_command): + """Submit distributed jobs (server and client processes) via ssh""" + hosts = [] + thread_list = [] + server_count_per_machine = 0 + ip_config = args.workspace + '/' + args.ip_config + with open(ip_config) as f: + for line in f: + ip, port, count = line.strip().split(' ') + port = int(port) + count = int(count) + server_count_per_machine = count + hosts.append((ip, port)) + assert args.num_client % len(hosts) == 0 + client_count_per_machine = int(args.num_client / len(hosts)) + # launch server tasks + server_cmd = 'DGL_ROLE=server' + server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client) + server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.conf_path) + server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) + for i in range(len(hosts)*server_count_per_machine): + ip, _ = hosts[int(i / server_count_per_machine)] + cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i) + cmd = cmd + ' ' + udf_command + cmd = 'cd ' + str(args.workspace) + '; ' + cmd + execute_remote(cmd, ip, thread_list) + # launch client tasks + client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client' + client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client) + client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.conf_path) + client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config) + if os.environ.get('OMP_NUM_THREADS') is not None: + client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS') + if os.environ.get('PYTHONPATH') is not None: + client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH') + + torch_cmd = '-m torch.distributed.launch' + torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(client_count_per_machine) + torch_cmd = torch_cmd + ' ' + '--nnodes=' + str(len(hosts)) + torch_cmd = torch_cmd + ' ' + '--node_rank=' + str(0) + torch_cmd = torch_cmd + ' ' + '--master_addr=' + str(hosts[0][0]) + torch_cmd = torch_cmd + ' ' + '--master_port=' + str(1234) + + for i in range(args.num_client): + node_id = int(i / client_count_per_machine) + ip, _ = hosts[node_id] + new_torch_cmd = torch_cmd.replace('node_rank=0', 'node_rank='+str(node_id)) + new_udf_command = udf_command.replace('python3', 'python3 ' + new_torch_cmd) + cmd = client_cmd + ' ' + new_udf_command + cmd = 'cd ' + str(args.workspace) + '; ' + cmd + execute_remote(cmd, ip, thread_list) + + for thread in thread_list: + thread.join() + +def main(): + parser = argparse.ArgumentParser(description='Launch a distributed job') + parser.add_argument('--workspace', type=str, + help='Path of user directory of distributed tasks. \ + This is used to specify a destination location where \ + the contents of current directory will be rsyncd') + parser.add_argument('--num_client', type=int, + help='Total number of client processes in the cluster') + parser.add_argument('--conf_path', type=str, + help='The path to the partition config file. This path can be \ + a remote path like s3 and dgl will download this file automatically') + parser.add_argument('--ip_config', type=str, + help='The file for IP configuration for server processes') + args, udf_command = parser.parse_known_args() + assert len(udf_command) == 1, 'Please provide user command line.' + assert args.num_client > 0, '--num_client must be a positive number.' + udf_command = str(udf_command[0]) + if 'python' not in udf_command: + raise RuntimeError("DGL launch can only support: python ...") + submit_jobs(args, udf_command) + +def signal_handler(signal, frame): + logging.info('Stop launcher') + sys.exit(0) + +if __name__ == '__main__': + fmt = '%(asctime)s %(levelname)s %(message)s' + logging.basicConfig(format=fmt, level=logging.INFO) + signal.signal(signal.SIGINT, signal_handler) + main()