-
Notifications
You must be signed in to change notification settings - Fork 55
/
data_loader.py
84 lines (65 loc) · 2.73 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf_8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from torch.utils.data import Dataset
from torchvision import transforms, datasets
from PIL import Image
def train_data_loader(data_path, img_size, use_augment=False):
if use_augment:
data_transforms = transforms.Compose([
transforms.RandomOrder([
transforms.RandomApply([transforms.ColorJitter(contrast=0.5)], .5),
transforms.Compose([
transforms.RandomApply([transforms.ColorJitter(saturation=0.5)], .5),
transforms.RandomApply([transforms.ColorJitter(hue=0.1)], .5),
])
]),
transforms.RandomApply([transforms.ColorJitter(brightness=0.125)], .5),
transforms.RandomApply([transforms.RandomRotation(15)], .5),
transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
else:
data_transforms = transforms.Compose([
transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_dataset = datasets.ImageFolder(data_path, data_transforms)
return image_dataset
def test_data_loader(data_path):
# return full path
queries_path = [os.path.join(data_path, 'query', path) for path in os.listdir(os.path.join(data_path, 'query'))]
references_path = [os.path.join(data_path, 'reference', path) for path in
os.listdir(os.path.join(data_path, 'reference'))]
return queries_path, references_path
def test_data_generator(data_path, img_size):
img_size = (img_size, img_size)
data_transforms = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_image_dataset = TestDataset(data_path, data_transforms)
return test_image_dataset
class TestDataset(Dataset):
def __init__(self, img_path_list, transform=None):
self.img_path_list = img_path_list
self.transform = transform
def __getitem__(self, index):
img_path = self.img_path_list[index]
img = Image.open(img_path)
if self.transform is not None:
img = self.transform(img)
return img_path, img
def __len__(self):
return len(self.img_path_list)
if __name__ == '__main__':
query, refer = test_data_loader('./')
print(query)
print(refer)