Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasSloan committed Jul 22, 2020
1 parent 86236a4 commit fc3b788
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 150 deletions.
19 changes: 7 additions & 12 deletions efficientdet/keras/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file')
FLAGS = flags.FLAGS


def main(_):
config = hparams_config.get_efficientdet_config(FLAGS.model_name)
config.override(FLAGS.hparams)
Expand Down Expand Up @@ -71,6 +72,7 @@ def main(_):
box_outputs,
labels['image_scales'],
labels['source_ids'], False)

images_flipped = tf.image.flip_left_right(images)
cls_outputs_flipped, box_outputs_flipped = model(
images_flipped, training=False)
Expand All @@ -79,19 +81,12 @@ def main(_):
labels['image_scales'], labels['source_ids'], True)

for d, df in zip(detections, detections_flipped):
combined_detections = wbf.ensemble_boxes(config, tf.concat([d, df], 0))
combined_detections = wbf.ensemble_detections(config, tf.concat([d, df],
0))
combined_detections = tf.stack([combined_detections])
evaluator.update_state(labels['groundtruth_data'].numpy(),
combined_detections.numpy())

# print(len(detections[0]))
# print()

# for d in detections[0][:10]:
# print(d[5].numpy(), d[6].numpy())

# print()
# break
evaluator.update_state(
labels['groundtruth_data'].numpy(),
postprocess.transform_detections(combined_detections).numpy())

# compute the final eval results.
metric_values = evaluator.result()
Expand Down
46 changes: 31 additions & 15 deletions efficientdet/keras/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,15 @@ def topk_class_boxes(params, cls_outputs: T,
# Due to some issues, top_k is currently slow in graph model.
logging.info('use max_nms_inputs for pre-nms topk.')
cls_outputs_reshape = tf.reshape(cls_outputs, [batch_size, -1])
_, cls_topk_indices = tf.math.top_k(cls_outputs_reshape,
k=max_nms_inputs,
sorted=False)
_, cls_topk_indices = tf.math.top_k(
cls_outputs_reshape, k=max_nms_inputs, sorted=False)
indices = cls_topk_indices // num_classes
classes = cls_topk_indices % num_classes
cls_indices = tf.stack([indices, classes], axis=2)

cls_outputs_topk = tf.gather_nd(cls_outputs, cls_indices, batch_dims=1)
box_outputs_topk = tf.gather_nd(box_outputs,
tf.expand_dims(indices, 2),
batch_dims=1)
box_outputs_topk = tf.gather_nd(
box_outputs, tf.expand_dims(indices, 2), batch_dims=1)
else:
logging.info('use max_reduce for pre-nms topk.')
# Keep all anchors, but for each anchor, just keep the max probablity for
Expand All @@ -84,8 +82,8 @@ def topk_class_boxes(params, cls_outputs: T,
num_anchors = cls_outputs.shape[1]

classes = cls_outputs_idx
indices = tf.tile(tf.expand_dims(tf.range(num_anchors), axis=0),
[batch_size, 1])
indices = tf.tile(
tf.expand_dims(tf.range(num_anchors), axis=0), [batch_size, 1])
cls_outputs_topk = tf.reduce_max(cls_outputs, -1)
box_outputs_topk = box_outputs

Expand Down Expand Up @@ -354,11 +352,15 @@ def postprocess_per_class(params, cls_outputs, box_outputs, image_scales=None):
return per_class_nms(params, boxes, scores, classes, image_scales)


def generate_detections(params, cls_outputs, box_outputs, image_scales,
image_ids, flip = False):
def generate_detections(params,
cls_outputs,
box_outputs,
image_scales,
image_ids,
flip=False):
"""A legacy interface for generating [id, x, y, w, h, score, class]."""
nms_boxes_bs, nms_scores_bs, nms_classes_bs, _ = postprocess_per_class(
params, cls_outputs, box_outputs, img_scales)
params, cls_outputs, box_outputs, image_scales)

image_ids_bs = tf.cast(tf.expand_dims(image_ids, -1), nms_scores_bs.dtype)
if flip:
Expand All @@ -368,8 +370,8 @@ def generate_detections(params, cls_outputs, box_outputs, image_scales,
image_ids_bs * tf.ones_like(nms_scores_bs),
tf.expand_dims(image_scales, -1) * width - nms_boxes_bs[:, :, 3],
nms_boxes_bs[:, :, 0],
nms_boxes_bs[:, :, 3] - nms_boxes_bs[:, :, 1],
nms_boxes_bs[:, :, 2] - nms_boxes_bs[:, :, 0],
tf.expand_dims(image_scales, -1) * width - nms_boxes_bs[:, :, 1],
nms_boxes_bs[:, :, 2],
nms_scores_bs,
nms_classes_bs,
]
Expand All @@ -378,9 +380,23 @@ def generate_detections(params, cls_outputs, box_outputs, image_scales,
image_ids_bs * tf.ones_like(nms_scores_bs),
nms_boxes_bs[:, :, 1],
nms_boxes_bs[:, :, 0],
nms_boxes_bs[:, :, 3] - nms_boxes_bs[:, :, 1],
nms_boxes_bs[:, :, 2] - nms_boxes_bs[:, :, 0],
nms_boxes_bs[:, :, 3],
nms_boxes_bs[:, :, 2],
nms_scores_bs,
nms_classes_bs,
]
return tf.stack(detections_bs, axis=-1, name='detnections')


def transform_detections(detections):
"""A transforms detections in [id, x1, y1, x2, y2, score, class] form to [id, x, y, w, h, score, class]."""
return tf.stack([
detections[:, :, 0],
detections[:, :, 1],
detections[:, :, 2],
detections[:, :, 3] - detections[:, :, 1],
detections[:, :, 4] - detections[:, :, 2],
detections[:, :, 5],
detections[:, :, 6],
],
axis=-1)
28 changes: 28 additions & 0 deletions efficientdet/keras/postprocess_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,33 @@ def test_postprocess_per_class(self):
box_outputs_list, scales, ids)
self.assertAllClose(
outputs.numpy(),
[[[0., -1.177383, 1.793507, 8.340945, 4.418388, 0.901576, 2.],
[0., 5.676410, 6.102146, 7.785691, 8.537168, 0.888125, 1.]],
[[1., 5.885427, 13.529362, 11.410081, 14.154047, 0.884544, 1.],
[1., 8.145872, -9.660868, 14.173973, 10.41237, 0.815883, 2.]]])

