Skip to content

Commit

Permalink
Support batch input for median_filter2d (tensorflow#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
WindQAQ authored and Squadrick committed Jun 16, 2019
1 parent 320ad67 commit 894a074
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 167 deletions.
162 changes: 76 additions & 86 deletions tensorflow_addons/image/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,6 @@
from tensorflow_addons.utils import keras_utils


@tf.function
def _normalize(li, ma):
one = tf.convert_to_tensor(1.0)
two = tf.convert_to_tensor(255.0)

def func1():
return li

def func2():
return tf.math.truediv(li, two)

return tf.cond(tf.math.greater(ma, one), func2, func1)


def _pad(image, filter_shape, mode="CONSTANT", constant_values=0):
"""Explicitly pad a 4-D image.
Expand Down Expand Up @@ -142,80 +128,84 @@ def mean_filter2d(image,


@tf.function
def median_filter2d(image, filter_shape=(3, 3), name=None):
"""This method performs Median Filtering on image. Filter shape can be user
given.
def median_filter2d(image,
filter_shape=(3, 3),
padding="REFLECT",
constant_values=0,
name=None):
"""Perform median filtering on image(s).
This method takes both kind of images where pixel values lie between 0 to
255 and where it lies between 0.0 and 1.0
Args:
image: A 3D `Tensor` of type `float32` or 'int32' or 'float64' or
'int64 and of shape`[rows, columns, channels]`
filter_shape: Optional Argument. A tuple of 2 integers (R,C).
R is the first value is the number of rows in the filter and
C is the second value in the filter is the number of columns
in the filter. This creates a filter of shape (R,C) or RxC
filter. Default value = (3,3)
name: The name of the op.
Returns:
A 3D median filtered image tensor of shape [rows,columns,channels] and
type 'int32'. Pixel value of returned tensor ranges between 0 to 255
image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
the height and width of the 2-D median filter. Can be a single integer
to specify the same value for all spatial dimensions.
padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC".
The type of padding algorithm to use, which is compatible with
`mode` argument in `tf.pad`. For more details, please refer to
https://www.tensorflow.org/api_docs/python/tf/pad.
constant_values: A `scalar`, the pad value to use in "CONSTANT"
padding mode.
name: A name for this operation (optional).
Returns:
3-D or 4-D `Tensor` of the same dtype as input.
Raises:
ValueError: If `image` is not 3 or 4-dimensional,
if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
or if `filter_shape` is invalid.
"""

with tf.name_scope(name or "median_filter2d"):
if not isinstance(filter_shape, tuple):
raise TypeError('Filter shape must be a tuple')
if len(filter_shape) != 2:
raise ValueError('Filter shape must be a tuple of 2 integers. '
'Got %s values in tuple' % len(filter_shape))
filter_shapex = filter_shape[0]
filter_shapey = filter_shape[1]
if not isinstance(filter_shapex, int) or not isinstance(
filter_shapey, int):
raise TypeError('Size of the filter must be Integers')
(row, col, ch) = (image.shape[0], image.shape[1], image.shape[2])
if row != None and col != None and ch != None:
(row, col, ch) = (int(row), int(col), int(ch))
else:
raise TypeError('All the Dimensions of the input image '
'tensor must be Integers.')
if row < filter_shapex or col < filter_shapey:
image = tf.convert_to_tensor(image, name="image")

rank = image.shape.rank
if rank != 3 and rank != 4:
raise ValueError("image should be either 3 or 4-dimensional.")

if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]:
raise ValueError(
'Number of Pixels in each dimension of the image should be \
more than the filter size. Got filter_shape (%sx' %
filter_shape[0] + '%s).' % filter_shape[1] +
' Image Shape (%s)' % image.shape)
if filter_shapex % 2 == 0 or filter_shapey % 2 == 0:
raise ValueError('Filter size should be odd. Got filter_shape '
'(%sx%s)' % (filter_shape[0], filter_shape[1]))
image = tf.cast(image, tf.float32)
tf_i = tf.reshape(image, [row * col * ch])
ma = tf.math.reduce_max(tf_i)
image = _normalize(image, ma)

# k and l is the Zero-padding size

listi = []
for a in range(ch):
img = image[:, :, a:a + 1]
img = tf.reshape(img, [1, row, col, 1])
slic = tf.image.extract_patches(
img, [1, filter_shapex, filter_shapey, 1], [1, 1, 1, 1],
[1, 1, 1, 1],
padding='SAME')
mid = int(filter_shapex * filter_shapey / 2 + 1)
top = tf.nn.top_k(slic, mid, sorted=True)
li = tf.slice(top[0], [0, 0, 0, mid - 1], [-1, -1, -1, 1])
li = tf.reshape(li, [row, col, 1])
listi.append(li)
y = tf.concat(listi[0], 2)

for i in range(len(listi) - 1):
y = tf.concat([y, listi[i + 1]], 2)

y *= 255
y = tf.cast(y, tf.int32)

return y
"padding should be one of \"REFLECT\", \"CONSTANT\", or "
"\"SYMMETRIC\".")

filter_shape = keras_utils.conv_utils.normalize_tuple(
filter_shape, 2, "filter_shape")

# Expand to a 4-D tensor
if rank == 3:
image = tf.expand_dims(image, axis=0)

# Explicitly pad the image
image = _pad(
image, filter_shape, mode=padding, constant_values=constant_values)

floor = (filter_shape[0] * filter_shape[1] + 1) // 2
ceil = (filter_shape[0] * filter_shape[1]) // 2 + 1

def _median_filter2d_single_channel(x):
x = tf.expand_dims(x, axis=-1)
patches = tf.image.extract_patches(
x,
sizes=[1, filter_shape[0], filter_shape[1], 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding="VALID")

# Note the returned median is casted back to the original type
# Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5
# It turns out to be int(6.5) = 6 if the original type is int
top = tf.nn.top_k(patches, k=ceil).values
median = (top[:, :, :, floor - 1] + top[:, :, :, ceil - 1]) / 2
return tf.dtypes.cast(median, x.dtype)

output = tf.map_fn(
_median_filter2d_single_channel,
elems=tf.transpose(image, [3, 0, 1, 2]),
dtype=image.dtype)
output = tf.transpose(output, [1, 2, 3, 0])

# Squeeze out the first axis to make sure
# output has the same dimension with image.
if rank == 3:
output = tf.squeeze(output, axis=0)

return output
Loading

0 comments on commit 894a074

Please sign in to comment.