Skip to content

Commit

Permalink
basic dataloader for high GPU utilization
Browse files Browse the repository at this point in the history
  • Loading branch information
ljk628 committed Sep 12, 2018
1 parent e15ae71 commit 95305cd
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,23 @@ def load_dataset(dataset='cifar10', datapath='cifar10/data', batch_size=128, \
transforms.ToTensor(),
normalize,
])

trainset = torchvision.datasets.CIFAR10(root=data_folder, train=True,
download=True, transform=transform)
indices = torch.tensor(np.arange(len(trainset)))
data_num = len(trainset)/data_split
ind_start = data_num*split_idx
ind_end = min(data_num*(split_idx + 1), len(trainset))
train_indices = indices[ind_start:ind_end]
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
sampler=train_sampler,
shuffle=False, num_workers=threads)

if data_split > 1:
indices = torch.tensor(np.arange(len(trainset)))
data_num = len(trainset)/data_split
ind_start = data_num*split_idx
ind_end = min(data_num*(split_idx + 1), len(trainset))
train_indices = indices[ind_start:ind_end]
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
sampler=train_sampler,
shuffle=False, num_workers=threads)
else:
kwargs = {'num_workers': 2, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=False, **kwargs)
testset = torchvision.datasets.CIFAR10(root=data_folder, train=False,
download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
Expand Down

0 comments on commit 95305cd

Please sign in to comment.