Skip to content

Commit

Permalink
[Bug] Fix dtype mismatch in heterogeneous DataLoader (dmlc#3878)
Browse files Browse the repository at this point in the history
* fix

* unit test
  • Loading branch information
BarclayII authored Mar 26, 2022
1 parent e963256 commit f758db3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
10 changes: 6 additions & 4 deletions python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .. import ndarray as nd
from ..utils import (
recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads,
create_shared_mem_array, get_shared_mem_array, context_of)
create_shared_mem_array, get_shared_mem_array, context_of, dtype_of)
from ..frame import LazyFeature
from ..storages import wrap_storage
from .base import BlockSampler, as_edge_prediction_sampler
Expand Down Expand Up @@ -86,9 +86,11 @@ def __next__(self):


def _get_id_tensor_from_mapping(indices, device, keys):
lengths = torch.LongTensor([
(indices[k].shape[0] if k in indices else 0) for k in keys]).to(device)
type_ids = torch.arange(len(keys), device=device).repeat_interleave(lengths)
dtype = dtype_of(indices)
lengths = torch.tensor(
[(indices[k].shape[0] if k in indices else 0) for k in keys],
dtype=dtype, device=device)
type_ids = torch.arange(len(keys), dtype=dtype, device=device).repeat_interleave(lengths)
all_indices = torch.cat([indices[k] for k in keys if k in indices])
return torch.stack([type_ids, all_indices], 1)

Expand Down
4 changes: 4 additions & 0 deletions python/dgl/utils/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,4 +1019,8 @@ def context_of(data):
else:
return F.context(data)

def dtype_of(data):
"""Return the dtype of the data which can be either a tensor or a dict of tensors."""
return F.dtype(next(iter(data.values())) if isinstance(data, Mapping) else data)

_init_api("dgl.utils.internal")
23 changes: 20 additions & 3 deletions tests/pytorch/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import defaultdict
from collections.abc import Iterator, Mapping
from itertools import product
from test_utils import parametrize_dtype
import pytest


Expand Down Expand Up @@ -89,6 +90,15 @@ def test_neighbor_nonuniform(num_workers):
elif seed == 0:
assert neighbors == {3, 4}

def _check_dtype(data, dtype, attr_name):
if isinstance(data, dict):
for k, v in data.items():
assert getattr(v, attr_name) == dtype
elif isinstance(data, list):
for v in data:
assert getattr(v, attr_name) == dtype
else:
assert getattr(data, attr_name) == dtype

def _check_device(data):
if isinstance(data, dict):
Expand All @@ -100,10 +110,11 @@ def _check_device(data):
else:
assert data.device == F.ctx()

@parametrize_dtype
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
@pytest.mark.parametrize('pin_graph', [False, True])
def test_node_dataloader(sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
def test_node_dataloader(idtype, sampler_name, pin_graph):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
if F.ctx() != F.cpu() and pin_graph:
g1.create_formats_()
g1.pin_memory_()
Expand All @@ -123,13 +134,16 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
_check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')

g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
}).astype(idtype)
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
Expand All @@ -146,6 +160,9 @@ def test_node_dataloader(sampler_name, pin_graph):
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
_check_dtype(input_nodes, idtype, 'dtype')
_check_dtype(output_nodes, idtype, 'dtype')
_check_dtype(blocks, idtype, 'idtype')

if g1.is_pinned():
g1.unpin_memory_()
Expand Down

0 comments on commit f758db3

Please sign in to comment.