Skip to content

Commit

Permalink
[Distributed] Distributed launching script (dmlc#1772)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix launch script.

Co-authored-by: Da Zheng <[email protected]>
  • Loading branch information
aksnzhy and zheng-da authored Jul 16, 2020
1 parent 0e92dad commit ca9d321
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 44 deletions.
52 changes: 8 additions & 44 deletions examples/pytorch/graphsage/experimental/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,53 +24,17 @@ python3 partition_graph.py --dataset ogb-product --num_parts 4 --balance_train -
When copying data to the cluster, we recommend users to copy the partitioned data to NFS so that all worker machines
will be able to access the partitioned data.

### Step 3: run servers
### Step 3: Launch distributed jobs

To perform actual distributed training (running training jobs in multiple machines), we need to run
a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses.

On each of the machines, set the following environment variables.

```bash
export DGL_ROLE=server
export DGL_IP_CONFIG=ip_config.txt
export DGL_CONF_PATH=data/ogb-product.json
export DGL_NUM_CLIENT=4
```

```bash
# run server on machine 0
export DGL_SERVER_ID=0
python3 train_dist.py

# run server on machine 1
export DGL_SERVER_ID=1
python3 train_dist.py

# run server on machine 2
export DGL_SERVER_ID=2
python3 train_dist.py

# run server on machine 3
export DGL_SERVER_ID=3
python3 train_dist.py
```

### Step 4: run trainers
We run a trainer process on each machine. Here we use Pytorch distributed. We need to use pytorch distributed launch to run each trainer process.
Pytorch distributed requires one of the trainer process to be the master. Here we use the first machine to run the master process.
First make sure that the master node has the right permission to ssh to all the other nodes. Then run script:

```bash
# set the DistGraph in distributed mode
export DGL_DIST_MODE="distributed"
# run client on machine 0
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=1 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 2
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
# run client on machine 3
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --lr 0.1
python3 ~/dgl/tools/launch.py \
--workspace ~/dgl/examples/pytorch/graphsage/experimental \
--num_client 4 \
--conf_path data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000 --lr 0.1 --num-client 4"
```

## Distributed code runs in the standalone mode
Expand Down
108 changes: 108 additions & 0 deletions tools/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Launching tool for DGL distributed training"""
import os
import stat
import sys
import subprocess
import argparse
import signal
import logging
import time
from threading import Thread

def execute_remote(cmd, ip, thread_list):
"""execute command line on remote machine via ssh"""
cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\''
# thread func to run the job
def run(cmd):
subprocess.check_call(cmd, shell = True)

thread = Thread(target = run, args=(cmd,))
thread.setDaemon(True)
thread.start()
thread_list.append(thread)

def submit_jobs(args, udf_command):
"""Submit distributed jobs (server and client processes) via ssh"""
hosts = []
thread_list = []
server_count_per_machine = 0
ip_config = args.workspace + '/' + args.ip_config
with open(ip_config) as f:
for line in f:
ip, port, count = line.strip().split(' ')
port = int(port)
count = int(count)
server_count_per_machine = count
hosts.append((ip, port))
assert args.num_client % len(hosts) == 0
client_count_per_machine = int(args.num_client / len(hosts))
# launch server tasks
server_cmd = 'DGL_ROLE=server'
server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client)
server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.conf_path)
server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
for i in range(len(hosts)*server_count_per_machine):
ip, _ = hosts[int(i / server_count_per_machine)]
cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i)
cmd = cmd + ' ' + udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, thread_list)
# launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client'
client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(args.num_client)
client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.conf_path)
client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
if os.environ.get('OMP_NUM_THREADS') is not None:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS')
if os.environ.get('PYTHONPATH') is not None:
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')

torch_cmd = '-m torch.distributed.launch'
torch_cmd = torch_cmd + ' ' + '--nproc_per_node=' + str(client_count_per_machine)
torch_cmd = torch_cmd + ' ' + '--nnodes=' + str(len(hosts))
torch_cmd = torch_cmd + ' ' + '--node_rank=' + str(0)
torch_cmd = torch_cmd + ' ' + '--master_addr=' + str(hosts[0][0])
torch_cmd = torch_cmd + ' ' + '--master_port=' + str(1234)

for i in range(args.num_client):
node_id = int(i / client_count_per_machine)
ip, _ = hosts[node_id]
new_torch_cmd = torch_cmd.replace('node_rank=0', 'node_rank='+str(node_id))
new_udf_command = udf_command.replace('python3', 'python3 ' + new_torch_cmd)
cmd = client_cmd + ' ' + new_udf_command
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
execute_remote(cmd, ip, thread_list)

for thread in thread_list:
thread.join()

def main():
parser = argparse.ArgumentParser(description='Launch a distributed job')
parser.add_argument('--workspace', type=str,
help='Path of user directory of distributed tasks. \
This is used to specify a destination location where \
the contents of current directory will be rsyncd')
parser.add_argument('--num_client', type=int,
help='Total number of client processes in the cluster')
parser.add_argument('--conf_path', type=str,
help='The path to the partition config file. This path can be \
a remote path like s3 and dgl will download this file automatically')
parser.add_argument('--ip_config', type=str,
help='The file for IP configuration for server processes')
args, udf_command = parser.parse_known_args()
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_client > 0, '--num_client must be a positive number.'
udf_command = str(udf_command[0])
if 'python' not in udf_command:
raise RuntimeError("DGL launch can only support: python ...")
submit_jobs(args, udf_command)

def signal_handler(signal, frame):
logging.info('Stop launcher')
sys.exit(0)

if __name__ == '__main__':
fmt = '%(asctime)s %(levelname)s %(message)s'
logging.basicConfig(format=fmt, level=logging.INFO)
signal.signal(signal.SIGINT, signal_handler)
main()

0 comments on commit ca9d321

Please sign in to comment.