From fc3b788b2e62394750c8339385db7e952c2630a6 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Fri, 17 Jul 2020 22:55:56 -0700 Subject: [PATCH] Clean up. --- efficientdet/keras/eval.py | 19 ++-- efficientdet/keras/postprocess.py | 46 +++++--- efficientdet/keras/postprocess_test.py | 28 +++++ efficientdet/keras/wbf.py | 82 +++++++-------- efficientdet/keras/wbf_test.py | 139 +++++++++++-------------- 5 files changed, 164 insertions(+), 150 deletions(-) diff --git a/efficientdet/keras/eval.py b/efficientdet/keras/eval.py index 43f588562..86984098b 100644 --- a/efficientdet/keras/eval.py +++ b/efficientdet/keras/eval.py @@ -41,6 +41,7 @@ flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file') FLAGS = flags.FLAGS + def main(_): config = hparams_config.get_efficientdet_config(FLAGS.model_name) config.override(FLAGS.hparams) @@ -71,6 +72,7 @@ def main(_): box_outputs, labels['image_scales'], labels['source_ids'], False) + images_flipped = tf.image.flip_left_right(images) cls_outputs_flipped, box_outputs_flipped = model( images_flipped, training=False) @@ -79,19 +81,12 @@ def main(_): labels['image_scales'], labels['source_ids'], True) for d, df in zip(detections, detections_flipped): - combined_detections = wbf.ensemble_boxes(config, tf.concat([d, df], 0)) + combined_detections = wbf.ensemble_detections(config, tf.concat([d, df], + 0)) combined_detections = tf.stack([combined_detections]) - evaluator.update_state(labels['groundtruth_data'].numpy(), - combined_detections.numpy()) - - # print(len(detections[0])) - # print() - - # for d in detections[0][:10]: - # print(d[5].numpy(), d[6].numpy()) - - # print() - # break + evaluator.update_state( + labels['groundtruth_data'].numpy(), + postprocess.transform_detections(combined_detections).numpy()) # compute the final eval results. metric_values = evaluator.result() diff --git a/efficientdet/keras/postprocess.py b/efficientdet/keras/postprocess.py index 48b295cfa..c9951b3b7 100644 --- a/efficientdet/keras/postprocess.py +++ b/efficientdet/keras/postprocess.py @@ -65,17 +65,15 @@ def topk_class_boxes(params, cls_outputs: T, # Due to some issues, top_k is currently slow in graph model. logging.info('use max_nms_inputs for pre-nms topk.') cls_outputs_reshape = tf.reshape(cls_outputs, [batch_size, -1]) - _, cls_topk_indices = tf.math.top_k(cls_outputs_reshape, - k=max_nms_inputs, - sorted=False) + _, cls_topk_indices = tf.math.top_k( + cls_outputs_reshape, k=max_nms_inputs, sorted=False) indices = cls_topk_indices // num_classes classes = cls_topk_indices % num_classes cls_indices = tf.stack([indices, classes], axis=2) cls_outputs_topk = tf.gather_nd(cls_outputs, cls_indices, batch_dims=1) - box_outputs_topk = tf.gather_nd(box_outputs, - tf.expand_dims(indices, 2), - batch_dims=1) + box_outputs_topk = tf.gather_nd( + box_outputs, tf.expand_dims(indices, 2), batch_dims=1) else: logging.info('use max_reduce for pre-nms topk.') # Keep all anchors, but for each anchor, just keep the max probablity for @@ -84,8 +82,8 @@ def topk_class_boxes(params, cls_outputs: T, num_anchors = cls_outputs.shape[1] classes = cls_outputs_idx - indices = tf.tile(tf.expand_dims(tf.range(num_anchors), axis=0), - [batch_size, 1]) + indices = tf.tile( + tf.expand_dims(tf.range(num_anchors), axis=0), [batch_size, 1]) cls_outputs_topk = tf.reduce_max(cls_outputs, -1) box_outputs_topk = box_outputs @@ -354,11 +352,15 @@ def postprocess_per_class(params, cls_outputs, box_outputs, image_scales=None): return per_class_nms(params, boxes, scores, classes, image_scales) -def generate_detections(params, cls_outputs, box_outputs, image_scales, - image_ids, flip = False): +def generate_detections(params, + cls_outputs, + box_outputs, + image_scales, + image_ids, + flip=False): """A legacy interface for generating [id, x, y, w, h, score, class].""" nms_boxes_bs, nms_scores_bs, nms_classes_bs, _ = postprocess_per_class( - params, cls_outputs, box_outputs, img_scales) + params, cls_outputs, box_outputs, image_scales) image_ids_bs = tf.cast(tf.expand_dims(image_ids, -1), nms_scores_bs.dtype) if flip: @@ -368,8 +370,8 @@ def generate_detections(params, cls_outputs, box_outputs, image_scales, image_ids_bs * tf.ones_like(nms_scores_bs), tf.expand_dims(image_scales, -1) * width - nms_boxes_bs[:, :, 3], nms_boxes_bs[:, :, 0], - nms_boxes_bs[:, :, 3] - nms_boxes_bs[:, :, 1], - nms_boxes_bs[:, :, 2] - nms_boxes_bs[:, :, 0], + tf.expand_dims(image_scales, -1) * width - nms_boxes_bs[:, :, 1], + nms_boxes_bs[:, :, 2], nms_scores_bs, nms_classes_bs, ] @@ -378,9 +380,23 @@ def generate_detections(params, cls_outputs, box_outputs, image_scales, image_ids_bs * tf.ones_like(nms_scores_bs), nms_boxes_bs[:, :, 1], nms_boxes_bs[:, :, 0], - nms_boxes_bs[:, :, 3] - nms_boxes_bs[:, :, 1], - nms_boxes_bs[:, :, 2] - nms_boxes_bs[:, :, 0], + nms_boxes_bs[:, :, 3], + nms_boxes_bs[:, :, 2], nms_scores_bs, nms_classes_bs, ] return tf.stack(detections_bs, axis=-1, name='detnections') + + +def transform_detections(detections): + """A transforms detections in [id, x1, y1, x2, y2, score, class] form to [id, x, y, w, h, score, class].""" + return tf.stack([ + detections[:, :, 0], + detections[:, :, 1], + detections[:, :, 2], + detections[:, :, 3] - detections[:, :, 1], + detections[:, :, 4] - detections[:, :, 2], + detections[:, :, 5], + detections[:, :, 6], + ], + axis=-1) diff --git a/efficientdet/keras/postprocess_test.py b/efficientdet/keras/postprocess_test.py index 7aed48276..7b0ab78d9 100644 --- a/efficientdet/keras/postprocess_test.py +++ b/efficientdet/keras/postprocess_test.py @@ -89,6 +89,33 @@ def test_postprocess_per_class(self): box_outputs_list, scales, ids) self.assertAllClose( outputs.numpy(), + [[[0., -1.177383, 1.793507, 8.340945, 4.418388, 0.901576, 2.], + [0., 5.676410, 6.102146, 7.785691, 8.537168, 0.888125, 1.]], + [[1., 5.885427, 13.529362, 11.410081, 14.154047, 0.884544, 1.], + [1., 8.145872, -9.660868, 14.173973, 10.41237, 0.815883, 2.]]]) + + outputs_flipped = postprocess.generate_detections(self.params, + cls_outputs_list, + box_outputs_list, scales, + ids, True) + self.assertAllClose( + outputs_flipped.numpy(), + [[[0., -0.340945, 1.793507, 9.177383, 4.418388, 0.901576, 2.], + [0., 0.214309, 6.102146, 2.32359, 8.537168, 0.888125, 1.]], + [[1., 4.589919, 13.529362, 10.114573, 14.154047, 0.884544, 1.], + [1., 1.826027, -9.660868, 7.854128, 10.41237, 0.815883, 2.]]]) + + def test_transform_detections(self): + corners = tf.constant( + [[[0., -1.177383, 1.793507, 8.340945, 4.418388, 0.901576, 2.], + [0., 5.676410, 6.102146, 7.785691, 8.537168, 0.888125, 1.]], + [[1., 5.885427, 13.529362, 11.410081, 14.154047, 0.884544, 1.], + [1., 8.145872, -9.660868, 14.173973, 10.41237, 0.815883, 2.]]]) + + corner_plus_area = postprocess.transform_detections(corners) + + self.assertAllClose( + corner_plus_area.numpy(), [[[0., -1.177383, 1.793507, 9.518328, 2.624881, 0.901576, 2.], [0., 5.676410, 6.102146, 2.109282, 2.435021, 0.888125, 1.]], [[1., 5.885427, 13.529362, 5.524654, 0.624685, 0.884544, 1.], @@ -117,6 +144,7 @@ def test_postprocess_combined(self): self.assertAllClose(scores.numpy(), [[0.90157586, 0.88812476], [0.88454413, 0.8158828]]) + if __name__ == '__main__': logging.set_verbosity(logging.WARNING) tf.test.main() diff --git a/efficientdet/keras/wbf.py b/efficientdet/keras/wbf.py index 7f18de131..0056fd182 100644 --- a/efficientdet/keras/wbf.py +++ b/efficientdet/keras/wbf.py @@ -1,48 +1,38 @@ from absl import logging import tensorflow as tf -def vectorized_iou(d1, d2): - x1, y1, w1, h1 = tf.split(d1[:, 1:5], 4, axis=1) - x2, y2, w2, h2 = tf.split(d2[:, 1:5], 4, axis=1) - - x11 = x1 - y11 = y1 - x21 = x2 - y21 = y2 - x12 = x1 + w1 - y12 = y1 + h1 - x22 = x2 + w2 - y22 = y2 + h2 +def vectorized_iou(clusters, detection): + """Calculates the ious for box with each element of clusters.""" + x11, y11, x12, y12 = tf.split(clusters[:, 1:5], 4, axis=1) + x21, y21, x22, y22 = tf.split(detection[1:5], 4) - xA = tf.maximum(x11, x21) - yA = tf.maximum(y11, y21) - xB = tf.minimum(x12, x22) - yB = tf.minimum(y12, y22) + xA = tf.maximum(x11, x21) + yA = tf.maximum(y11, y21) + xB = tf.minimum(x12, x22) + yB = tf.minimum(y12, y22) - interArea = tf.maximum((xB - xA), 0) * tf.maximum((yB - yA), 0) + interArea = tf.maximum((xB - xA), 0) * tf.maximum((yB - yA), 0) - boxAArea = (x12 - x11) * (y12 - y11) - boxBArea = (x22 - x21) * (y22 - y21) + boxAArea = (x12 - x11) * (y12 - y11) + boxBArea = (x22 - x21) * (y22 - y21) - iou = interArea / (boxAArea + boxBArea - interArea) + iou = interArea / (boxAArea + boxBArea - interArea) - return iou + return iou -def find_matching_cluster(clusters, box): - if len(clusters) == 0: - return -1 - tiled_boxes = tf.tile(tf.expand_dims(box, axis=0), [len(clusters), 1]) - - ious = vectorized_iou(tf.stack(clusters), tiled_boxes) - ious = tf.concat([tf.constant([0.55]), tf.reshape(ious, [len(clusters)])], axis=0) - best_index = tf.argmax(ious) - - return best_index - 1 +def find_matching_cluster(clusters, detection): + """Returns the index of the highest iou matching cluster for detection. Returns -1 if no iou is higher than 0.55.""" + ious = vectorized_iou(tf.stack(clusters), detection) + ious = tf.reshape(ious, [len(clusters)]) + if tf.math.reduce_max(ious) < 0.55: + return -1 + return tf.argmax(ious) def average_detections(detections): + """Takes a list of detections and returns the average, both in box co-ordinates and confidence.""" detections = tf.stack(detections) return [ detections[0][0], @@ -55,29 +45,27 @@ def average_detections(detections): ] -def ensemble_boxes(params, detections): - # [id, x, y, w, h, score, class] - +def ensemble_detections(params, detections): + """Ensembles a group of detections by clustering the detections and returning the average of the clusters.""" all_clusters = [] - # cluster the detections for cid in range(params['num_classes']): indices = tf.where(tf.equal(detections[:, 6], cid)) if indices.shape[0] == 0: - continue + continue class_detections = tf.gather_nd(detections, indices) - clusters = [] - cluster_averages = [] - for d in class_detections: - cluster_index = find_matching_cluster(cluster_averages, d) - if cluster_index == -1: - clusters.append([d]) - cluster_averages.append(d) - else: - clusters[cluster_index].append(d) - cluster_averages[cluster_index] = average_detections( - clusters[cluster_index]) + clusters = [[class_detections[0]]] + cluster_averages = [class_detections[0]] + for d in class_detections[1:]: + cluster_index = find_matching_cluster(cluster_averages, d) + if cluster_index == -1: + clusters.append([d]) + cluster_averages.append(d) + else: + clusters[cluster_index].append(d) + cluster_averages[cluster_index] = average_detections( + clusters[cluster_index]) all_clusters.extend(cluster_averages) diff --git a/efficientdet/keras/wbf_test.py b/efficientdet/keras/wbf_test.py index 6111c8582..f90229e17 100644 --- a/efficientdet/keras/wbf_test.py +++ b/efficientdet/keras/wbf_test.py @@ -3,117 +3,104 @@ from keras import wbf -tf.enable_eager_execution() class WbfTest(tf.test.TestCase): - - def test_detection_iou_same(self): - d1 = tf.constant([[1, 1, 1, 2, 2, 1, 1]], dtype=tf.float32) - d2 = tf.constant([[1, 1, 1, 2, 2, 1, 1]], dtype=tf.float32) - - iou = wbf.vectorized_iou(d1, d2) - self.assertAllClose(iou[0][0], 1.0) + def test_detection_iou_same(self): + d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32) + d2 = tf.constant([1, 1, 1, 3, 3, 1, 1], dtype=tf.float32) - def test_detection_iou_corners(self): - d1 = tf.constant([[1, 1, 1, 2, 2, 1, 1]], dtype=tf.float32) - d2 = tf.constant([[1, 2, 2, 2, 2, 1, 1]], dtype=tf.float32) - - iou = wbf.vectorized_iou(d1, d2) + iou = wbf.vectorized_iou(d1, d2) - self.assertAllClose(iou[0][0], 1.0/7.0) + self.assertAllClose(iou[0][0], 1.0) - def test_detection_iou_ends(self): - d1 = tf.constant([[1, 1, 1, 2, 1, 1, 1]], dtype=tf.float32) - d2 = tf.constant([[1, 2, 1, 2, 1, 1, 1]], dtype=tf.float32) - - iou = wbf.vectorized_iou(d1, d2) + def test_detection_iou_corners(self): + d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32) + d2 = tf.constant([1, 2, 2, 4, 4, 1, 1], dtype=tf.float32) - self.assertAllClose(iou[0][0], 1.0/3.0) + iou = wbf.vectorized_iou(d1, d2) - def test_detection_iou_none(self): - d1 = tf.constant([[1, 1, 1, 2, 2, 1, 1]], dtype=tf.float32) - d2 = tf.constant([[1, 3, 3, 2, 2, 1, 1]], dtype=tf.float32) - - iou = wbf.vectorized_iou(d1, d2) + self.assertAllClose(iou[0][0], 1.0 / 7.0) - self.assertAllClose(iou[0][0], 0) + def test_detection_iou_ends(self): + d1 = tf.constant([[1, 1, 1, 3, 2, 1, 1]], dtype=tf.float32) + d2 = tf.constant([1, 2, 1, 4, 2, 1, 1], dtype=tf.float32) - def test_detection_iou_vector(self): - vector_to_match = tf.constant( - [ - [1, 1, 1, 2, 2, 1, 1], - [1, 2, 2, 2, 2, 1, 1], - [1, 3, 3, 2, 2, 1, 1], - ], - dtype=tf.float32, - ) + iou = wbf.vectorized_iou(d1, d2) - detection = tf.constant([[1, 1, 1, 2, 2, 1, 1]], dtype=tf.float32) + self.assertAllClose(iou[0][0], 1.0 / 3.0) - ious = wbf.vectorized_iou(vector_to_match, tf.tile(detection, [3, 1])) - self.assertAllClose(tf.reshape(ious, [3]), [1, 1.0/7.0, 0]) + def test_detection_iou_none(self): + d1 = tf.constant([[1, 1, 1, 3, 3, 1, 1]], dtype=tf.float32) + d2 = tf.constant([1, 3, 3, 5, 5, 1, 1], dtype=tf.float32) + iou = wbf.vectorized_iou(d1, d2) - def test_find_matching_cluster_matches(self): - matching_cluster = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) - non_matching_cluster = tf.constant([1, 3, 3, 2, 2, 1, 1], dtype=tf.float32) + self.assertAllClose(iou[0][0], 0) - box = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) + def test_detection_iou_vector(self): + vector_to_match = tf.constant( + [ + [1, 1, 1, 3, 3, 1, 1], + [1, 2, 2, 4, 4, 1, 1], + [1, 3, 3, 5, 5, 1, 1], + ], + dtype=tf.float32, + ) - cluster_index = wbf.find_matching_cluster((matching_cluster, non_matching_cluster), box) + detection = tf.constant([1, 1, 1, 3, 3, 1, 1], dtype=tf.float32) - self.assertAllClose(cluster_index, 0) + ious = wbf.vectorized_iou(vector_to_match, detection) + self.assertAllClose(tf.reshape(ious, [3]), [1, 1.0 / 7.0, 0]) - cluster_index = wbf.find_matching_cluster((non_matching_cluster, matching_cluster), box) + def test_find_matching_cluster_matches(self): + matching_cluster = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) + non_matching_cluster = tf.constant([1, 3, 3, 2, 2, 1, 1], dtype=tf.float32) - self.assertAllClose(cluster_index, 1) + box = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) - def test_find_matching_cluster_best_overlap(self): - overlaps = tf.constant([1, 1, 1, 10, 1, 1, 1], dtype=tf.float32) - overlaps_better = tf.constant([1, 2, 1, 10, 1, 1, 1], dtype=tf.float32) + cluster_index = wbf.find_matching_cluster( + (matching_cluster, non_matching_cluster), box) - box = tf.constant([1, 3, 1, 10, 1, 1, 1], dtype=tf.float32) + self.assertAllClose(cluster_index, 0) - cluster_index = wbf.find_matching_cluster((overlaps,), box) + cluster_index = wbf.find_matching_cluster( + (non_matching_cluster, matching_cluster), box) - self.assertAllClose(cluster_index, 0) + self.assertAllClose(cluster_index, 1) - cluster_index = wbf.find_matching_cluster((overlaps, overlaps_better), box) + def test_find_matching_cluster_best_overlap(self): + overlaps = tf.constant([1, 1, 1, 11, 2, 1, 1], dtype=tf.float32) + overlaps_better = tf.constant([1, 2, 1, 12, 2, 1, 1], dtype=tf.float32) - self.assertAllClose(cluster_index, 1) + box = tf.constant([1, 3, 1, 13, 2, 1, 1], dtype=tf.float32) + cluster_index = wbf.find_matching_cluster((overlaps,), box) - def test_average_detections(self): - d1 = tf.constant([1, 1, 1, 2, 2, 0.5, 1], dtype=tf.float32) - d2 = tf.constant([1, 3, 3, 4, 4, 1, 1], dtype=tf.float32) + self.assertAllClose(cluster_index, 0) - averaged = wbf.average_detections((d1, d2)) + cluster_index = wbf.find_matching_cluster((overlaps, overlaps_better), box) - self.assertAllClose(averaged, [1, 2, 2, 3, 3, 0.75, 1]) + self.assertAllClose(cluster_index, 1) - # def test_find_matching_cluster_class_difference(self): - # matching_class = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) - # non_matching_class = tf.constant([1, 1, 1, 2, 2, 1, 2], dtype=tf.float32) + def test_average_detections(self): + d1 = tf.constant([1, 1, 1, 2, 2, 0.5, 1], dtype=tf.float32) + d2 = tf.constant([1, 3, 3, 4, 4, 1, 1], dtype=tf.float32) - # box = tf.constant([1, 1, 1, 2, 2, 1, 1], dtype=tf.float32) + averaged = wbf.average_detections((d1, d2)) - # cluster_index = wbf.find_matching_cluster((matching_class, non_matching_class), box) + self.assertAllClose(averaged, [1, 2, 2, 3, 3, 0.75, 1]) - # self.assertAllClose(cluster_index, 0) + def test_ensemble_boxes(self): + d1 = tf.constant([1, 2, 1, 10, 1, 0.5, 1], dtype=tf.float32) + d2 = tf.constant([1, 3, 1, 10, 1, 1, 1], dtype=tf.float32) + d3 = tf.constant([1, 3, 1, 10, 1, 1, 2], dtype=tf.float32) - # cluster_index = wbf.find_matching_cluster((non_matching_class, matching_class), box) + ensembled = wbf.ensemble_detections({'num_classes': 3}, + tf.stack([d1, d2, d3])) - # self.assertAllClose(cluster_index, 1) - - def test_ensemble_boxes(self): - d1 = tf.constant([1, 2, 1, 10, 1, 0.5, 1], dtype=tf.float32) - d2 = tf.constant([1, 3, 1, 10, 1, 1, 1], dtype=tf.float32) - d3 = tf.constant([1, 3, 1, 10, 1, 1, 2], dtype=tf.float32) - - ensembled = wbf.ensemble_boxes({'num_classes': 3}, tf.stack([d1, d2, d3])) - - self.assertAllClose(ensembled, [[1, 3, 1, 10, 1, 1, 2], [1, 2.5, 1, 10, 1, 0.75, 1]]) + self.assertAllClose(ensembled, + [[1, 3, 1, 10, 1, 1, 2], [1, 2.5, 1, 10, 1, 0.75, 1]]) if __name__ == '__main__':