Skip to content

Commit

Permalink
Enable inference with dynamic batch size in SSD.
Browse files Browse the repository at this point in the history
* Creates a new batch_decode method in SSD Meta architecture that can handle
  dynamic batch size.
* use combined_shapes in _get_feature_maps_spatial_dims method to handle
  dynamic batch image_size.
* Add dynamic batch size tests to check preprocess, predict and postprocess
  methods in SSD Meta architecture.
  • Loading branch information
derekjchow committed Jul 18, 2017
1 parent 5d5fb7c commit 4f14cb6
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 36 deletions.
3 changes: 1 addition & 2 deletions object_detection/meta_architectures/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ py_library(
srcs = ["ssd_meta_arch.py"],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/core:box_coder",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:model",
"//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:variables_helper",
"//tensorflow_models/object_detection/utils:shape_utils",
],
)

Expand Down
36 changes: 30 additions & 6 deletions object_detection/meta_architectures/ssd_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import re
import tensorflow as tf

from object_detection.core import box_coder as bcoder
from object_detection.core import box_list
from object_detection.core import box_predictor as bpredictor
from object_detection.core import model
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import shape_utils

slim = tf.contrib.slim

Expand Down Expand Up @@ -323,7 +323,8 @@ def _get_feature_map_spatial_dims(self, feature_maps):
a list of pairs (height, width) for each feature map in feature_maps
"""
feature_map_shapes = [
feature_map.get_shape().as_list() for feature_map in feature_maps
shape_utils.combined_static_and_dynamic_shape(
feature_map) for feature_map in feature_maps
]
return [(shape[1], shape[2]) for shape in feature_map_shapes]

Expand Down Expand Up @@ -364,8 +365,7 @@ def postprocess(self, prediction_dict):
with tf.name_scope('Postprocessor'):
box_encodings = prediction_dict['box_encodings']
class_predictions = prediction_dict['class_predictions_with_background']
detection_boxes = bcoder.batch_decode(box_encodings, self._box_coder,
self.anchors)
detection_boxes = self._batch_decode(box_encodings)
detection_boxes = tf.expand_dims(detection_boxes, axis=2)

class_predictions_without_background = tf.slice(class_predictions,
Expand Down Expand Up @@ -549,8 +549,7 @@ def _apply_hard_mining(self, location_losses, cls_losses, prediction_dict,
tf.slice(prediction_dict['class_predictions_with_background'],
[0, 0, 1], class_pred_shape), class_pred_shape)

decoded_boxes = bcoder.batch_decode(prediction_dict['box_encodings'],
self._box_coder, self.anchors)
decoded_boxes = self._batch_decode(prediction_dict['box_encodings'])
decoded_box_tensors_list = tf.unstack(decoded_boxes)
class_prediction_list = tf.unstack(class_predictions)
decoded_boxlist_list = []
Expand All @@ -565,6 +564,31 @@ def _apply_hard_mining(self, location_losses, cls_losses, prediction_dict,
decoded_boxlist_list=decoded_boxlist_list,
match_list=match_list)

def _batch_decode(self, box_encodings):
"""Decodes a batch of box encodings with respect to the anchors.
Args:
box_encodings: A float32 tensor of shape
[batch_size, num_anchors, box_code_size] containing box encodings.
Returns:
decoded_boxes: A float32 tensor of shape
[batch_size, num_anchors, 4] containing the decoded boxes.
"""
combined_shape = shape_utils.combined_static_and_dynamic_shape(
box_encodings)
batch_size = combined_shape[0]
tiled_anchor_boxes = tf.tile(
tf.expand_dims(self.anchors.get(), 0), [batch_size, 1, 1])
tiled_anchors_boxlist = box_list.BoxList(
tf.reshape(tiled_anchor_boxes, [-1, self._box_coder.code_size]))
decoded_boxes = self._box_coder.decode(
tf.reshape(box_encodings, [-1, self._box_coder.code_size]),
tiled_anchors_boxlist)
return tf.reshape(decoded_boxes.get(),
tf.stack([combined_shape[0], combined_shape[1],
4]))

def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.
Expand Down
83 changes: 58 additions & 25 deletions object_detection/meta_architectures/ssd_meta_arch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,46 @@ def image_resizer_fn(image):
localization_loss_weight, normalize_loss_by_num_matches,
hard_example_miner)

def test_preprocess_preserves_input_shapes(self):
image_shapes = [(3, None, None, 3),
(None, 10, 10, 3),
(None, None, None, 3)]
for image_shape in image_shapes:
image_placeholder = tf.placeholder(tf.float32, shape=image_shape)
preprocessed_inputs = self._model.preprocess(image_placeholder)
self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape)

def test_predict_results_have_correct_keys_and_shapes(self):
batch_size = 3
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3),
dtype=tf.float32)
prediction_dict = self._model.predict(preprocessed_input)

self.assertTrue('box_encodings' in prediction_dict)
self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict)

image_size = 2
input_shapes = [(batch_size, image_size, image_size, 3),
(None, image_size, image_size, 3),
(batch_size, None, None, 3),
(None, None, None, 3)]
expected_box_encodings_shape_out = (
batch_size, self._num_anchors, self._code_size)
expected_class_predictions_with_background_shape_out = (
batch_size, self._num_anchors, self._num_classes+1)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict)

for input_shape in input_shapes:
tf_graph = tf.Graph()
with tf_graph.as_default():
preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
prediction_dict = self._model.predict(preprocessed_input_placeholder)

self.assertTrue('box_encodings' in prediction_dict)
self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict)

init_op = tf.global_variables_initializer()
with self.test_session(graph=tf_graph) as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllEqual(prediction_out['box_encodings'].shape,
expected_box_encodings_shape_out)
self.assertAllEqual(
Expand All @@ -142,10 +164,11 @@ def test_predict_results_have_correct_keys_and_shapes(self):

def test_postprocess_results_are_correct(self):
batch_size = 2
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3),
dtype=tf.float32)
prediction_dict = self._model.predict(preprocessed_input)
detections = self._model.postprocess(prediction_dict)
image_size = 2
input_shapes = [(batch_size, image_size, image_size, 3),
(None, image_size, image_size, 3),
(batch_size, None, None, 3),
(None, None, None, 3)]

expected_boxes = np.array([[[0, 0, .5, .5],
[0, .5, .5, 1],
Expand All @@ -163,15 +186,25 @@ def test_postprocess_results_are_correct(self):
[0, 0, 0, 0, 0]])
expected_num_detections = np.array([4, 4])

self.assertTrue('detection_boxes' in detections)
self.assertTrue('detection_scores' in detections)
self.assertTrue('detection_classes' in detections)
self.assertTrue('num_detections' in detections)

init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
detections_out = sess.run(detections)
for input_shape in input_shapes:
tf_graph = tf.Graph()
with tf_graph.as_default():
preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
prediction_dict = self._model.predict(preprocessed_input_placeholder)
detections = self._model.postprocess(prediction_dict)
self.assertTrue('detection_boxes' in detections)
self.assertTrue('detection_scores' in detections)
self.assertTrue('detection_classes' in detections)
self.assertTrue('num_detections' in detections)
init_op = tf.global_variables_initializer()
with self.test_session(graph=tf_graph) as sess:
sess.run(init_op)
detections_out = sess.run(detections,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllClose(detections_out['detection_boxes'], expected_boxes)
self.assertAllClose(detections_out['detection_scores'], expected_scores)
self.assertAllClose(detections_out['detection_classes'], expected_classes)
Expand Down
1 change: 1 addition & 0 deletions object_detection/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ py_library(
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:matcher",
"//tensorflow_models/object_detection/utils:shape_utils"
],
)

Expand Down
8 changes: 5 additions & 3 deletions object_detection/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from object_detection.core import box_list
from object_detection.core import box_predictor
from object_detection.core import matcher
from object_detection.utils import shape_utils


class MockBoxCoder(box_coder.BoxCoder):
Expand All @@ -45,9 +46,10 @@ def __init__(self, is_training, num_classes):
super(MockBoxPredictor, self).__init__(is_training, num_classes)

def _predict(self, image_features, num_predictions_per_location):
batch_size = image_features.get_shape().as_list()[0]
num_anchors = (image_features.get_shape().as_list()[1]
* image_features.get_shape().as_list()[2])
combined_feature_shape = shape_utils.combined_static_and_dynamic_shape(
image_features)
batch_size = combined_feature_shape[0]
num_anchors = (combined_feature_shape[1] * combined_feature_shape[2])
code_size = 4
zero = tf.reduce_sum(0 * image_features)
box_encodings = zero + tf.zeros(
Expand Down

0 comments on commit 4f14cb6

Please sign in to comment.