Skip to content

Commit

Permalink
fix dist launch script test=develop (PaddlePaddle#17404)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 authored May 15, 2019
1 parent 0823a7b commit 266444b
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions python/paddle/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@
GPUS = 8


def get_gpu_ids(gpus):
if os.getenv("CUDA_VISIBLE_DEVICES"):
ids = [int(i)
for i in os.getenv("CUDA_VISIBLE_DEVICES").split(",")][:gpus]
if gpus > len(ids):
raise EnvironmentError(
"The count of env CUDA_VISIBLE_DEVICES should not greater than the passed gpus: %s"
% gpus)
return ids
else:
return [i for i in range(gpus)]


def start_procs(gpus, entrypoint, entrypoint_args, log_dir):
procs = []
log_fns = []
Expand All @@ -61,8 +74,8 @@ def start_procs(gpus, entrypoint, entrypoint_args, log_dir):
all_nodes_devices_endpoints += "%s:617%d" % (n, i)
nranks = num_nodes * gpus
# ======== for dist training =======

for i in range(gpus):
gpu_ids = get_gpu_ids(gpus)
for i in gpu_ids:
curr_env = {}
curr_env.update(default_envs)
curr_env.update({
Expand Down

0 comments on commit 266444b

Please sign in to comment.