forked from mohamedac29/S2-FPN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_builder.py
122 lines (98 loc) · 5.59 KB
/
dataset_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import pickle
from torch.utils import data
from dataset.cityscapes import CityscapesDataSet, CityscapesTrainInform, CityscapesValDataSet, CityscapesTestDataSet
from dataset.camvid import CamVidDataSet, CamVidValDataSet, CamVidTrainInform, CamVidTestDataSet
def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers):
data_dir = os.path.join('./dataset/', dataset)
dataset_list = os.path.join(dataset+'_trainval_list.txt')
train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt')
val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')
# inform_data_file collect the information of mean, std and weigth_class
if not os.path.isfile(inform_data_file):
print("%s is not found" % (inform_data_file))
if dataset == "cityscapes":
dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list,
inform_data_file=inform_data_file)
elif dataset == 'camvid':
dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list,
inform_data_file=inform_data_file)
else:
raise NotImplementedError(
"This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)
datas = dataCollect.collectDataAndSave()
if datas is None:
print("error while pickling data. Please check.")
exit(-1)
else:
print("find file: ", str(inform_data_file))
datas = pickle.load(open(inform_data_file, "rb"))
if dataset == "cityscapes":
trainLoader = data.DataLoader(
CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale,
mirror=random_mirror, mean=datas['mean']),
batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True, drop_last=True)
valLoader = data.DataLoader(
CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']),
batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True,
drop_last=True)
return datas, trainLoader, valLoader
elif dataset == "camvid":
trainLoader = data.DataLoader(
CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale,
mirror=random_mirror, mean=datas['mean']),
batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True, drop_last=True)
valLoader = data.DataLoader(
CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']),
batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)
return datas, trainLoader, valLoader
def build_dataset_test(dataset, num_workers, none_gt=False):
data_dir = os.path.join('./dataset/', dataset)
dataset_list = os.path.join(dataset, '_trainval_list.txt')
test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt')
inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')
# inform_data_file collect the information of mean, std and weigth_class
if not os.path.isfile(inform_data_file):
print("%s is not found" % (inform_data_file))
if dataset == "cityscapes":
dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list,
inform_data_file=inform_data_file)
elif dataset == 'camvid':
dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list,
inform_data_file=inform_data_file)
else:
raise NotImplementedError(
"This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)
datas = dataCollect.collectDataAndSave()
if datas is None:
print("error while pickling data. Please check.")
exit(-1)
else:
print("find file: ", str(inform_data_file))
datas = pickle.load(open(inform_data_file, "rb"))
if dataset == "cityscapes":
# for cityscapes, if test on validation set, set none_gt to False
# if test on the test set, set none_gt to True
if none_gt:
testLoader = data.DataLoader(
CityscapesTestDataSet(data_dir, test_data_list, mean=datas['mean']),
batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
else:
test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
testLoader = data.DataLoader(
CityscapesValDataSet(data_dir, test_data_list, mean=datas['mean']),
batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
return datas, testLoader
elif dataset == "camvid":
testLoader = data.DataLoader(
CamVidValDataSet(data_dir, test_data_list, mean=datas['mean']),
batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
return datas, testLoader
if __name__ == "__main__":
dataCollection = CamVidTrainInform("/home/mohamed/RINet/dataset/camvid", classes=11,
train_set_file="camvid_trainval_list.txt",
inform_data_file="inform/camvid_inform.pkl")
data = dataCollection.collectDataAndSa