Skip to content

Commit

Permalink
[Bugfix] Fix graph being duplicated in multi-GPU and CPU dataloader w…
Browse files Browse the repository at this point in the history
…orkers (dmlc#3760)

* fix shared memory issue

* oops

* add explanation

* add explanation
  • Loading branch information
BarclayII authored Feb 22, 2022
1 parent 3f138eb commit 4f00d5a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
14 changes: 4 additions & 10 deletions examples/pytorch/__temporary__/graphsage/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,10 @@ def inference(self, g, device, batch_size, num_workers, buffer_device=None):
return y


def train(rank, world_size, shared_memory_name, features, num_classes, split_idx):
def train(rank, world_size, graph, num_classes, split_idx):
torch.cuda.set_device(rank)
dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank)

graph = dgl.hetero_from_shared_memory(shared_memory_name)
feat, labels = features
graph.ndata['feat'] = feat
graph.ndata['label'] = labels

model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
Expand Down Expand Up @@ -132,14 +127,13 @@ def train(rank, world_size, shared_memory_name, features, num_classes, split_idx
if __name__ == '__main__':
dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0]
shared_memory_name = 'shm' # can be any string
feat = graph.ndata['feat']
graph = graph.shared_memory(shared_memory_name)
graph.ndata['label'] = labels
graph.create_formats_() # must be called before mp.spawn().
split_idx = dataset.get_idx_split()
num_classes = dataset.num_classes
n_procs = 4

# Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs
# and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples.
import torch.multiprocessing as mp
mp.spawn(train, args=(n_procs, shared_memory_name, (feat, labels), num_classes, split_idx), nprocs=n_procs)
mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs)
30 changes: 22 additions & 8 deletions python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _set_python_exit_flag():
PYTHON_EXIT_STATUS = True
atexit.register(_set_python_exit_flag)

prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '10'))
prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '30'))

class _TensorizedDatasetIter(object):
def __init__(self, dataset, batch_size, drop_last, mapping_keys):
Expand Down Expand Up @@ -615,9 +615,11 @@ def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False,
raise ValueError(
'Expect graph and indices to be on the same device. '
'If you wish to use UVA sampling, please set use_uva=True.')
if self.graph.device.type == 'cuda':
if num_workers > 0:
raise ValueError('num_workers must be 0 if graph and indices are on CUDA.')
if self.graph.device.type == 'cuda' and num_workers > 0:
raise ValueError('num_workers must be 0 if graph and indices are on CUDA.')
if self.graph.device.type == 'cpu' and num_workers > 0:
# Instantiate all the formats if the number of workers is greater than 0.
self.graph.create_formats_()

# Check pin_prefetcher and use_prefetch_thread - should be only effective
# if performing CPU sampling but output device is CUDA
Expand Down Expand Up @@ -666,10 +668,6 @@ def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False,
self.use_prefetch_thread = use_prefetch_thread
worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None))

# Instantiate all the formats if the number of workers is greater than 0.
if num_workers > 0 and hasattr(self.graph, 'create_formats_'):
self.graph.create_formats_()

self.other_storages = {}

super().__init__(
Expand Down Expand Up @@ -732,6 +730,14 @@ class NodeDataLoader(DataLoader):
graph and feature tensors into pinned memory.
Default: False.
.. warning::
Using UVA with multiple GPUs may crash with device mismatch errors with
older CUDA drivers. We have confirmed that CUDA driver 450.142 will
crash while 465.19 will work. Therefore we recommend you to upgrade your
CUDA driver if you wish to use UVA with multiple GPUs.
use_prefetch_thread : bool, optional
(Advanced option)
Spawns a new Python thread to perform feature slicing
Expand Down Expand Up @@ -916,6 +922,14 @@ class EdgeDataLoader(DataLoader):
graph and feature tensors into pinned memory.
Default: False.
.. warning::
Using UVA with multiple GPUs may crash with device mismatch errors with
older CUDA drivers. We have confirmed that CUDA driver 450.142 will
crash while 465.19 will work. Therefore we recommend you to upgrade your
CUDA driver if you wish to use UVA with multiple GPUs.
batch_size : int, optional
drop_last : bool, optional
shuffle : bool, optional
Expand Down
14 changes: 13 additions & 1 deletion python/dgl/heterograph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,9 +1370,21 @@ def _forking_rebuild(pk_state):
meta, arrays = pk_state
arrays = [F.to_dgl_nd(arr) for arr in arrays]
states = _CAPI_DGLCreateHeteroPickleStates(meta, arrays)
return _CAPI_DGLHeteroForkingUnpickle(states)
graph_index = _CAPI_DGLHeteroForkingUnpickle(states)
graph_index._forking_pk_state = pk_state
return graph_index

def _forking_reduce(graph_index):
# Because F.from_dgl_nd(F.to_dgl_nd(x)) loses the information of shared memory
# file descriptor (because DLPack does not keep it), without caching the tensors
# PyTorch will allocate one shared memory region for every single worker.
# The downside is that if a graph_index is shared by forking and new formats are created
# afterwards, then sharing it again will not bring together the new formats. This case
# should be rare though because (1) DataLoader will create all the formats if num_workers > 0
# anyway, and (2) we require the users to explicitly create all formats before calling
# mp.spawn().
if hasattr(graph_index, '_forking_pk_state'):
return _forking_rebuild, (graph_index._forking_pk_state,)
states = _CAPI_DGLHeteroForkingPickle(graph_index)
arrays = [F.from_dgl_nd(arr) for arr in states.arrays]
# Similar to what being mentioned in HeteroGraphIndex.__getstate__, we need to save
Expand Down

0 comments on commit 4f00d5a

Please sign in to comment.