-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
558 lines (470 loc) · 25.4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
# coding=utf-8
# for better understanding about yolov3 architecture, refer to this website (in Chinese):
# https://blog.csdn.net/leviopku/article/details/82660381
from __future__ import division, print_function
import tensorflow as tf
slim = tf.contrib.slim
from utils.layer_utils import conv2d, darknet53_body, yolo_block, upsample_layer
class yolov3(object):
def __init__(self, class_num, anchors, use_label_smooth=False, use_focal_loss=False, batch_norm_decay=0.999,
weight_decay=5e-4, use_static_shape=True):
# self.anchors = [[10, 13], [16, 30], [33, 23],
# [30, 61], [62, 45], [59, 119],
# [116, 90], [156, 198], [373,326]]
self.class_num = class_num
self.anchors = anchors
self.batch_norm_decay = batch_norm_decay
self.use_label_smooth = use_label_smooth
self.use_focal_loss = use_focal_loss
self.weight_decay = weight_decay
# inference speed optimization
# if `use_static_shape` is True, use tensor.get_shape(), otherwise use tf.shape(tensor)
# static_shape is slightly faster
self.use_static_shape = use_static_shape
def forward(self, inputs, is_training=False, reuse=False):
# the input img_size, form: [height, weight]
# it will be used later
self.img_size = tf.shape(inputs)[1:3]
# set batch norm params
batch_norm_params = {
'decay': self.batch_norm_decay,
'epsilon': 1e-05,
'scale': True,
'is_training': is_training,
'fused': None, # Use fused batch norm if possible.
}
with slim.arg_scope([slim.conv2d, slim.batch_norm], reuse=reuse):
with slim.arg_scope([slim.conv2d],
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params,
biases_initializer=None,
activation_fn=lambda x: tf.nn.leaky_relu(x, alpha=0.1),
weights_regularizer=slim.l2_regularizer(self.weight_decay)):
with tf.variable_scope('darknet53_body'):
route_1, route_2, route_3 = darknet53_body(inputs)
with tf.variable_scope('yolov3_head'):
inter1, net = yolo_block(route_3, 512)
feature_map_1 = slim.conv2d(net, 3 * (5 + self.class_num), 1,
stride=1, normalizer_fn=None,
activation_fn=None, biases_initializer=tf.zeros_initializer())
feature_map_1 = tf.identity(feature_map_1, name='feature_map_1')
inter1 = conv2d(inter1, 256, 1)
inter1 = upsample_layer(inter1,
route_2.get_shape().as_list() if self.use_static_shape else tf.shape(
route_2))
concat1 = tf.concat([inter1, route_2], axis=3)
inter2, net = yolo_block(concat1, 256)
feature_map_2 = slim.conv2d(net, 3 * (5 + self.class_num), 1,
stride=1, normalizer_fn=None,
activation_fn=None, biases_initializer=tf.zeros_initializer())
feature_map_2 = tf.identity(feature_map_2, name='feature_map_2')
inter2 = conv2d(inter2, 128, 1)
inter2 = upsample_layer(inter2,
route_1.get_shape().as_list() if self.use_static_shape else tf.shape(
route_1))
concat2 = tf.concat([inter2, route_1], axis=3)
_, feature_map_3 = yolo_block(concat2, 128)
feature_map_3 = slim.conv2d(feature_map_3, 3 * (5 + self.class_num), 1,
stride=1, normalizer_fn=None,
activation_fn=None, biases_initializer=tf.zeros_initializer())
feature_map_3 = tf.identity(feature_map_3, name='feature_map_3')
return feature_map_1, feature_map_2, feature_map_3
def reorg_layer(self, feature_map, anchors):
'''
feature_map: a feature_map from [feature_map_1, feature_map_2, feature_map_3] returned
from `forward` function
anchors: shape: [3, 2]
'''
# NOTE: size in [h, w] format! don't get messed up!
grid_size = feature_map.get_shape().as_list()[1:3] if self.use_static_shape else tf.shape(feature_map)[
1:3] # [13, 13]
# the downscale ratio in height and weight
ratio = tf.cast(self.img_size / grid_size, tf.float32)
# rescale the anchors to the feature_map
# NOTE: the anchor is in [w, h] format!
rescaled_anchors = [(anchor[0] / ratio[1], anchor[1] / ratio[0]) for anchor in anchors]
feature_map = tf.reshape(feature_map, [-1, grid_size[0], grid_size[1], 3, 5 + self.class_num])
# split the feature_map along the last dimension
# shape info: take 416x416 input image and the 13*13 feature_map for example:
# box_centers: [N, 13, 13, 3, 2] last_dimension: [center_x, center_y]
# box_sizes: [N, 13, 13, 3, 2] last_dimension: [width, height]
# conf_logits: [N, 13, 13, 3, 1]
# prob_logits: [N, 13, 13, 3, class_num]
box_centers, box_sizes, conf_logits, prob_logits = tf.split(feature_map, [2, 2, 1, self.class_num], axis=-1)
box_centers = tf.nn.sigmoid(box_centers)
# use some broadcast tricks to get the mesh coordinates
grid_x = tf.range(grid_size[1], dtype=tf.int32)
grid_y = tf.range(grid_size[0], dtype=tf.int32)
grid_x, grid_y = tf.meshgrid(grid_x, grid_y)
x_offset = tf.reshape(grid_x, (-1, 1))
y_offset = tf.reshape(grid_y, (-1, 1))
x_y_offset = tf.concat([x_offset, y_offset], axis=-1)
# shape: [13, 13, 1, 2]
x_y_offset = tf.cast(tf.reshape(x_y_offset, [grid_size[0], grid_size[1], 1, 2]), tf.float32)
# get the absolute box coordinates on the feature_map
box_centers = box_centers + x_y_offset
# rescale to the original image scale
box_centers = box_centers * ratio[::-1]
# avoid getting possible nan value with tf.clip_by_value
box_sizes = tf.exp(box_sizes) * rescaled_anchors
# box_sizes = tf.clip_by_value(tf.exp(box_sizes), 1e-9, 100) * rescaled_anchors
# rescale to the original image scale
box_sizes = box_sizes * ratio[::-1]
# shape: [N, 13, 13, 3, 4]
# last dimension: (center_x, center_y, w, h)
boxes = tf.concat([box_centers, box_sizes], axis=-1)
# shape:
# x_y_offset: [13, 13, 1, 2]
# boxes: [N, 13, 13, 3, 4], rescaled to the original image scale
# conf_logits: [N, 13, 13, 3, 1]
# prob_logits: [N, 13, 13, 3, class_num]
return x_y_offset, boxes, conf_logits, prob_logits
def predict(self, feature_maps):
'''
Receive the returned feature_maps from `forward` function,
the produce the output predictions at the test stage.
'''
feature_map_1, feature_map_2, feature_map_3 = feature_maps
feature_map_anchors = [(feature_map_1, self.anchors[6:9]),
(feature_map_2, self.anchors[3:6]),
(feature_map_3, self.anchors[0:3])]
reorg_results = [self.reorg_layer(feature_map, anchors) for (feature_map, anchors) in feature_map_anchors]
def _reshape(result):
x_y_offset, boxes, conf_logits, prob_logits = result
grid_size = x_y_offset.get_shape().as_list()[:2] if self.use_static_shape else tf.shape(x_y_offset)[:2]
boxes = tf.reshape(boxes, [-1, grid_size[0] * grid_size[1] * 3, 4])
conf_logits = tf.reshape(conf_logits, [-1, grid_size[0] * grid_size[1] * 3, 1])
prob_logits = tf.reshape(prob_logits, [-1, grid_size[0] * grid_size[1] * 3, self.class_num])
# shape: (take 416*416 input image and feature_map_1 for example)
# boxes: [N, 13*13*3, 4]
# conf_logits: [N, 13*13*3, 1]
# prob_logits: [N, 13*13*3, class_num]
return boxes, conf_logits, prob_logits
boxes_list, confs_list, probs_list = [], [], []
for result in reorg_results:
boxes, conf_logits, prob_logits = _reshape(result)
confs = tf.sigmoid(conf_logits)
probs = tf.sigmoid(prob_logits)
boxes_list.append(boxes)
confs_list.append(confs)
probs_list.append(probs)
# collect results on three scales
# take 416*416 input image for example:
# shape: [N, (13*13+26*26+52*52)*3, 4]
boxes = tf.concat(boxes_list, axis=1)
# shape: [N, (13*13+26*26+52*52)*3, 1]
confs = tf.concat(confs_list, axis=1)
# shape: [N, (13*13+26*26+52*52)*3, class_num]
probs = tf.concat(probs_list, axis=1)
center_x, center_y, width, height = tf.split(boxes, [1, 1, 1, 1], axis=-1)
x_min = center_x - width / 2
y_min = center_y - height / 2
x_max = center_x + width / 2
y_max = center_y + height / 2
boxes = tf.concat([x_min, y_min, x_max, y_max], axis=-1)
return boxes, confs, probs
def loss_layer(self, feature_map_i, y_true, anchors):
'''
calc loss function from a certain scale
input:
feature_map_i: feature maps of a certain scale. shape: [N, 13, 13, 3*(5 + num_class)] etc.
y_true: y_ture from a certain scale. shape: [N, 13, 13, 3, 5 + num_class + 1] etc.
anchors: shape [9, 2]
'''
# size in [h, w] format! don't get messed up!
grid_size = tf.shape(feature_map_i)[1:3]
# the downscale ratio in height and weight
ratio = tf.cast(self.img_size / grid_size, tf.float32)
# N: batch_size
N = tf.cast(tf.shape(feature_map_i)[0], tf.float32)
x_y_offset, pred_boxes, pred_conf_logits, pred_prob_logits = self.reorg_layer(feature_map_i, anchors)
###########
# get mask
###########
# shape: take 416x416 input image and 13*13 feature_map for example:
# [N, 13, 13, 3, 1]
object_mask = y_true[..., 4:5]
# the calculation of ignore mask if referred from
# https://github.com/pjreddie/darknet/blob/master/src/yolo_layer.c#L179
ignore_mask = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
def loop_cond(idx, ignore_mask):
return tf.less(idx, tf.cast(N, tf.int32))
def loop_body(idx, ignore_mask):
# shape: [13, 13, 3, 4] & [13, 13, 3] ==> [V, 4]
# V: num of true gt box of each image in a batch
valid_true_boxes = tf.boolean_mask(y_true[idx, ..., 0:4], tf.cast(object_mask[idx, ..., 0], 'bool'))
# shape: [13, 13, 3, 4] & [V, 4] ==> [13, 13, 3, V]
iou = self.box_iou(pred_boxes[idx], valid_true_boxes)
# shape: [13, 13, 3]
best_iou = tf.reduce_max(iou, axis=-1)
# shape: [13, 13, 3]
ignore_mask_tmp = tf.cast(best_iou < 0.5, tf.float32)
# finally will be shape: [N, 13, 13, 3]
ignore_mask = ignore_mask.write(idx, ignore_mask_tmp)
return idx + 1, ignore_mask
_, ignore_mask = tf.while_loop(cond=loop_cond, body=loop_body, loop_vars=[0, ignore_mask])
ignore_mask = ignore_mask.stack()
# shape: [N, 13, 13, 3, 1]
ignore_mask = tf.expand_dims(ignore_mask, -1)
# shape: [N, 13, 13, 3, 2]
pred_box_xy = pred_boxes[..., 0:2]
pred_box_wh = pred_boxes[..., 2:4]
# get xy coordinates in one cell from the feature_map
# numerical range: 0 ~ 1
# shape: [N, 13, 13, 3, 2]
true_xy = y_true[..., 0:2] / ratio[::-1] - x_y_offset
pred_xy = pred_box_xy / ratio[::-1] - x_y_offset
# get_tw_th
# numerical range: 0 ~ 1
# shape: [N, 13, 13, 3, 2]
true_tw_th = y_true[..., 2:4] / anchors
pred_tw_th = pred_box_wh / anchors
# for numerical stability
true_tw_th = tf.where(condition=tf.equal(true_tw_th, 0),
x=tf.ones_like(true_tw_th), y=true_tw_th)
pred_tw_th = tf.where(condition=tf.equal(pred_tw_th, 0),
x=tf.ones_like(pred_tw_th), y=pred_tw_th)
true_tw_th = tf.log(tf.clip_by_value(true_tw_th, 1e-9, 1e9))
pred_tw_th = tf.log(tf.clip_by_value(pred_tw_th, 1e-9, 1e9))
# box size punishment:
# box with smaller area has bigger weight. This is taken from the yolo darknet C source code.
# shape: [N, 13, 13, 3, 1]
box_loss_scale = 2. - (y_true[..., 2:3] / tf.cast(self.img_size[1], tf.float32)) * (
y_true[..., 3:4] / tf.cast(self.img_size[0], tf.float32))
############
# loss_part
############
# mix_up weight
# [N, 13, 13, 3, 1]
mix_w = y_true[..., -1:]
# shape: [N, 13, 13, 3, 1]
xy_loss = tf.reduce_sum(tf.square(true_xy - pred_xy) * object_mask * box_loss_scale * mix_w) / N
wh_loss = tf.reduce_sum(tf.square(true_tw_th - pred_tw_th) * object_mask * box_loss_scale * mix_w) / N
# shape: [N, 13, 13, 3, 1]
conf_pos_mask = object_mask
conf_neg_mask = (1 - object_mask) * ignore_mask
conf_loss_pos = conf_pos_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=object_mask,
logits=pred_conf_logits)
conf_loss_neg = conf_neg_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=object_mask,
logits=pred_conf_logits)
# TODO: may need to balance the pos-neg by multiplying some weights
conf_loss = conf_loss_pos + conf_loss_neg
if self.use_focal_loss:
alpha = 1.0
gamma = 2.0
# TODO: alpha should be a mask array if needed
focal_mask = alpha * tf.pow(tf.abs(object_mask - tf.sigmoid(pred_conf_logits)), gamma)
conf_loss *= focal_mask
conf_loss = tf.reduce_sum(conf_loss * mix_w) / N
giou = tf.expand_dims(self.bbox_giou(pred_boxes, y_true), axis=-1)
giou_loss = object_mask * box_loss_scale * (1 - giou)
giou_loss = tf.reduce_sum(giou_loss) / N
# shape: [N, 13, 13, 3, 1]
# whether to use label smooth
if self.use_label_smooth:
delta = 0.01
label_target = (1 - delta) * y_true[..., 5:-1] + delta * 1. / self.class_num
else:
label_target = y_true[..., 5:-1]
class_loss = object_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_target,
logits=pred_prob_logits) * mix_w
class_loss = tf.reduce_sum(class_loss) / N
return xy_loss, wh_loss, conf_loss, class_loss
def loss_layer_giou(self, feature_map_i, y_true, anchors):
'''
calc loss function from a certain scale
input:
feature_map_i: feature maps of a certain scale. shape: [N, 13, 13, 3*(5 + num_class)] etc.
y_true: y_ture from a certain scale. shape: [N, 13, 13, 3, 5 + num_class + 1] etc.
anchors: shape [9, 2]
'''
# size in [h, w] format! don't get messed up!
grid_size = tf.shape(feature_map_i)[1:3]
# the downscale ratio in height and weight
ratio = tf.cast(self.img_size / grid_size, tf.float32)
# N: batch_size
N = tf.cast(tf.shape(feature_map_i)[0], tf.float32)
x_y_offset, pred_boxes, pred_conf_logits, pred_prob_logits = self.reorg_layer(feature_map_i, anchors)
###########
# get mask
###########
# shape: take 416x416 input image and 13*13 feature_map for example:
# [N, 13, 13, 3, 1]
object_mask = y_true[..., 4:5]
# the calculation of ignore mask if referred from
# https://github.com/pjreddie/darknet/blob/master/src/yolo_layer.c#L179
ignore_mask = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
def loop_cond(idx, ignore_mask):
return tf.less(idx, tf.cast(N, tf.int32))
def loop_body(idx, ignore_mask):
# shape: [13, 13, 3, 4] & [13, 13, 3] ==> [V, 4]
# V: num of true gt box of each image in a batch
valid_true_boxes = tf.boolean_mask(y_true[idx, ..., 0:4], tf.cast(object_mask[idx, ..., 0], 'bool'))
# shape: [13, 13, 3, 4] & [V, 4] ==> [13, 13, 3, V]
iou = self.box_iou(pred_boxes[idx], valid_true_boxes)
# shape: [13, 13, 3]
best_iou = tf.reduce_max(iou, axis=-1)
# shape: [13, 13, 3]
ignore_mask_tmp = tf.cast(best_iou < 0.5, tf.float32)
# finally will be shape: [N, 13, 13, 3]
ignore_mask = ignore_mask.write(idx, ignore_mask_tmp)
return idx + 1, ignore_mask
_, ignore_mask = tf.while_loop(cond=loop_cond, body=loop_body, loop_vars=[0, ignore_mask])
ignore_mask = ignore_mask.stack()
# shape: [N, 13, 13, 3, 1]
ignore_mask = tf.expand_dims(ignore_mask, -1)
# shape: [N, 13, 13, 3, 2]
pred_box_xy = pred_boxes[..., 0:2]
pred_box_wh = pred_boxes[..., 2:4]
# get xy coordinates in one cell from the feature_map
# numerical range: 0 ~ 1
# shape: [N, 13, 13, 3, 2]
true_xy = y_true[..., 0:2] / ratio[::-1] - x_y_offset
pred_xy = pred_box_xy / ratio[::-1] - x_y_offset
# get_tw_th
# numerical range: 0 ~ 1
# shape: [N, 13, 13, 3, 2]
true_tw_th = y_true[..., 2:4] / anchors
pred_tw_th = pred_box_wh / anchors
# for numerical stability
true_tw_th = tf.where(condition=tf.equal(true_tw_th, 0),
x=tf.ones_like(true_tw_th), y=true_tw_th)
pred_tw_th = tf.where(condition=tf.equal(pred_tw_th, 0),
x=tf.ones_like(pred_tw_th), y=pred_tw_th)
true_tw_th = tf.log(tf.clip_by_value(true_tw_th, 1e-9, 1e9))
pred_tw_th = tf.log(tf.clip_by_value(pred_tw_th, 1e-9, 1e9))
# box size punishment:
# box with smaller area has bigger weight. This is taken from the yolo darknet C source code.
# shape: [N, 13, 13, 3, 1]
box_loss_scale = 2. - (y_true[..., 2:3] / tf.cast(self.img_size[1], tf.float32)) * (
y_true[..., 3:4] / tf.cast(self.img_size[0], tf.float32))
############
# loss_part
############
# mix_up weight
# [N, 13, 13, 3, 1]
mix_w = y_true[..., -1:]
# shape: [N, 13, 13, 3, 1]
xy_loss = tf.reduce_sum(tf.square(true_xy - pred_xy) * object_mask * box_loss_scale * mix_w) / N
wh_loss = tf.reduce_sum(tf.square(true_tw_th - pred_tw_th) * object_mask * box_loss_scale * mix_w) / N
# shape: [N, 13, 13, 3, 1]
conf_pos_mask = object_mask
conf_neg_mask = (1 - object_mask) * ignore_mask
conf_loss_pos = conf_pos_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=object_mask,
logits=pred_conf_logits)
conf_loss_neg = conf_neg_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=object_mask,
logits=pred_conf_logits)
# TODO: may need to balance the pos-neg by multiplying some weights
conf_loss = conf_loss_pos + conf_loss_neg
if self.use_focal_loss:
alpha = 1.0
gamma = 2.0
# TODO: alpha should be a mask array if needed
focal_mask = alpha * tf.pow(tf.abs(object_mask - tf.sigmoid(pred_conf_logits)), gamma)
conf_loss *= focal_mask
conf_loss = tf.reduce_sum(conf_loss * mix_w) / N
giou = tf.expand_dims(self.bbox_giou(pred_boxes, y_true), axis=-1)
giou_loss = object_mask * box_loss_scale * (1 - giou)
giou_loss = tf.reduce_sum(giou_loss) / N
# shape: [N, 13, 13, 3, 1]
# whether to use label smooth
if self.use_label_smooth:
delta = 0.01
label_target = (1 - delta) * y_true[..., 5:-1] + delta * 1. / self.class_num
else:
label_target = y_true[..., 5:-1]
class_loss = object_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_target,
logits=pred_prob_logits) * mix_w
class_loss = tf.reduce_sum(class_loss) / N
return xy_loss, wh_loss, conf_loss, class_loss, giou_loss
def box_iou(self, pred_boxes, valid_true_boxes):
'''
param:
pred_boxes: [13, 13, 3, 4], (center_x, center_y, w, h)
valid_true: [V, 4]
'''
# [13, 13, 3, 2]
pred_box_xy = pred_boxes[..., 0:2]
pred_box_wh = pred_boxes[..., 2:4]
# shape: [13, 13, 3, 1, 2]
pred_box_xy = tf.expand_dims(pred_box_xy, -2)
pred_box_wh = tf.expand_dims(pred_box_wh, -2)
# [V, 2]
true_box_xy = valid_true_boxes[:, 0:2]
true_box_wh = valid_true_boxes[:, 2:4]
# [13, 13, 3, 1, 2] & [V, 2] ==> [13, 13, 3, V, 2]
intersect_mins = tf.maximum(pred_box_xy - pred_box_wh / 2.,
true_box_xy - true_box_wh / 2.)
intersect_maxs = tf.minimum(pred_box_xy + pred_box_wh / 2.,
true_box_xy + true_box_wh / 2.)
intersect_wh = tf.maximum(intersect_maxs - intersect_mins, 0.)
# shape: [13, 13, 3, V]
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
# shape: [13, 13, 3, 1]
pred_box_area = pred_box_wh[..., 0] * pred_box_wh[..., 1]
# shape: [V]
true_box_area = true_box_wh[..., 0] * true_box_wh[..., 1]
# shape: [1, V]
true_box_area = tf.expand_dims(true_box_area, axis=0)
# [13, 13, 3, V]
iou = intersect_area / (pred_box_area + true_box_area - intersect_area + 1e-10)
return iou
def bbox_giou(self, boxes1, boxes2):
boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)
boxes1 = tf.concat([tf.minimum(boxes1[..., :2], boxes1[..., 2:]),
tf.maximum(boxes1[..., :2], boxes1[..., 2:])], axis=-1)
boxes2 = tf.concat([tf.minimum(boxes2[..., :2], boxes2[..., 2:]),
tf.maximum(boxes2[..., :2], boxes2[..., 2:])], axis=-1)
boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])
left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2])
right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:])
inter_section = tf.maximum(right_down - left_up, 0.0)
inter_area = inter_section[..., 0] * inter_section[..., 1]
union_area = boxes1_area + boxes2_area - inter_area
iou = inter_area / union_area
enclose_left_up = tf.minimum(boxes1[..., :2], boxes2[..., :2])
enclose_right_down = tf.maximum(boxes1[..., 2:], boxes2[..., 2:])
enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0)
enclose_area = enclose[..., 0] * enclose[..., 1]
giou = iou - 1.0 * (enclose_area - union_area) / enclose_area
return giou
def compute_loss(self, y_pred, y_true):
'''
param:
y_pred: returned feature_map list by `forward` function: [feature_map_1, feature_map_2, feature_map_3]
y_true: input y_true by the tf.data pipeline
'''
loss_xy, loss_wh, loss_conf, loss_class = 0., 0., 0., 0.
anchor_group = [self.anchors[6:9], self.anchors[3:6], self.anchors[0:3]]
# calc loss in 3 scales
for i in range(len(y_pred)):
result = self.loss_layer(y_pred[i], y_true[i], anchor_group[i])
loss_xy += result[0]
loss_wh += result[1]
loss_conf += result[2]
loss_class += result[3]
total_loss = loss_xy + loss_wh + loss_conf + loss_class
return [total_loss, loss_xy, loss_wh, loss_conf, loss_class]
def compute_loss_giou(self, y_pred, y_true):
'''
param:
y_pred: returned feature_map list by `forward` function: [feature_map_1, feature_map_2, feature_map_3]
y_true: input y_true by the tf.data pipeline
'''
loss_xy, loss_wh, loss_conf, loss_class, loss_giou = 0., 0., 0., 0., 0.
anchor_group = [self.anchors[6:9], self.anchors[3:6], self.anchors[0:3]]
# calc loss in 3 scales
for i in range(len(y_pred)):
result = self.loss_layer_giou(y_pred[i], y_true[i], anchor_group[i])
loss_xy += result[0]
loss_wh += result[1]
loss_conf += result[2]
loss_class += result[3]
loss_giou += result[4]
# total_loss = loss_xy + loss_wh + loss_conf + loss_class + loss_giou
total_loss = loss_conf + loss_class + loss_giou
return [total_loss, loss_conf, loss_class, loss_giou]