Skip to content

Commit

Permalink
[Distributed] Automatically setting the number of OMP threads for tra…
Browse files Browse the repository at this point in the history
…iners (dmlc#2812)

* set omp thread.

* add comment.

* fix.
  • Loading branch information
zheng-da authored Apr 4, 2021
1 parent cba5af2 commit 86229d4
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import time
import json
import multiprocessing
from threading import Thread

DEFAULT_PORT = 30050
Expand Down Expand Up @@ -77,6 +78,8 @@ def submit_jobs(args, udf_command):
client_cmd = client_cmd + ' ' + 'DGL_NUM_SERVER=' + str(args.num_servers)
if os.environ.get('OMP_NUM_THREADS') is not None:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + os.environ.get('OMP_NUM_THREADS')
else:
client_cmd = client_cmd + ' ' + 'OMP_NUM_THREADS=' + str(args.num_omp_threads)
if os.environ.get('PYTHONPATH') is not None:
client_cmd = client_cmd + ' ' + 'PYTHONPATH=' + os.environ.get('PYTHONPATH')

Expand Down Expand Up @@ -111,6 +114,8 @@ def main():
the contents of current directory will be rsyncd')
parser.add_argument('--num_trainers', type=int,
help='The number of trainer processes per machine')
parser.add_argument('--num_omp_threads', type=int,
help='The number of OMP threads per trainer')
parser.add_argument('--num_samplers', type=int, default=0,
help='The number of sampler processes per trainer process')
parser.add_argument('--num_servers', type=int,
Expand All @@ -137,6 +142,12 @@ def main():
'A user has to specify a partition configuration file with --part_config.'
assert args.ip_config is not None, \
'A user has to specify an IP configuration file with --ip_config.'
if args.num_omp_threads is None:
# Here we assume all machines have the same number of CPU cores as the machine
# where the launch script runs.
args.num_omp_threads = max(multiprocessing.cpu_count() // 2 // args.num_trainers, 1)
print('The number of OMP threads per trainer is set to', args.num_omp_threads)

udf_command = str(udf_command[0])
if 'python' not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.")
Expand Down

0 comments on commit 86229d4

Please sign in to comment.