Skip to content

Commit

Permalink
[Distributed] Use barrier instead of sleep in DistDataloader (dmlc#2086)
Browse files Browse the repository at this point in the history
* use barrier instead of sleep

* lint
  • Loading branch information
VoVAllen authored Aug 21, 2020
1 parent 6d21298 commit cd204a4
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Multiprocess dataloader for distributed training"""
import multiprocessing as mp
from queue import Queue
import time
import traceback

from .dist_context import get_sampler_pool
Expand All @@ -25,34 +24,32 @@ def call_collate_fn(name, next_data):
DGL_GLOBAL_COLLATE_FNS = {}
DGL_GLOBAL_MP_QUEUES = {}

def init_fn(name, collate_fn, queue):
def init_fn(barrier, name, collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES[name] = queue
DGL_GLOBAL_COLLATE_FNS[name] = collate_fn
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
barrier.wait()
return 1

def cleanup_fn(name):
def cleanup_fn(barrier, name):
"""Clean up the data of a dataloader in the worker process"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
del DGL_GLOBAL_MP_QUEUES[name]
del DGL_GLOBAL_COLLATE_FNS[name]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
barrier.wait()
return 1


def enable_mp_debug():
"""Print multiprocessing debug information. This is only
for debug usage"""
import logging
logger = multiprocessing.log_to_stderr()
logger = mp.log_to_stderr()
logger.setLevel(logging.DEBUG)

DATALOADER_ID = 0
Expand Down Expand Up @@ -122,6 +119,7 @@ def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_las
self.current_pos = 0
if self.pool is not None:
self.m = mp.Manager()
self.barrier = self.m.Barrier(self.num_workers)
self.queue = self.m.Queue(maxsize=queue_size)
else:
self.queue = Queue(maxsize=queue_size)
Expand All @@ -145,15 +143,15 @@ def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_las
results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(
init_fn, args=(self.name, self.collate_fn, self.queue)))
init_fn, args=(self.barrier, self.name, self.collate_fn, self.queue)))
for res in results:
res.get()

def __del__(self):
if self.pool is not None:
results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(self.name,)))
results.append(self.pool.apply_async(cleanup_fn, args=(self.barrier, self.name,)))
for res in results:
res.get()

Expand All @@ -162,7 +160,7 @@ def __next__(self):
for _ in range(num_reqs):
self._request_next_batch()
if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=9999)
result = self.queue.get(timeout=1800)
self.recv_idxs += 1
self.num_pending -= 1
return result
Expand Down

0 comments on commit cd204a4

Please sign in to comment.