Skip to content

Commit

Permalink
[GraphBolt][CUDA] Dataloader feature overlap fix (dmlc#7036)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jan 30, 2024
1 parent 6837725 commit b085224
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 36 deletions.
5 changes: 1 addition & 4 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ def create_dataloader(
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device)

# Until https://github.com/dmlc/dgl/issues/7008, overlap should be False.
dataloader = gb.DataLoader(
datapipe, args.num_workers, overlap_feature_fetch=False
)
dataloader = gb.DataLoader(datapipe, args.num_workers)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down
42 changes: 16 additions & 26 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@

__all__ = [
"DataLoader",
"Awaiter",
"Bufferer",
]


def _find_and_wrap_parent(
datapipe_graph, datapipe_adjlist, target_datapipe, wrapper, **kwargs
):
def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
datapipe_graph,
target_datapipe,
)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
for datapipe in datapipes:
datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
Expand All @@ -36,6 +37,7 @@ def _find_and_wrap_parent(
parent_datapipe,
wrapper(parent_datapipe, **kwargs),
)
return datapipe_graph


class EndMarker(dp.iter.IterDataPipe):
Expand All @@ -45,8 +47,7 @@ def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for data in self.datapipe:
yield data
yield from self.datapipe


class Bufferer(dp.iter.IterDataPipe):
Expand All @@ -58,11 +59,11 @@ class Bufferer(dp.iter.IterDataPipe):
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider increasing passing a high
value. Default is 2.
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""

def __init__(self, datapipe, buffer_size=2):
def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
Expand Down Expand Up @@ -180,7 +181,6 @@ def __init__(

datapipe = EndMarker(datapipe)
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)

# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
Expand All @@ -198,9 +198,8 @@ def __init__(
)

# (2) Cut datapipe at FeatureFetcher and wrap.
_find_and_wrap_parent(
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
FeatureFetcher,
MultiprocessingWrapper,
num_workers=num_workers,
Expand All @@ -221,25 +220,16 @@ def __init__(
)
for feature_fetcher in feature_fetchers:
feature_fetcher.stream = _get_uva_stream()
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Bufferer,
buffer_size=2,
)
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Awaiter,
)
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
feature_fetcher,
Awaiter(Bufferer(feature_fetcher, buffer_size=1)),
)

# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent(
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
Expand Down
29 changes: 23 additions & 6 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pytest
import torch

import torchdata.dataloader2.graph as dp_utils

from . import gb_test_utils


Expand Down Expand Up @@ -46,7 +48,8 @@ def test_DataLoader():
reason="This test requires the GPU.",
)
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch):
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
Expand All @@ -70,13 +73,27 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
datapipe = dgl.graphbolt.FeatureFetcher(
datapipe,
feature_store,
["a", "b"],
)
if enable_feature_fetch:
datapipe = dgl.graphbolt.FeatureFetcher(
datapipe,
feature_store,
["a", "b"],
)

dataloader = dgl.graphbolt.DataLoader(
datapipe, overlap_feature_fetch=overlap_feature_fetch
)
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
datapipe = dataloader.dataset
datapipe_graph = dp_utils.traverse_dps(datapipe)
awaiters = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Awaiter,
)
assert len(awaiters) == bufferer_awaiter_cnt
bufferers = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,
)
assert len(bufferers) == bufferer_awaiter_cnt
assert len(list(dataloader)) == N // B

0 comments on commit b085224

Please sign in to comment.