diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 4d5e8a51b4a80..77fb4c9e92959 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -173,6 +173,36 @@ def check_len(dl, expected): check_len(DataLoader(self.dataset, batch_size=2), 50) check_len(DataLoader(self.dataset, batch_size=3), 34) + @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") + def test_numpy_scalars(self): + import numpy as np + + class ScalarDataset(torch.utils.data.Dataset): + def __init__(self, dtype): + self.dtype = dtype + + def __getitem__(self, i): + return self.dtype() + + def __len__(self): + return 4 + + dtypes = { + np.float64: torch.DoubleTensor, + np.float32: torch.FloatTensor, + np.float16: torch.HalfTensor, + np.int64: torch.LongTensor, + np.int32: torch.IntTensor, + np.int16: torch.ShortTensor, + np.int8: torch.CharTensor, + np.uint8: torch.ByteTensor, + } + for dt, tt in dtypes.items(): + dset = ScalarDataset(dt) + loader = DataLoader(dset, batch_size=2) + batch = next(iter(loader)) + self.assertIsInstance(batch, tt) + class StringDataset(Dataset): def __init__(self): diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 797205565adda..69305585b9f17 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -60,12 +60,29 @@ def _pin_memory_loop(in_queue, out_queue, done_event): out_queue.put((idx, batch)) +numpy_type_map = { + 'float64': torch.DoubleTensor, + 'float32': torch.FloatTensor, + 'float16': torch.HalfTensor, + 'int64': torch.LongTensor, + 'int32': torch.IntTensor, + 'int16': torch.ShortTensor, + 'int8': torch.CharTensor, + 'uint8': torch.ByteTensor, +} + + def default_collate(batch): "Puts each data field into a tensor with outer dimension batch size" if torch.is_tensor(batch[0]): return torch.stack(batch, 0) - elif type(batch[0]).__module__ == 'numpy' and type(batch[0]).__name__ == 'ndarray': - return torch.stack([torch.from_numpy(b) for b in batch], 0) + elif type(batch[0]).__module__ == 'numpy': + elem = batch[0] + if type(elem).__name__ == 'ndarray': + return torch.stack([torch.from_numpy(b) for b in batch], 0) + if elem.shape == (): # scalars + py_type = float if elem.dtype.name.startswith('float') else int + return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], int): return torch.LongTensor(batch) elif isinstance(batch[0], float):