-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_process.py
235 lines (181 loc) · 7.14 KB
/
data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
'''
Parallel data loading functions
'''
import sys
import time
import theano
import numpy as np
import traceback
from PIL import Image
from six.moves import queue
from multiprocessing import Process, Event
from lib.config import cfg
from lib.data_augmentation import preprocess_img
from lib.data_io import get_voxel_file, get_rendering_file
from lib.binvox_rw import read_as_3d_array
def print_error(func):
'''Flush out error messages. Mainly used for debugging separate processes'''
def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except:
traceback.print_exception(*sys.exc_info())
sys.stdout.flush()
return func_wrapper
class DataProcess(Process):
def __init__(self, data_queue, data_paths, repeat=True):
'''
data_queue : Multiprocessing queue
data_paths : list of data and label pair used to load data
repeat : if set True, return data until exit is set
'''
super(DataProcess, self).__init__()
# Queue to transfer the loaded mini batches
self.data_queue = data_queue
self.data_paths = data_paths
self.num_data = len(data_paths)
self.repeat = repeat
# Tuple of data shape
self.batch_size = cfg.CONST.BATCH_SIZE
self.exit = Event()
self.shuffle_db_inds()
def shuffle_db_inds(self):
# Randomly permute the training roidb
if self.repeat:
self.perm = np.random.permutation(np.arange(self.num_data))
else:
self.perm = np.arange(self.num_data)
self.cur = 0
def get_next_minibatch(self):
if (self.cur + self.batch_size) >= self.num_data and self.repeat:
self.shuffle_db_inds()
db_inds = self.perm[self.cur:min(self.cur + self.batch_size, self.num_data)]
self.cur += self.batch_size
return db_inds
def shutdown(self):
self.exit.set()
@print_error
def run(self):
iteration = 0
# Run the loop until exit flag is set
while not self.exit.is_set() and self.cur <= self.num_data:
# Ensure that the network sees (almost) all data per epoch
db_inds = self.get_next_minibatch()
data_list = []
label_list = []
for batch_id, db_ind in enumerate(db_inds):
datum = self.load_datum(self.data_paths[db_ind])
label = self.load_label(self.data_paths[db_ind])
data_list.append(datum)
label_list.append(label)
batch_data = np.array(data_list).astype(np.float32)
batch_label = np.array(label_list).astype(np.float32)
# The following will wait until the queue frees
self.data_queue.put((batch_data, batch_label), block=True)
iteration += 1
def load_datum(self, path):
pass
def load_label(self, path):
pass
class ReconstructionDataProcess(DataProcess):
def __init__(self, data_queue, category_model_pair, background_imgs=[], repeat=True,
train=True):
self.repeat = repeat
self.train = train
self.background_imgs = background_imgs
super(ReconstructionDataProcess, self).__init__(
data_queue, category_model_pair, repeat=repeat)
@print_error
def run(self):
# set up constants
img_h = cfg.CONST.IMG_W
img_w = cfg.CONST.IMG_H
n_vox = cfg.CONST.N_VOX
# This is the maximum number of views
n_views = cfg.CONST.N_VIEWS
while not self.exit.is_set() and self.cur <= self.num_data:
# To insure that the network sees (almost) all images per epoch
db_inds = self.get_next_minibatch()
# We will sample # views
if cfg.TRAIN.RANDOM_NUM_VIEWS:
curr_n_views = np.random.randint(n_views) + 1
else:
curr_n_views = n_views
# This will be fed into the queue. create new batch everytime
batch_img = np.zeros(
(curr_n_views, self.batch_size, 3, img_h, img_w), dtype=theano.config.floatX)
batch_voxel = np.zeros(
(self.batch_size, n_vox, 2, n_vox, n_vox), dtype=theano.config.floatX)
# load each data instance
for batch_id, db_ind in enumerate(db_inds):
category, model_id = self.data_paths[db_ind]
image_ids = np.random.choice(cfg.TRAIN.NUM_RENDERING, curr_n_views)
# load multi view images
for view_id, image_id in enumerate(image_ids):
im = self.load_img(category, model_id, image_id)
# channel, height, width
batch_img[view_id, batch_id, :, :, :] = \
im.transpose((2, 0, 1)).astype(theano.config.floatX)
voxel = self.load_label(category, model_id)
voxel_data = voxel.data
batch_voxel[batch_id, :, 0, :, :] = voxel_data < 1
batch_voxel[batch_id, :, 1, :, :] = voxel_data
# The following will wait until the queue frees
self.data_queue.put((batch_img, batch_voxel), block=True)
print('Exiting')
def load_img(self, category, model_id, image_id):
image_fn = get_rendering_file(category, model_id, image_id)
im = Image.open(image_fn)
t_im = preprocess_img(im, self.train)
return t_im
def load_label(self, category, model_id):
voxel_fn = get_voxel_file(category, model_id)
with open(voxel_fn, 'rb') as f:
voxel = read_as_3d_array(f)
return voxel
def kill_processes(queue, processes):
print('Signal processes')
for p in processes:
p.shutdown()
print('Empty queue')
while not queue.empty():
time.sleep(0.5)
queue.get(False)
print('kill processes')
for p in processes:
p.terminate()
def make_data_processes(queue, data_paths, num_workers, repeat=True, train=True):
'''
Make a set of data processes for parallel data loading.
'''
processes = []
for i in range(num_workers):
process = ReconstructionDataProcess(queue, data_paths, repeat=repeat, train=train)
process.start()
processes.append(process)
return processes
def get_while_running(data_process, data_queue, sleep_time=0):
while True:
time.sleep(sleep_time)
try:
batch_data, batch_label = data_queue.get_nowait()
except queue.Empty:
if not data_process.is_alive():
break
else:
continue
yield batch_data, batch_label
def test_process():
from multiprocessing import Queue
from lib.config import cfg
from lib.data_io import category_model_id_pair
cfg.TRAIN.PAD_X = 10
cfg.TRAIN.PAD_Y = 10
data_queue = Queue(2)
category_model_pair = category_model_id_pair(dataset_portion=[0, 0.1])
data_process = ReconstructionDataProcess(data_queue, category_model_pair)
data_process.start()
batch_img, batch_voxel = data_queue.get()
kill_processes(data_queue, [data_process])
if __name__ == '__main__':
test_process()