Skip to content

Commit

Permalink
[Distributed] Fix dataloader (dmlc#1970)
Browse files Browse the repository at this point in the history
* fix dataloader.

* initialize iterator of DistDataloader correctly.

* update test.
  • Loading branch information
zheng-da authored Aug 8, 2020
1 parent 264d96c commit bab32d5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 45 deletions.
63 changes: 45 additions & 18 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,37 @@
__all__ = ["DistDataLoader"]


def call_collate_fn(next_data):
def call_collate_fn(name, next_data):
"""Call collate function"""
try:
result = DGL_GLOBAL_COLLATE_FN(next_data)
DGL_GLOBAL_MP_QUEUE.put(result)
result = DGL_GLOBAL_COLLATE_FNS[name](next_data)
DGL_GLOBAL_MP_QUEUES[name].put(result)
except Exception as e:
traceback.print_exc()
print(e)
raise e
return 1

DGL_GLOBAL_COLLATE_FNS = {}
DGL_GLOBAL_MP_QUEUES = {}

def init_fn(collate_fn, queue):
def init_fn(name, collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FN
global DGL_GLOBAL_MP_QUEUE
DGL_GLOBAL_MP_QUEUE = queue
DGL_GLOBAL_COLLATE_FN = collate_fn
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES[name] = queue
DGL_GLOBAL_COLLATE_FNS[name] = collate_fn
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
return 1

def cleanup_fn(name):
"""Clean up the data of a dataloader in the worker process"""
global DGL_GLOBAL_COLLATE_FNS
global DGL_GLOBAL_MP_QUEUES
del DGL_GLOBAL_MP_QUEUES[name]
del DGL_GLOBAL_COLLATE_FNS[name]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
Expand All @@ -41,6 +54,7 @@ def enable_mp_debug():
logger = multiprocessing.log_to_stderr()
logger.setLevel(logging.DEBUG)

DATALOADER_ID = 0

class DistDataLoader:
"""DGL customized multiprocessing dataloader, which is designed for using with DistGraph."""
Expand Down Expand Up @@ -90,19 +104,32 @@ def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_las
if self.pool is None:
ctx = mp.get_context("spawn")
self.pool = ctx.Pool(num_workers)
results = []

for _ in range(num_workers):
results.append(self.pool.apply_async(
init_fn, args=(collate_fn, self.queue)))
for res in results:
res.get()

self.dataset = F.tensor(dataset)
self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1

# We need to have a unique Id for each data loader to identify itself
# in the sampler processes.
global DATALOADER_ID
self.name = "dataloader-" + str(DATALOADER_ID)
DATALOADER_ID += 1

results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(
init_fn, args=(self.name, self.collate_fn, self.queue)))
for res in results:
res.get()

def __del__(self):
results = []
for _ in range(self.num_workers):
results.append(self.pool.apply_async(cleanup_fn, args=(self.name,)))
for res in results:
res.get()

def __next__(self):
if not self.started:
for _ in range(self.queue_size):
Expand All @@ -113,13 +140,13 @@ def __next__(self):
self.recv_idxs += 1
return result
else:
self.recv_idxs = 0
self.current_pos = 0
raise StopIteration

def __iter__(self):
if self.shuffle:
self.dataset = F.rand_shuffle(self.dataset)
self.recv_idxs = 0
self.current_pos = 0
return self

def _request_next_batch(self):
Expand All @@ -128,7 +155,7 @@ def _request_next_batch(self):
return None
else:
async_result = self.pool.apply_async(
call_collate_fn, args=(next_data, ))
call_collate_fn, args=(self.name, next_data, ))
return async_result

def _next_data(self):
Expand Down
56 changes: 29 additions & 27 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,34 @@ def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
sampler = NeighborSampler(dist_graph, [5, 10],
dgl.distributed.sample_neighbors)

# Create DataLoader for constructing blocks
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=False,
drop_last=drop_last)

groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = []

for epoch in range(2):
for idx, blocks in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
block = blocks[-1]
o_src, o_dst = block.edges()
src_nodes_id = block.srcdata[dgl.NID][o_src]
dst_nodes_id = block.dstdata[dgl.NID][o_dst]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
assert np.all(F.asnumpy(has_edges))
print(np.unique(np.sort(F.asnumpy(dst_nodes_id))))
max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
if drop_last:
assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
else:
assert np.max(max_nid) == num_nodes_to_sample - 1
# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=False,
drop_last=drop_last)

groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = []

for epoch in range(2):
for idx, blocks in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
block = blocks[-1]
o_src, o_dst = block.edges()
src_nodes_id = block.srcdata[dgl.NID][o_src]
dst_nodes_id = block.dstdata[dgl.NID][o_dst]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
assert np.all(F.asnumpy(has_edges))
print(np.unique(np.sort(F.asnumpy(dst_nodes_id))))
max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
if drop_last:
assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
else:
assert np.max(max_nid) == num_nodes_to_sample - 1

dgl.distributed.exit_client() # this is needed since there's two test here in one process

Expand Down Expand Up @@ -126,4 +128,4 @@ def test_dist_dataloader(tmpdir, num_server, drop_last):
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
test_dist_dataloader(Path(tmpdirname), 3, True)
test_dist_dataloader(Path(tmpdirname), 3, True)

0 comments on commit bab32d5

Please sign in to comment.