Skip to content

Commit

Permalink
[Tools] In tools/launch.py, correctly pass all DGL client/server en…
Browse files Browse the repository at this point in the history
…v vars if udf is a multi-command (dmlc#3245)

* Correctly pass all DGL client/server env vars if udf is a multi-command

* Refactor to use wrap_cmd_with_local_envvars() helper fn
  • Loading branch information
erickim555 authored Aug 17, 2021
1 parent 1d4d4f5 commit ac01e88
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 26 deletions.
92 changes: 91 additions & 1 deletion tests/tools/test_launch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

from tools.launch import wrap_udf_in_torch_dist_launcher
from tools.launch import wrap_udf_in_torch_dist_launcher, wrap_cmd_with_local_envvars, construct_dgl_server_env_vars, \
construct_dgl_client_env_vars


class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
Expand Down Expand Up @@ -60,5 +61,94 @@ def test_py_versions(self):
self.assertEqual(wrapped_udf_command, expected)


class TestWrapCmdWithLocalEnvvars(unittest.TestCase):
"""wrap_cmd_with_local_envvars()"""

def test_simple(self):
self.assertEqual(
wrap_cmd_with_local_envvars("ls && pwd", "VAR1=value1 VAR2=value2"),
"(export VAR1=value1 VAR2=value2; ls && pwd)"
)


class TestConstructDglServerEnvVars(unittest.TestCase):
"""construct_dgl_server_env_vars()"""
def test_simple(self):
self.assertEqual(
construct_dgl_server_env_vars(
num_samplers=2,
num_server_threads=3,
tot_num_clients=4,
part_config="path/to/part.config",
ip_config="path/to/ip.config",
num_servers=5,
graph_format="csc"
),
(
"DGL_ROLE=server "
"DGL_NUM_SAMPLER=2 "
"OMP_NUM_THREADS=3 "
"DGL_NUM_CLIENT=4 "
"DGL_CONF_PATH=path/to/part.config "
"DGL_IP_CONFIG=path/to/ip.config "
"DGL_NUM_SERVER=5 "
"DGL_GRAPH_FORMAT=csc "
)
)


class TestConstructDglClientEnvVars(unittest.TestCase):
"""construct_dgl_client_env_vars()"""
def test_simple(self):
# with pythonpath
self.assertEqual(
construct_dgl_client_env_vars(
num_samplers=1,
tot_num_clients=2,
part_config="path/to/part.config",
ip_config="path/to/ip.config",
num_servers=3,
graph_format="csc",
num_omp_threads=4,
pythonpath="some/pythonpath/"
),
(
"DGL_DIST_MODE=distributed "
"DGL_ROLE=client "
"DGL_NUM_SAMPLER=1 "
"DGL_NUM_CLIENT=2 "
"DGL_CONF_PATH=path/to/part.config "
"DGL_IP_CONFIG=path/to/ip.config "
"DGL_NUM_SERVER=3 "
"DGL_GRAPH_FORMAT=csc "
"OMP_NUM_THREADS=4 "
"PYTHONPATH=some/pythonpath/ "
)
)
# without pythonpath
self.assertEqual(
construct_dgl_client_env_vars(
num_samplers=1,
tot_num_clients=2,
part_config="path/to/part.config",
ip_config="path/to/ip.config",
num_servers=3,
graph_format="csc",
num_omp_threads=4,
),
(
"DGL_DIST_MODE=distributed "
"DGL_ROLE=client "
"DGL_NUM_SAMPLER=1 "
"DGL_NUM_CLIENT=2 "
"DGL_CONF_PATH=path/to/part.config "
"DGL_IP_CONFIG=path/to/ip.config "
"DGL_NUM_SERVER=3 "
"DGL_GRAPH_FORMAT=csc "
"OMP_NUM_THREADS=4 "
)
)


if __name__ == '__main__':
unittest.main()
186 changes: 161 additions & 25 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,150 @@ def wrap_udf_in_torch_dist_launcher(
return new_udf_command


def construct_dgl_server_env_vars(
num_samplers: int,
num_server_threads: int,
tot_num_clients: int,
part_config: str,
ip_config: str,
num_servers: int,
graph_format: str,
) -> str:
"""Constructs the DGL server-specific env vars string that are required for DGL code to behave in the correct
server role.
Convenience function.
Args:
num_samplers:
num_server_threads:
tot_num_clients:
part_config: Partition config.
Relative path to workspace.
ip_config: IP config file containing IP addresses of cluster hosts.
Relative path to workspace.
num_servers:
graph_format:
Returns:
server_env_vars: The server-specific env-vars in a string format, friendly for CLI execution.
"""
server_env_vars_template = (
"DGL_ROLE={DGL_ROLE} "
"DGL_NUM_SAMPLER={DGL_NUM_SAMPLER} "
"OMP_NUM_THREADS={OMP_NUM_THREADS} "
"DGL_NUM_CLIENT={DGL_NUM_CLIENT} "
"DGL_CONF_PATH={DGL_CONF_PATH} "
"DGL_IP_CONFIG={DGL_IP_CONFIG} "
"DGL_NUM_SERVER={DGL_NUM_SERVER} "
"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
)
return server_env_vars_template.format(
DGL_ROLE="server",
DGL_NUM_SAMPLER=num_samplers,
OMP_NUM_THREADS=num_server_threads,
DGL_NUM_CLIENT=tot_num_clients,
DGL_CONF_PATH=part_config,
DGL_IP_CONFIG=ip_config,
DGL_NUM_SERVER=num_servers,
DGL_GRAPH_FORMAT=graph_format,
)


def construct_dgl_client_env_vars(
num_samplers: int,
tot_num_clients: int,
part_config: str,
ip_config: str,
num_servers: int,
graph_format: str,
num_omp_threads: int,
pythonpath: Optional[str] = "",
) -> str:
"""Constructs the DGL client-specific env vars string that are required for DGL code to behave in the correct
client role.
Convenience function.
Args:
num_samplers:
tot_num_clients:
part_config: Partition config.
Relative path to workspace.
ip_config: IP config file containing IP addresses of cluster hosts.
Relative path to workspace.
num_servers:
graph_format:
num_omp_threads:
pythonpath: Optional. If given, this will pass this as PYTHONPATH.
Returns:
client_env_vars: The client-specific env-vars in a string format, friendly for CLI execution.
"""
client_env_vars_template = (
"DGL_DIST_MODE={DGL_DIST_MODE} "
"DGL_ROLE={DGL_ROLE} "
"DGL_NUM_SAMPLER={DGL_NUM_SAMPLER} "
"DGL_NUM_CLIENT={DGL_NUM_CLIENT} "
"DGL_CONF_PATH={DGL_CONF_PATH} "
"DGL_IP_CONFIG={DGL_IP_CONFIG} "
"DGL_NUM_SERVER={DGL_NUM_SERVER} "
"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
"OMP_NUM_THREADS={OMP_NUM_THREADS} "
"{suffix_optional_envvars}"
)
# append optional additional env-vars
suffix_optional_envvars = ""
if pythonpath:
suffix_optional_envvars += f"PYTHONPATH={pythonpath} "
return client_env_vars_template.format(
DGL_DIST_MODE="distributed",
DGL_ROLE="client",
DGL_NUM_SAMPLER=num_samplers,
DGL_NUM_CLIENT=tot_num_clients,
DGL_CONF_PATH=part_config,
DGL_IP_CONFIG=ip_config,
DGL_NUM_SERVER=num_servers,
DGL_GRAPH_FORMAT=graph_format,
OMP_NUM_THREADS=num_omp_threads,
suffix_optional_envvars=suffix_optional_envvars,
)


def wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str:
"""Wraps a CLI command with desired env vars with the following properties:
(1) env vars persist for the entire `cmd`, even if it consists of multiple "chained" commands like:
cmd = "ls && pwd && python run/something.py"
(2) env vars don't pollute the environment after `cmd` completes.
Example:
>>> cmd = "ls && pwd"
>>> env_vars = "VAR1=value1 VAR2=value2"
>>> wrap_cmd_with_local_envvars(cmd, env_vars)
"(export VAR1=value1 VAR2=value2; ls && pwd)"
Args:
cmd:
env_vars: A string containing env vars, eg "VAR1=val1 VAR2=val2"
Returns:
cmd_with_env_vars:
"""
# use `export` to persist env vars for entire cmd block. required if udf_command is a chain of commands
# also: wrap in parens to not pollute env:
# https://stackoverflow.com/a/45993803
return f"(export {env_vars}; {cmd})"


def submit_jobs(args, udf_command):
"""Submit distributed jobs (server and client processes) via ssh"""
hosts = []
thread_list = []
server_count_per_machine = 0

# Get the IP addresses of the cluster.
ip_config = args.workspace + '/' + args.ip_config
ip_config = os.path.join(args.workspace, args.ip_config)
with open(ip_config) as f:
for line in f:
result = line.strip().split()
Expand All @@ -286,7 +422,7 @@ def submit_jobs(args, udf_command):
raise RuntimeError("Format error of ip_config.")
server_count_per_machine = args.num_servers
# Get partition info of the graph data
part_config = args.workspace + '/' + args.part_config
part_config = os.path.join(args.workspace, args.part_config)
with open(part_config) as conf_f:
part_metadata = json.load(conf_f)
assert 'num_parts' in part_metadata, 'num_parts does not exist.'
Expand All @@ -296,33 +432,33 @@ def submit_jobs(args, udf_command):

tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks
server_cmd = 'DGL_ROLE=server DGL_NUM_SAMPLER=' + str(args.num_samplers)
server_cmd = server_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_server_threads)
server_cmd = server_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
server_cmd = server_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
server_cmd = server_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
server_cmd = server_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
server_cmd = server_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)
for i in range(len(hosts)*server_count_per_machine):
server_env_vars = construct_dgl_server_env_vars(
num_samplers=args.num_samplers,
num_server_threads=args.num_server_threads,
tot_num_clients=tot_num_clients,
part_config=args.part_config,
ip_config=args.ip_config,
num_servers=args.num_servers,
graph_format=args.graph_format,
)
for i in range(len(hosts) * server_count_per_machine):
ip, _ = hosts[int(i / server_count_per_machine)]
cmd = server_cmd + ' ' + 'DGL_SERVER_ID=' + str(i)
cmd = cmd + ' ' + udf_command
server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}"
cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))

