forked from HobbitLong/CMC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
124 lines (94 loc) · 3.17 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import print_function
import numpy as np
from skimage import color
import torch
import torchvision.datasets as datasets
class ImageFolderInstance(datasets.ImageFolder):
"""Folder datasets which returns the index of the image as well
"""
def __init__(self, root, transform=None, target_transform=None, two_crop=False):
super(ImageFolderInstance, self).__init__(root, transform, target_transform)
self.two_crop = two_crop
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target, index) where target is class_index of the target class.
"""
path, target = self.imgs[index]
image = self.loader(path)
if self.transform is not None:
img = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
if self.two_crop:
img2 = self.transform(image)
img = torch.cat([img, img2], dim=0)
return img, target, index
class RGB2Lab(object):
"""Convert RGB PIL image to ndarray Lab."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2lab(img)
return img
class RGB2HSV(object):
"""Convert RGB PIL image to ndarray HSV."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2hsv(img)
return img
class RGB2HED(object):
"""Convert RGB PIL image to ndarray HED."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2hed(img)
return img
class RGB2LUV(object):
"""Convert RGB PIL image to ndarray LUV."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2luv(img)
return img
class RGB2YUV(object):
"""Convert RGB PIL image to ndarray YUV."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2yuv(img)
return img
class RGB2XYZ(object):
"""Convert RGB PIL image to ndarray XYZ."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2xyz(img)
return img
class RGB2YCbCr(object):
"""Convert RGB PIL image to ndarray YCbCr."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2ycbcr(img)
return img
class RGB2YDbDr(object):
"""Convert RGB PIL image to ndarray YDbDr."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2ydbdr(img)
return img
class RGB2YPbPr(object):
"""Convert RGB PIL image to ndarray YPbPr."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2ypbpr(img)
return img
class RGB2YIQ(object):
"""Convert RGB PIL image to ndarray YIQ."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2yiq(img)
return img
class RGB2CIERGB(object):
"""Convert RGB PIL image to ndarray RGBCIE."""
def __call__(self, img):
img = np.asarray(img, np.uint8)
img = color.rgb2rgbcie(img)
return img