From 894a074bfaa61a01107d0f5b579df5b00c95aa3d Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 17 Jun 2019 00:57:56 +0800 Subject: [PATCH] Support batch input for median_filter2d (#288) --- tensorflow_addons/image/filters.py | 162 +++++++++---------- tensorflow_addons/image/filters_test.py | 198 ++++++++++++++---------- 2 files changed, 193 insertions(+), 167 deletions(-) diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index b567d538b6..503d4657c9 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -21,20 +21,6 @@ from tensorflow_addons.utils import keras_utils -@tf.function -def _normalize(li, ma): - one = tf.convert_to_tensor(1.0) - two = tf.convert_to_tensor(255.0) - - def func1(): - return li - - def func2(): - return tf.math.truediv(li, two) - - return tf.cond(tf.math.greater(ma, one), func2, func1) - - def _pad(image, filter_shape, mode="CONSTANT", constant_values=0): """Explicitly pad a 4-D image. @@ -142,80 +128,84 @@ def mean_filter2d(image, @tf.function -def median_filter2d(image, filter_shape=(3, 3), name=None): - """This method performs Median Filtering on image. Filter shape can be user - given. +def median_filter2d(image, + filter_shape=(3, 3), + padding="REFLECT", + constant_values=0, + name=None): + """Perform median filtering on image(s). - This method takes both kind of images where pixel values lie between 0 to - 255 and where it lies between 0.0 and 1.0 Args: - image: A 3D `Tensor` of type `float32` or 'int32' or 'float64' or - 'int64 and of shape`[rows, columns, channels]` - - filter_shape: Optional Argument. A tuple of 2 integers (R,C). - R is the first value is the number of rows in the filter and - C is the second value in the filter is the number of columns - in the filter. This creates a filter of shape (R,C) or RxC - filter. Default value = (3,3) - name: The name of the op. - - Returns: - A 3D median filtered image tensor of shape [rows,columns,channels] and - type 'int32'. Pixel value of returned tensor ranges between 0 to 255 + image: Either a 3-D `Tensor` of shape `[height, width, channels]`, + or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`. + filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying + the height and width of the 2-D median filter. Can be a single integer + to specify the same value for all spatial dimensions. + padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". + The type of padding algorithm to use, which is compatible with + `mode` argument in `tf.pad`. For more details, please refer to + https://www.tensorflow.org/api_docs/python/tf/pad. + constant_values: A `scalar`, the pad value to use in "CONSTANT" + padding mode. + name: A name for this operation (optional). + Returns: + 3-D or 4-D `Tensor` of the same dtype as input. + Raises: + ValueError: If `image` is not 3 or 4-dimensional, + if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC", + or if `filter_shape` is invalid. """ - with tf.name_scope(name or "median_filter2d"): - if not isinstance(filter_shape, tuple): - raise TypeError('Filter shape must be a tuple') - if len(filter_shape) != 2: - raise ValueError('Filter shape must be a tuple of 2 integers. ' - 'Got %s values in tuple' % len(filter_shape)) - filter_shapex = filter_shape[0] - filter_shapey = filter_shape[1] - if not isinstance(filter_shapex, int) or not isinstance( - filter_shapey, int): - raise TypeError('Size of the filter must be Integers') - (row, col, ch) = (image.shape[0], image.shape[1], image.shape[2]) - if row != None and col != None and ch != None: - (row, col, ch) = (int(row), int(col), int(ch)) - else: - raise TypeError('All the Dimensions of the input image ' - 'tensor must be Integers.') - if row < filter_shapex or col < filter_shapey: + image = tf.convert_to_tensor(image, name="image") + + rank = image.shape.rank + if rank != 3 and rank != 4: + raise ValueError("image should be either 3 or 4-dimensional.") + + if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]: raise ValueError( - 'Number of Pixels in each dimension of the image should be \ - more than the filter size. Got filter_shape (%sx' % - filter_shape[0] + '%s).' % filter_shape[1] + - ' Image Shape (%s)' % image.shape) - if filter_shapex % 2 == 0 or filter_shapey % 2 == 0: - raise ValueError('Filter size should be odd. Got filter_shape ' - '(%sx%s)' % (filter_shape[0], filter_shape[1])) - image = tf.cast(image, tf.float32) - tf_i = tf.reshape(image, [row * col * ch]) - ma = tf.math.reduce_max(tf_i) - image = _normalize(image, ma) - - # k and l is the Zero-padding size - - listi = [] - for a in range(ch): - img = image[:, :, a:a + 1] - img = tf.reshape(img, [1, row, col, 1]) - slic = tf.image.extract_patches( - img, [1, filter_shapex, filter_shapey, 1], [1, 1, 1, 1], - [1, 1, 1, 1], - padding='SAME') - mid = int(filter_shapex * filter_shapey / 2 + 1) - top = tf.nn.top_k(slic, mid, sorted=True) - li = tf.slice(top[0], [0, 0, 0, mid - 1], [-1, -1, -1, 1]) - li = tf.reshape(li, [row, col, 1]) - listi.append(li) - y = tf.concat(listi[0], 2) - - for i in range(len(listi) - 1): - y = tf.concat([y, listi[i + 1]], 2) - - y *= 255 - y = tf.cast(y, tf.int32) - - return y + "padding should be one of \"REFLECT\", \"CONSTANT\", or " + "\"SYMMETRIC\".") + + filter_shape = keras_utils.conv_utils.normalize_tuple( + filter_shape, 2, "filter_shape") + + # Expand to a 4-D tensor + if rank == 3: + image = tf.expand_dims(image, axis=0) + + # Explicitly pad the image + image = _pad( + image, filter_shape, mode=padding, constant_values=constant_values) + + floor = (filter_shape[0] * filter_shape[1] + 1) // 2 + ceil = (filter_shape[0] * filter_shape[1]) // 2 + 1 + + def _median_filter2d_single_channel(x): + x = tf.expand_dims(x, axis=-1) + patches = tf.image.extract_patches( + x, + sizes=[1, filter_shape[0], filter_shape[1], 1], + strides=[1, 1, 1, 1], + rates=[1, 1, 1, 1], + padding="VALID") + + # Note the returned median is casted back to the original type + # Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5 + # It turns out to be int(6.5) = 6 if the original type is int + top = tf.nn.top_k(patches, k=ceil).values + median = (top[:, :, :, floor - 1] + top[:, :, :, ceil - 1]) / 2 + return tf.dtypes.cast(median, x.dtype) + + output = tf.map_fn( + _median_filter2d_single_channel, + elems=tf.transpose(image, [3, 0, 1, 2]), + dtype=image.dtype) + output = tf.transpose(output, [1, 2, 3, 0]) + + # Squeeze out the first axis to make sure + # output has the same dimension with image. + if rank == 3: + output = tf.squeeze(output, axis=0) + + return output diff --git a/tensorflow_addons/image/filters_test.py b/tensorflow_addons/image/filters_test.py index 8e8e29011e..2121b89b1e 100644 --- a/tensorflow_addons/image/filters_test.py +++ b/tensorflow_addons/image/filters_test.py @@ -23,8 +23,16 @@ from tensorflow_addons.utils import test_utils -@test_utils.run_all_in_graph_and_eager_modes -class MeanFilter2dTest(tf.test.TestCase): +class _Filter2dTest(tf.test.TestCase): + def setUp(self): + self._dtypes_to_test = [ + tf.dtypes.uint8, tf.dtypes.int32, tf.dtypes.float16, + tf.dtypes.float32, tf.dtypes.float64 + ] + self._image_shapes_to_test = [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), + (1, 3, 3, 3), (2, 3, 3, 1), (2, 3, 3, 3)] + super(_Filter2dTest, self).setUp() + def _tile_image(self, plane, image_shape): """Tile a 2-D image `plane` into 3-D or 4-D as per `image_shape`.""" assert 3 <= len(image_shape) <= 4 @@ -49,7 +57,7 @@ def _setup_values(self, image_shape, filter_shape, padding, dtype=dtype) image = self._tile_image(plane, image_shape=image_shape) - result = mean_filter2d( + result = self._filter2d_fn( image, filter_shape=filter_shape, padding=padding, @@ -59,17 +67,20 @@ def _setup_values(self, image_shape, filter_shape, padding, def _verify_values(self, image_shape, filter_shape, padding, constant_values, expected_plane): - expected_output = self._tile_image(expected_plane, image_shape) - dtypes = tf.dtypes - for dtype in [ - dtypes.uint8, dtypes.float16, dtypes.float32, dtypes.float64 - ]: + for dtype in self._dtypes_to_test: result = self._setup_values(image_shape, filter_shape, padding, constant_values, dtype) self.assertAllCloseAccordingToType( result, tf.dtypes.cast(expected_output, dtype)) + +@test_utils.run_all_in_graph_and_eager_modes +class MeanFilter2dTest(_Filter2dTest): + def setUp(self): + self._filter2d_fn = mean_filter2d + super(MeanFilter2dTest, self).setUp() + def test_invalid_image(self): msg = "image should be either 3 or 4-dimensional." @@ -113,8 +124,7 @@ def test_reflect_padding_with_3x3_filter(self): [42. / 9., 45. / 9., 48. / 9.], [51. / 9., 54. / 9., 57. / 9.]]) - for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), - (2, 3, 3, 1), (2, 3, 3, 3)]: + for image_shape in self._image_shapes_to_test: self._verify_values( image_shape=image_shape, filter_shape=(3, 3), @@ -127,8 +137,7 @@ def test_reflect_padding_with_4x4_filter(self): [80. / 16., 80. / 16., 80. / 16.], [80. / 16., 80. / 16., 80. / 16.]]) - for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), - (2, 3, 3, 1), (2, 3, 3, 3)]: + for image_shape in self._image_shapes_to_test: self._verify_values( image_shape=image_shape, filter_shape=(4, 4), @@ -141,8 +150,7 @@ def test_constant_padding_with_3x3_filter(self): [27. / 9., 45. / 9., 33. / 9.], [24. / 9., 39. / 9., 28. / 9.]]) - for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), - (2, 3, 3, 1), (2, 3, 3, 3)]: + for image_shape in self._image_shapes_to_test: self._verify_values( image_shape=image_shape, filter_shape=(3, 3), @@ -154,8 +162,7 @@ def test_constant_padding_with_3x3_filter(self): [30. / 9., 45. / 9., 36. / 9.], [29. / 9., 42. / 9., 33. / 9.]]) - for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), - (2, 3, 3, 1), (2, 3, 3, 3)]: + for image_shape in self._image_shapes_to_test: self._verify_values( image_shape=image_shape, filter_shape=(3, 3), @@ -168,8 +175,7 @@ def test_symmetric_padding_with_3x3_filter(self): [39. / 9., 45. / 9., 51. / 9.], [57. / 9., 63. / 9., 69. / 9.]]) - for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), - (2, 3, 3, 1), (2, 3, 3, 3)]: + for image_shape in self._image_shapes_to_test: self._verify_values( image_shape=image_shape, filter_shape=(3, 3), @@ -178,73 +184,103 @@ def test_symmetric_padding_with_3x3_filter(self): expected_plane=expected_plane) -class MedianFilter2dTest(tf.test.TestCase): - def _validate_median_filter2d(self, - inputs, - expected_values, - filter_shape=(3, 3)): - output = median_filter2d(inputs, filter_shape) - self.assertAllClose(output, expected_values) +@test_utils.run_all_in_graph_and_eager_modes +class MedianFilter2dTest(_Filter2dTest): + def setUp(self): + self._filter2d_fn = median_filter2d + super(MedianFilter2dTest, self).setUp() + + def test_invalid_image(self): + msg = "image should be either 3 or 4-dimensional." + + for image_shape in [(28, 28), (16, 28, 28, 1, 1)]: + with self.subTest(dim=len(image_shape)): + with self.assertRaisesRegexp(ValueError, msg): + median_filter2d(tf.ones(shape=image_shape)) + + def test_invalid_filter_shape(self): + msg = ("The `filter_shape` argument must be a tuple of 2 integers.") + image = tf.ones(shape=(1, 28, 28, 1)) - @test_utils.run_in_graph_and_eager_modes - def test_filter_tuple(self): - tf_img = tf.zeros([3, 4, 3], tf.int32) + for filter_shape in [(3, 3, 3), (3, None, 3), None]: + with self.subTest(filter_shape=filter_shape): + with self.assertRaisesRegexp(ValueError, msg): + median_filter2d(image, filter_shape=filter_shape) - for filter_shape in [3, 3.5, 'dt', None]: - with self.assertRaisesRegexp(TypeError, - 'Filter shape must be a tuple'): - median_filter2d(tf_img, filter_shape) + def test_invalid_padding(self): + msg = ("padding should be one of \"REFLECT\", \"CONSTANT\", " + "or \"SYMMETRIC\".") + image = tf.ones(shape=(1, 28, 28, 1)) - filter_shape = (3, 3, 3) - msg = ('Filter shape must be a tuple of 2 integers. ' - 'Got %s values in tuple' % len(filter_shape)) with self.assertRaisesRegexp(ValueError, msg): - median_filter2d(tf_img, filter_shape) - - msg = 'Size of the filter must be Integers' - for filter_shape in [(3.5, 3), (None, 3)]: - with self.assertRaisesRegexp(TypeError, msg): - median_filter2d(tf_img, filter_shape) - - @test_utils.run_in_graph_and_eager_modes - def test_filter_value(self): - tf_img = tf.zeros([3, 4, 3], tf.int32) - - with self.assertRaises(ValueError): - median_filter2d(tf_img, (4, 3)) - - @test_utils.run_deprecated_v1 - def test_dimension(self): - for image_shape in [(3, 4, None), (3, None, 4), (None, 3, 4)]: - with self.assertRaises(TypeError): - tf_img = tf.compat.v1.placeholder(tf.int32, shape=image_shape) - median_filter2d(tf_img) - - @test_utils.run_in_graph_and_eager_modes - def test_image_vs_filter(self): - tf_img = tf.zeros([3, 4, 3], tf.int32) - filter_shape = (3, 5) - with self.assertRaises(ValueError): - median_filter2d(tf_img, filter_shape) - - @test_utils.run_in_graph_and_eager_modes - def test_three_channels(self): - tf_img = [[[0.32801723, 0.08863795, 0.79119259], - [0.35526001, 0.79388736, 0.55435993], - [0.11607035, 0.55673079, 0.99473371]], - [[0.53240645, 0.74684819, 0.33700031], - [0.01760473, 0.28181609, 0.9751476], - [0.01605137, 0.8292904, 0.56405609]], - [[0.57215374, 0.10155051, 0.64836128], - [0.36533048, 0.91401874, 0.02524159], - [0.56379134, 0.9028874, 0.19505117]]] - - tf_img = tf.convert_to_tensor(value=tf_img) - expt = [[[0, 0, 0], [4, 71, 141], [0, 0, 0]], - [[83, 25, 85], [90, 190, 143], [4, 141, 49]], - [[0, 0, 0], [4, 71, 49], [0, 0, 0]]] - expt = tf.convert_to_tensor(value=expt) - self._validate_median_filter2d(tf_img, expt) + median_filter2d(image, padding="TEST") + + def test_none_channels(self): + # 3-D image + fn = median_filter2d.get_concrete_function( + tf.TensorSpec(dtype=tf.dtypes.float32, shape=(3, 3, None))) + fn(tf.ones(shape=(3, 3, 1))) + fn(tf.ones(shape=(3, 3, 3))) + + # 4-D image + fn = median_filter2d.get_concrete_function( + tf.TensorSpec(dtype=tf.dtypes.float32, shape=(1, 3, 3, None))) + fn(tf.ones(shape=(1, 3, 3, 1))) + fn(tf.ones(shape=(1, 3, 3, 3))) + + def test_reflect_padding_with_3x3_filter(self): + expected_plane = tf.constant([[4, 4, 5], [5, 5, 5], [5, 6, 6]]) + + for image_shape in self._image_shapes_to_test: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="REFLECT", + constant_values=0, + expected_plane=expected_plane) + + def test_reflect_padding_with_4x4_filter(self): + expected_plane = tf.constant([[5, 5, 5], [5, 5, 5], [5, 5, 5]]) + + for image_shape in self._image_shapes_to_test: + self._verify_values( + image_shape=image_shape, + filter_shape=(4, 4), + padding="REFLECT", + constant_values=0, + expected_plane=expected_plane) + + def test_constant_padding_with_3x3_filter(self): + expected_plane = tf.constant([[0, 2, 0], [2, 5, 3], [0, 5, 0]]) + + for image_shape in self._image_shapes_to_test: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="CONSTANT", + constant_values=0, + expected_plane=expected_plane) + + expected_plane = tf.constant([[1, 2, 1], [2, 5, 3], [1, 5, 1]]) + + for image_shape in self._image_shapes_to_test: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="CONSTANT", + constant_values=1, + expected_plane=expected_plane) + + def test_symmetric_padding_with_3x3_filter(self): + expected_plane = tf.constant([[2, 3, 3], [4, 5, 6], [7, 7, 8]]) + + for image_shape in self._image_shapes_to_test: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="SYMMETRIC", + constant_values=0, + expected_plane=expected_plane) if __name__ == "__main__":