outputs_flipped = postprocess.generate_detections(self.params,
cls_outputs_list,
box_outputs_list, scales,
ids, True)
self.assertAllClose(
outputs_flipped.numpy(),
[[[0., -0.340945, 1.793507, 9.177383, 4.418388, 0.901576, 2.],
[0., 0.214309, 6.102146, 2.32359, 8.537168, 0.888125, 1.]],
[[1., 4.589919, 13.529362, 10.114573, 14.154047, 0.884544, 1.],
[1., 1.826027, -9.660868, 7.854128, 10.41237, 0.815883, 2.]]])

def test_transform_detections(self):
corners = tf.constant(
[[[0., -1.177383, 1.793507, 8.340945, 4.418388, 0.901576, 2.],
[0., 5.676410, 6.102146, 7.785691, 8.537168, 0.888125, 1.]],
[[1., 5.885427, 13.529362, 11.410081, 14.154047, 0.884544, 1.],
[1., 8.145872, -9.660868, 14.173973, 10.41237, 0.815883, 2.]]])

corner_plus_area = postprocess.transform_detections(corners)

self.assertAllClose(
corner_plus_area.numpy(),
[[[0., -1.177383, 1.793507, 9.518328, 2.624881, 0.901576, 2.],
[0., 5.676410, 6.102146, 2.109282, 2.435021, 0.888125, 1.]],
[[1., 5.885427, 13.529362, 5.524654, 0.624685, 0.884544, 1.],
Expand Down Expand Up @@ -117,6 +144,7 @@ def test_postprocess_combined(self):
self.assertAllClose(scores.numpy(),
[[0.90157586, 0.88812476], [0.88454413, 0.8158828]])


if __name__ == '__main__':
logging.set_verbosity(logging.WARNING)
tf.test.main()
82 changes: 35 additions & 47 deletions efficientdet/keras/wbf.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,38 @@
from absl import logging
import tensorflow as tf

def vectorized_iou(d1, d2):
x1, y1, w1, h1 = tf.split(d1[:, 1:5], 4, axis=1)
x2, y2, w2, h2 = tf.split(d2[:, 1:5], 4, axis=1)

x11 = x1
y11 = y1
x21 = x2
y21 = y2

x12 = x1 + w1
y12 = y1 + h1
x22 = x2 + w2
y22 = y2 + h2
def vectorized_iou(clusters, detection):
"""Calculates the ious for box with each element of clusters."""
x11, y11, x12, y12 = tf.split(clusters[:, 1:5], 4, axis=1)
x21, y21, x22, y22 = tf.split(detection[1:5], 4)

xA = tf.maximum(x11, x21)
yA = tf.maximum(y11, y21)
xB = tf.minimum(x12, x22)
yB = tf.minimum(y12, y22)
xA = tf.maximum(x11, x21)
yA = tf.maximum(y11, y21)
xB = tf.minimum(x12, x22)
yB = tf.minimum(y12, y22)

interArea = tf.maximum((xB - xA), 0) * tf.maximum((yB - yA), 0)
interArea = tf.maximum((xB - xA), 0) * tf.maximum((yB - yA), 0)

boxAArea = (x12 - x11) * (y12 - y11)
boxBArea = (x22 - x21) * (y22 - y21)
boxAArea = (x12 - x11) * (y12 - y11)
boxBArea = (x22 - x21) * (y22 - y21)

iou = interArea / (boxAArea + boxBArea - interArea)
iou = interArea / (boxAArea + boxBArea - interArea)

return iou
return iou


def find_matching_cluster(clusters, box):
if len(clusters) == 0:
return -1
tiled_boxes = tf.tile(tf.expand_dims(box, axis=0), [len(clusters), 1])

ious = vectorized_iou(tf.stack(clusters), tiled_boxes)
ious = tf.concat([tf.constant([0.55]), tf.reshape(ious, [len(clusters)])], axis=0)
best_index = tf.argmax(ious)

return best_index - 1
def find_matching_cluster(clusters, detection):
"""Returns the index of the highest iou matching cluster for detection. Returns -1 if no iou is higher than 0.55."""
ious = vectorized_iou(tf.stack(clusters), detection)
ious = tf.reshape(ious, [len(clusters)])
if tf.math.reduce_max(ious) < 0.55:
return -1
return tf.argmax(ious)


def average_detections(detections):
"""Takes a list of detections and returns the average, both in box co-ordinates and confidence."""
detections = tf.stack(detections)
return [
detections[0][0],
Expand All @@ -55,29 +45,27 @@ def average_detections(detections):
]


def ensemble_boxes(params, detections):
# [id, x, y, w, h, score, class]

def ensemble_detections(params, detections):
"""Ensembles a group of detections by clustering the detections and returning the average of the clusters."""
all_clusters = []

# cluster the detections
for cid in range(params['num_classes']):
indices = tf.where(tf.equal(detections[:, 6], cid))
if indices.shape[0] == 0:
continue
continue
class_detections = tf.gather_nd(detections, indices)

clusters = []
cluster_averages = []
for d in class_detections:
cluster_index = find_matching_cluster(cluster_averages, d)
if cluster_index == -1:
clusters.append([d])
cluster_averages.append(d)
else:
clusters[cluster_index].append(d)
cluster_averages[cluster_index] = average_detections(
clusters[cluster_index])
clusters = [[class_detections[0]]]
cluster_averages = [class_detections[0]]
for d in class_detections[1:]:
cluster_index = find_matching_cluster(cluster_averages, d)
if cluster_index == -1:
clusters.append([d])
cluster_averages.append(d)
else:
clusters[cluster_index].append(d)
cluster_averages[cluster_index] = average_detections(
clusters[cluster_index])

all_clusters.extend(cluster_averages)

Expand Down
Loading

0 comments on commit fc3b788

Please sign in to comment.