# launch client tasks
client_cmd = 'DGL_DIST_MODE="distributed" DGL_ROLE=client DGL_NUM_SAMPLER=' + str(args.num_samplers)
client_cmd = client_cmd + ' ' + 'DGL_NUM_CLIENT=' + str(tot_num_clients)
client_cmd = client_cmd + ' ' + 'DGL_CONF_PATH=' + str(args.part_config)
client_cmd = client_cmd + ' ' + 'DGL_IP_CONFIG=' + str(args.ip_config)
client_cmd = client_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
if os.environ.get('OMP_NUM_THREADS') is not None:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS')
else:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_omp_threads)
if os.environ.get('PYTHONPATH') is not None:
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')
client_cmd = client_cmd + ' ' + 'DGL_GRAPH_FORMAT=' + str(args.graph_format)
client_env_vars = construct_dgl_client_env_vars(
num_samplers=args.num_samplers,
tot_num_clients=tot_num_clients,
part_config=args.part_config,
ip_config=args.ip_config,
num_servers=args.num_servers,
graph_format=args.graph_format,
num_omp_threads=os.environ.get("OMP_NUM_THREADS", str(args.num_omp_threads)),
pythonpath=os.environ.get("PYTHONPATH", ""),
)

for node_id, host in enumerate(hosts):
ip, _ = host
Expand All @@ -335,7 +471,7 @@ def submit_jobs(args, udf_command):
master_addr=hosts[0][0],
master_port=1234,
)
cmd = client_cmd + ' ' + torch_dist_udf_command
cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars)
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))

Expand Down

0 comments on commit ac01e88

Please sign in to comment.