forked from microsoft/CameraTraps
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathct_utils.py
327 lines (238 loc) · 9.48 KB
/
ct_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
"""
ct_utils.py
Utility functions that don't depend on other things in this repo. Also see
cct_json_utils.
"""
import subprocess
import argparse
import inspect
import json
import math
import os
import jsonpickle
import numpy as np
def truncate_float_array(xs, precision=3):
"""
Vectorized version of truncate_float(...)
Args:
x (list of float) List of floats to truncate
precision (int) The number of significant digits to preserve, should be
greater or equal 1
"""
return [truncate_float(x, precision=precision) for x in xs]
def truncate_float(x, precision=3):
"""
Function for truncating a float scalar to the defined precision.
For example: truncate_float(0.0003214884) --> 0.000321
This function is primarily used to achieve a certain float representation
before exporting to JSON
Args:
x (float) Scalar to truncate
precision (int) The number of significant digits to preserve, should be
greater or equal 1
"""
assert precision > 0
if np.isclose(x, 0):
return 0
else:
# Determine the factor, which shifts the decimal point of x
# just behind the last significant digit
factor = math.pow(10, precision - 1 - math.floor(math.log10(abs(x))))
# Shift decimal point by multiplicatipon with factor, flooring, and
# division by factor
return math.floor(x * factor)/factor
def args_to_object(args: argparse.Namespace, obj: object) -> None:
"""
Copy all fields from a Namespace (i.e., the output from parse_args) to an
object. Skips fields starting with _. Does not check existence in the target
object.
Args:
args: argparse.Namespace
obj: class or object whose whose attributes will be updated
"""
for n, v in inspect.getmembers(args):
if not n.startswith('_'):
setattr(obj, n, v)
def pretty_print_object(obj, b_print=True):
"""
Prints an arbitrary object as .json
"""
# _ = pretty_print_object(obj)
# Sloppy that I'm making a module-wide change here...
jsonpickle.set_encoder_options('json', sort_keys=True, indent=4)
a = jsonpickle.encode(obj)
s = '{}'.format(a)
if b_print:
print(s)
return s
def is_list_sorted(L,reverse=False):
if reverse:
return all(L[i] >= L[i + 1] for i in range(len(L)-1))
else:
return all(L[i] <= L[i + 1] for i in range(len(L)-1))
def write_json(path, content, indent=1):
with open(path, 'w') as f:
json.dump(content, f, indent=indent)
image_extensions = ['.jpg', '.jpeg', '.gif', '.png']
def is_image_file(s):
"""
Check a file's extension against a hard-coded set of image file extensions
"""
ext = os.path.splitext(s)[1]
return ext.lower() in image_extensions
def convert_yolo_to_xywh(yolo_box):
"""
Converts a YOLO format bounding box to [x_min, y_min, width_of_box, height_of_box].
Args:
yolo_box: bounding box of format [x_center, y_center, width_of_box, height_of_box].
Returns:
bbox with coordinates represented as [x_min, y_min, width_of_box, height_of_box].
"""
x_center, y_center, width_of_box, height_of_box = yolo_box
x_min = x_center - width_of_box / 2.0
y_min = y_center - height_of_box / 2.0
return [x_min, y_min, width_of_box, height_of_box]
def convert_xywh_to_tf(api_box):
"""
Converts an xywh bounding box to an [y_min, x_min, y_max, x_max] box that the TensorFlow
Object Detection API uses
Args:
api_box: bbox output by the batch processing API [x_min, y_min, width_of_box, height_of_box]
Returns:
bbox with coordinates represented as [y_min, x_min, y_max, x_max]
"""
x_min, y_min, width_of_box, height_of_box = api_box
x_max = x_min + width_of_box
y_max = y_min + height_of_box
return [y_min, x_min, y_max, x_max]
def convert_xywh_to_xyxy(api_bbox):
"""
Converts an xywh bounding box to an xyxy bounding box.
Note that this is also different from the TensorFlow Object Detection API coords format.
Args:
api_bbox: bbox output by the batch processing API [x_min, y_min, width_of_box, height_of_box]
Returns:
bbox with coordinates represented as [x_min, y_min, x_max, y_max]
"""
x_min, y_min, width_of_box, height_of_box = api_bbox
x_max, y_max = x_min + width_of_box, y_min + height_of_box
return [x_min, y_min, x_max, y_max]
def get_iou(bb1, bb2):
"""
Calculate the Intersection over Union (IoU) of two bounding boxes.
Adapted from: https://stackoverflow.com/questions/25349178/calculating-percentage-of-bounding-box-overlap-for-image-detector-evaluation
Args:
bb1: [x_min, y_min, width_of_box, height_of_box]
bb2: [x_min, y_min, width_of_box, height_of_box]
These will be converted to
bb1: [x1,y1,x2,y2]
bb2: [x1,y1,x2,y2]
The (x1, y1) position is at the top left corner (or the bottom right - either way works).
The (x2, y2) position is at the bottom right corner (or the top left).
Returns:
intersection_over_union, a float in [0, 1]
"""
bb1 = convert_xywh_to_xyxy(bb1)
bb2 = convert_xywh_to_xyxy(bb2)
assert bb1[0] < bb1[2], 'Malformed bounding box (x2 >= x1)'
assert bb1[1] < bb1[3], 'Malformed bounding box (y2 >= y1)'
assert bb2[0] < bb2[2], 'Malformed bounding box (x2 >= x1)'
assert bb2[1] < bb2[3], 'Malformed bounding box (y2 >= y1)'
# Determine the coordinates of the intersection rectangle
x_left = max(bb1[0], bb2[0])
y_top = max(bb1[1], bb2[1])
x_right = min(bb1[2], bb2[2])
y_bottom = min(bb1[3], bb2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The intersection of two axis-aligned bounding boxes is always an
# axis-aligned bounding box
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# Compute the area of both AABBs
bb1_area = (bb1[2] - bb1[0]) * (bb1[3] - bb1[1])
bb2_area = (bb2[2] - bb2[0]) * (bb2[3] - bb2[1])
# Compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the intersection area.
iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
assert iou >= 0.0, 'Illegal IOU < 0'
assert iou <= 1.0, 'Illegal IOU > 1'
return iou
def _get_max_conf_from_detections(detections):
max_conf = 0.0
if detections is not None and len(detections) > 0:
confidences = [det['conf'] for det in detections]
max_conf = max(confidences)
return max_conf
def get_max_conf(im):
"""
Given an image dict in the format used by the batch API, compute the maximum detection
confidence for any class. Returns 0.0 (not None) if there was a failure and 'detections'
isn't present.
"""
max_conf = 0.0
if 'detections' in im and im['detections'] is not None and len(im['detections']) > 0:
max_conf = _get_max_conf_from_detections(im['detections'])
return max_conf
#%% Functions for running commands as subprocesses
def execute_command(cmd):
"""
Run [cmd] (a single string) in a shell, yielding each line of output to the caller.
Based on:
stackoverflow/questions/4417546/constantly-print-subprocess-output-while-process-is-running
"""
popen = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, shell=True, universal_newlines=True)
for stdout_line in iter(popen.stdout.readline, ""):
yield stdout_line
popen.stdout.close()
return_code = popen.wait()
if return_code:
raise subprocess.CalledProcessError(return_code, cmd)
def execute_command_and_print(cmd,print_output=True):
"""
Run [cmd] (a single string) in a shell, capturing and printing output. Returns
a dictionary with fields "status" and "output".
"""
to_return = {'status':'unknown','output':''}
output=[]
try:
for s in execute_command(cmd):
output.append(s)
if print_output:
print(s,end='',flush=True)
to_return['status'] = 0
except subprocess.CalledProcessError as cpe:
print('Caught error: {}'.format(cpe.output))
to_return['status'] = cpe.returncode
to_return['output'] = output
return to_return
#%%
if False:
#%% Test driver for execute_and_print
execute_command_and_print('echo hello && sleep 1 && echo goodbye')
#%% Parallel test driver for execute_command_and_print
from functools import partial
from multiprocessing.pool import ThreadPool as ThreadPool
from multiprocessing.pool import Pool as Pool
n_workers = 8
# Should we use threads (vs. processes) for parallelization?
use_threads = True
# Only relevant if n_workers == 1, i.e. if we're not parallelizing
quit_on_error = True
test_data = ['a','b','c','d']
def process_sample(s):
execute_command_and_print('echo ' + s,True)
if n_workers == 1:
results = []
for i_sample,sample in enumerate(test_data):
results.append(process_sample(sample))
else:
n_threads = min(n_workers,len(test_data))
if use_threads:
print('Starting parallel thread pool with {} workers'.format(n_threads))
pool = ThreadPool(n_threads)
else:
print('Starting parallel process pool with {} workers'.format(n_threads))
pool = Pool(n_threads)
results = list(pool.map(partial(process_sample),test_data))