Skip to content

Commit

Permalink
Retain the type of numpy scalars in collate_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Apr 11, 2017
1 parent 2087b11 commit 605b3c8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
30 changes: 30 additions & 0 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,36 @@ def check_len(dl, expected):
check_len(DataLoader(self.dataset, batch_size=2), 50)
check_len(DataLoader(self.dataset, batch_size=3), 34)

@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_numpy_scalars(self):
import numpy as np

class ScalarDataset(torch.utils.data.Dataset):
def __init__(self, dtype):
self.dtype = dtype

def __getitem__(self, i):
return self.dtype()

def __len__(self):
return 4

dtypes = {
np.float64: torch.DoubleTensor,
np.float32: torch.FloatTensor,
np.float16: torch.HalfTensor,
np.int64: torch.LongTensor,
np.int32: torch.IntTensor,
np.int16: torch.ShortTensor,
np.int8: torch.CharTensor,
np.uint8: torch.ByteTensor,
}
for dt, tt in dtypes.items():
dset = ScalarDataset(dt)
loader = DataLoader(dset, batch_size=2)
batch = next(iter(loader))
self.assertIsInstance(batch, tt)


class StringDataset(Dataset):
def __init__(self):
Expand Down
21 changes: 19 additions & 2 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,29 @@ def _pin_memory_loop(in_queue, out_queue, done_event):
out_queue.put((idx, batch))


numpy_type_map = {
'float64': torch.DoubleTensor,
'float32': torch.FloatTensor,
'float16': torch.HalfTensor,
'int64': torch.LongTensor,
'int32': torch.IntTensor,
'int16': torch.ShortTensor,
'int8': torch.CharTensor,
'uint8': torch.ByteTensor,
}


def default_collate(batch):
"Puts each data field into a tensor with outer dimension batch size"
if torch.is_tensor(batch[0]):
return torch.stack(batch, 0)
elif type(batch[0]).__module__ == 'numpy' and type(batch[0]).__name__ == 'ndarray':
return torch.stack([torch.from_numpy(b) for b in batch], 0)
elif type(batch[0]).__module__ == 'numpy':
elem = batch[0]
if type(elem).__name__ == 'ndarray':
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
Expand Down

0 comments on commit 605b3c8

Please sign in to comment.