Skip to content

Commit

Permalink
fix ddp dataloader in heterogeneous cases (dmlc#3801)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Mar 7, 2022
1 parent bb6cec2 commit 44638b9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/pytorch/graphsage/multi_gpu_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0))
Expand Down Expand Up @@ -77,7 +77,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
graph, train_idx, sampler,
device='cuda', batch_size=1000, shuffle=True, drop_last=False,
num_workers=0, use_ddp=True, use_uva=True)
valid_dataloader = dgl.dataloading.NodeDataLoader(
valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=True)

Expand Down
12 changes: 7 additions & 5 deletions python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,10 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
def __init__(self, indices, batch_size, drop_last, ddp_seed):
if isinstance(indices, Mapping):
self._mapping_keys = list(indices.keys())
len_indices = sum(len(v) for v in indices.values())
else:
self._mapping_keys = None
len_indices = len(indices)

self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size()
Expand All @@ -179,17 +181,17 @@ def __init__(self, indices, batch_size, drop_last, ddp_seed):
self.batch_size = batch_size
self.drop_last = drop_last

if self.drop_last and len(indices) % self.num_replicas != 0:
self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas)
if self.drop_last and len_indices % self.num_replicas != 0:
self.num_samples = math.ceil((len_indices - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(indices) / self.num_replicas)
self.num_samples = math.ceil(len_indices / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
# If drop_last is True, we create a shared memory array larger than the number
# of indices since we will need to pad it after shuffling to make it evenly
# divisible before every epoch. If drop_last is False, we create an array
# with the same size as the indices so we can trim it later.
self.shared_mem_size = self.total_size if not self.drop_last else len(indices)
self.num_indices = len(indices)
self.shared_mem_size = self.total_size if not self.drop_last else len_indices
self.num_indices = len_indices

if isinstance(indices, Mapping):
self._device = next(iter(indices.values())).device
Expand Down

0 comments on commit 44638b9

Please sign in to comment.