-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
101 lines (75 loc) · 2.92 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- encoding: utf-8 -*-
# Author : Haitong
import pandas as pd
import os
import torch as t
import numpy as np
import torchvision.transforms.functional as ff
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import cfg
class LabelProcessor:
def __init__(self, file_path):
self.colormap = self.read_color_map(file_path)
self.cm2lbl = self.encode_label_pix(self.colormap)
@staticmethod
def read_color_map(file_path):
pd_label_color = pd.read_csv(file_path, sep=',')
colormap = []
for i in range(len(pd_label_color.index)):
tmp = pd_label_color.iloc[i]
color = [tmp['r'], tmp['g'], tmp['b']]
colormap.append(color)
return colormap
@staticmethod
def encode_label_pix(colormap):
cm2lbl = np.zeros(256 ** 3)
for i, cm in enumerate(colormap):
cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
return cm2lbl
def encode_label_img(self, img):
data = np.array(img, dtype='int32')
idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
return np.array(self.cm2lbl[idx], dtype='int64')
class LoadDataset(Dataset):
def __init__(self, file_path=[], crop_size=None):
self.img_path = file_path[0]
self.label_path = file_path[1]
self.imgs = self.read_file(self.img_path)
self.labels = self.read_file(self.label_path)
self.crop_size = crop_size
def __getitem__(self, index):
img = self.imgs[index]
label = self.labels[index]
img = Image.open(img)
label = Image.open(label).convert('RGB')
img, label = self.center_crop(img, label, self.crop_size)
img, label = self.img_transform(img, label)
sample = {'img': img, 'label': label}
return sample
def __len__(self):
return len(self.imgs)
def read_file(self, path):
files_list = os.listdir(path)
file_path_list = [os.path.join(path, img) for img in files_list]
file_path_list.sort()
return file_path_list
def center_crop(self, data, label, crop_size):
data = ff.center_crop(data, crop_size)
label = ff.center_crop(label, crop_size)
return data, label
def img_transform(self, img, label):
label = np.array(label)
label = Image.fromarray(label.astype('uint8'))
transform_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
img = transform_img(img)
label = label_processor.encode_label_img(label) # 3, 352, 480 1, 352, 480
label = t.from_numpy(label)
return img, label
label_processor = LabelProcessor(cfg.class_dict_path)