Skip to content

Commit

Permalink
Merge branch 'master' into dist_part
Browse files Browse the repository at this point in the history
  • Loading branch information
jermainewang authored Aug 22, 2022
2 parents 2cf4bd0 + ee672c0 commit 7e2ed9f
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import json
import multiprocessing
import queue
import re
from functools import partial
from threading import Thread
Expand Down Expand Up @@ -74,6 +75,7 @@ def get_killed_pids(ip, port, killed_pids):

def execute_remote(
cmd: str,
state_q: queue.Queue,
ip: str,
port: int,
username: Optional[str] = ""
Expand All @@ -82,6 +84,7 @@ def execute_remote(
Args:
cmd: User-defined command (udf) to execute on the remote host.
state_q: A queue collecting Thread exit states.
ip: The ip-address of the host to run the command on.
port: Port number that the host is listening on.
thread_list:
Expand All @@ -105,10 +108,17 @@ def execute_remote(
)

# thread func to run the job
def run(ssh_cmd):
subprocess.check_call(ssh_cmd, shell=True)

thread = Thread(target=run, args=(ssh_cmd,))
def run(ssh_cmd, state_q):
try:
subprocess.check_call(ssh_cmd, shell=True)
state_q.put(0)
except subprocess.CalledProcessError as err:
print(f"Called process error {err}")
state_q.put(err.returncode)
except Exception:
state_q.put(-1)

thread = Thread(target=run, args=(ssh_cmd, state_q,))
thread.setDaemon(True)
thread.start()
# sleep for a while in case of ssh is rejected by peer due to busy connection
Expand Down Expand Up @@ -535,6 +545,7 @@ def submit_jobs(args, udf_command, dry_run=False):
assert part_metadata['num_parts'] == len(hosts), \
'The number of graph partitions has to match the number of machines in the cluster.'

state_q = queue.Queue()
tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks
if not has_alive_servers(args):
Expand All @@ -557,7 +568,7 @@ def submit_jobs(args, udf_command, dry_run=False):
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
servers_cmd.append(cmd)
if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username))
else:
print(f"Use running server {args.server_name}.")

Expand Down Expand Up @@ -592,7 +603,7 @@ def submit_jobs(args, udf_command, dry_run=False):
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
clients_cmd.append(cmd)
if not dry_run:
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username))

# return commands of clients/servers directly if in dry run mode
if dry_run:
Expand All @@ -612,12 +623,21 @@ def signal_handler(signal, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)

err = 0
for thread in thread_list:
thread.join()
err_code = state_q.get()
if err_code != 0:
# Record err_code
# We record one of the error if there are multiple
err = err_code

# The training processes complete. We should tell the cleanup process to exit.
conn2.send('exit')
process.join()

if err != 0:
print("Task failed")
sys.exit(-1)

def main():
parser = argparse.ArgumentParser(description='Launch a distributed job')
Expand Down

0 comments on commit 7e2ed9f

Please sign in to comment.