From f7bf204cb361aaf238767c701ebb010b7f0ba986 Mon Sep 17 00:00:00 2001 From: Dan Ringwalt Date: Fri, 10 Aug 2018 12:12:03 -0700 Subject: [PATCH] BEGIN_PUBLIC Allow a different output shape from the input in tf.contrib.image.transform (#17011). END_PUBLIC RELNOTES: Allow a different output shape from the input in tf.contrib.image.transform. Thanks qyu@ for making the original change and fixing a few other prior issues! Automated rollback of commit 07fdb697d33478d7a72d09fc2371fa834e870b83 PiperOrigin-RevId: 208248183 --- tensorflow/contrib/image/kernels/image_ops.cc | 33 ++++++++--- tensorflow/contrib/image/kernels/image_ops.h | 2 +- tensorflow/contrib/image/ops/image_ops.cc | 57 ++++++++++++++++--- .../python/kernel_tests/image_ops_test.py | 44 ++++++++++++++ .../contrib/image/python/ops/image_ops.py | 52 +++++++++++------ 5 files changed, 156 insertions(+), 32 deletions(-) diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 022e17d13963a1..693724b45751b8 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -71,6 +71,7 @@ class ImageProjectiveTransform : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& images_t = ctx->input(0); const Tensor& transform_t = ctx->input(1); + const Tensor& shape_t = ctx->input(2); OP_REQUIRES(ctx, images_t.shape().dims() == 4, errors::InvalidArgument("Input images must have rank 4")); OP_REQUIRES(ctx, @@ -81,11 +82,28 @@ class ImageProjectiveTransform : public OpKernel { ProjectiveGenerator::kNumParameters), errors::InvalidArgument( "Input transform should be num_images x 8 or 1 x 8")); - auto images = images_t.tensor(); - auto transform = transform_t.matrix(); + OP_REQUIRES(ctx, shape_t.dims() == 1, + errors::InvalidArgument("output shape must be 1-dimensional", + shape_t.shape().DebugString())); + OP_REQUIRES(ctx, shape_t.NumElements() == 2, + errors::InvalidArgument("output shape must have two elements", + shape_t.shape().DebugString())); + auto shape_vec = shape_t.vec(); + int32 out_height = shape_vec(0); + int32 out_width = shape_vec(1); + OP_REQUIRES(ctx, out_height > 0 && out_width > 0, + errors::InvalidArgument("output dimensions must be positive")); + Tensor* output_t; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 0, + TensorShape({images_t.dim_size(0), out_height, + out_width, images_t.dim_size(3)}), + &output_t)); auto output = output_t->tensor(); + auto images = images_t.tensor(); + auto transform = transform_t.matrix(); + (FillProjectiveTransform(interpolation_))( ctx->eigen_device(), &output, images, transform); } @@ -129,10 +147,11 @@ TF_CALL_double(DECLARE_FUNCTOR); } // end namespace functor -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("dtype"), \ +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("dtype") \ + .HostMemory("output_shape"), \ ImageProjectiveTransform) TF_CALL_uint8(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index 209aa24548443b..6b63eed1303acc 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -167,7 +167,7 @@ struct FillProjectiveTransform { void operator()(const Device& device, OutputType* output, const InputType& images, const TransformsType& transform) const { - output->device(device) = images.generate( + output->device(device) = output->generate( ProjectiveGenerator(images, transform, interpolation_)); } }; diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index e59f1bf8443732..4969ac58f96c8c 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -19,23 +19,66 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; +namespace { + +// Sets output[0] to shape [batch_dim,height,width,channel_dim], where +// height and width come from the size_tensor. +Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, + int size_input_idx, DimensionHandle channel_dim) { + // Verify shape of size input. + ShapeHandle size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size)); + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused)); + + // Get size values from the size tensor. + const Tensor* size_tensor = c->input_tensor(size_input_idx); + DimensionHandle width; + DimensionHandle height; + if (size_tensor == nullptr) { + width = c->UnknownDim(); + height = c->UnknownDim(); + } else { + // TODO(petewarden) - Remove once we have constant evaluation in C++ only. + if (size_tensor->dtype() != DT_INT32) { + return errors::InvalidArgument( + "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 " + "but got ", + DataTypeString(size_tensor->dtype()), " for input #", size_input_idx, + " in ", c->DebugString()); + } + auto vec = size_tensor->vec(); + height = c->MakeDim(vec(0)); + width = c->MakeDim(vec(1)); + } + c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim})); + return Status::OK(); +} + +// TODO(qyu): Move this to core/framework/common_shape_fns.h +Status ResizeShapeFn(InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */, + c->Dim(input, 3)); +} + +} // namespace + // TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc. // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). -// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to -// implement "same" and "valid" modes in the Python function. REGISTER_OP("ImageProjectiveTransform") .Input("images: dtype") .Input("transforms: float32") + .Input("output_shape: int32") .Attr("dtype: {uint8, int32, int64, float16, float32, float64}") .Attr("interpolation: string") .Output("transformed_images: dtype") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); - }) + .SetShapeFn(ResizeShapeFn) .Doc(R"doc( Applies the given transform to each of the images. @@ -49,7 +92,7 @@ If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the *output* point `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. If the transformed point lays outside of the input -image, the output pixel is set to 0. The output is the same size as the input, +image, the output pixel is set to 0. images: 4D `Tensor`, input image(s) in NHWC format. transforms: 2D `Tensor`, projective transform(s) to apply to the image(s). diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 62a22dcf3411fb..f588eae923f403 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import googletest _DTYPES = set( @@ -194,6 +195,19 @@ def test_bilinear_uint8(self): [0.0, 149, 233, 149, 0.0], [0.0, 0.0, 87., 0.0, 0.0]]) + def test_rotate_static_shape(self): + image = array_ops.diag([1., 2., 3.]) + result = image_ops.rotate( + image, random_ops.random_uniform((), -1, 1), interpolation="BILINEAR") + self.assertEqual(image.get_shape(), result.get_shape()) + + def test_transform_static_output_shape(self): + image = constant_op.constant([[1., 2.], [3., 4.]]) + result = image_ops.transform( + image, random_ops.random_uniform([8], -1, 1), + output_shape=constant_op.constant([3, 5])) + self.assertAllEqual([3, 5], result.get_shape()) + def _test_grad(self, shape_to_test): with self.test_session(): test_image_shape = shape_to_test @@ -213,10 +227,40 @@ def _test_grad(self, shape_to_test): x_init_value=test_image) self.assertLess(left_err, 1e-10) + def _test_grad_different_shape(self, input_shape, output_shape): + with self.test_session(): + test_image_shape = input_shape + test_image = np.random.randn(*test_image_shape) + test_image_tensor = constant_op.constant( + test_image, shape=test_image_shape) + test_transform = image_ops.angles_to_projective_transforms( + np.pi / 2, 4, 4) + + if len(output_shape) == 2: + resize_shape = output_shape + elif len(output_shape) == 3: + resize_shape = output_shape[0:2] + elif len(output_shape) == 4: + resize_shape = output_shape[1:3] + output = image_ops.transform( + images=test_image_tensor, + transforms=test_transform, + output_shape=resize_shape) + left_err = gradient_checker.compute_gradient_error( + test_image_tensor, + test_image_shape, + output, + output_shape, + x_init_value=test_image) + self.assertLess(left_err, 1e-10) + def test_grad(self): self._test_grad([16, 16]) self._test_grad([4, 12, 12]) self._test_grad([3, 4, 12, 12]) + self._test_grad_different_shape([16, 16], [8, 8]) + self._test_grad_different_shape([4, 12, 3], [8, 24, 3]) + self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3]) class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 86b0ffe9a0f223..e7a09041adb339 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops @@ -40,6 +41,9 @@ ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) +# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name +# used by PIL, maybe more readable) mode, which determines the correct +# output_shape and translation for the transform. def rotate(images, angles, interpolation="NEAREST", name=None): """Rotate image(s) counterclockwise by the passed angle(s) in radians. @@ -213,7 +217,11 @@ def translations_to_projective_transforms(translations, name=None): axis=1) -def transform(images, transforms, interpolation="NEAREST", name=None): +def transform(images, + transforms, + interpolation="NEAREST", + output_shape=None, + name=None): """Applies the given transform(s) to the image(s). Args: @@ -230,6 +238,10 @@ def transform(images, transforms, interpolation="NEAREST", name=None): the transform mapping input points to output points. Note that gradients are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. Returns: Image(s) with the same type and shape as `images`, with the given @@ -238,6 +250,7 @@ def transform(images, transforms, interpolation="NEAREST", name=None): Raises: TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. """ with ops.name_scope(name, "transform"): image_or_images = ops.convert_to_tensor(images, name="images") @@ -256,6 +269,17 @@ def transform(images, transforms, interpolation="NEAREST", name=None): else: raise TypeError("Images should have rank between 2 and 4.") + if output_shape is None: + output_shape = tensor_util.constant_value( + array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3] + + output_shape = ops.convert_to_tensor( + output_shape, dtypes.int32, name="output_shape") + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " + "new_height, new_width") + if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: @@ -265,8 +289,12 @@ def transform(images, transforms, interpolation="NEAREST", name=None): transforms = transform_or_transforms else: raise TypeError("Transforms should have rank 1 or 2.") + output = gen_image_ops.image_projective_transform( - images, transforms, interpolation=interpolation.upper()) + images, + output_shape=output_shape, + transforms=transforms, + interpolation=interpolation.upper()) if len(image_or_images.get_shape()) == 2: return output[0, :, :, 0] elif len(image_or_images.get_shape()) == 3: @@ -376,14 +404,6 @@ def _image_projective_transform_grad(op, grad): if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4") if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif len(transform_or_transforms.get_shape()) == 2: @@ -396,13 +416,11 @@ def _image_projective_transform_grad(op, grad): inverse = linalg_ops.matrix_inverse(transforms) transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( - grad, transforms, interpolation=interpolation) - if len(image_or_images.get_shape()) == 2: - return [output[0, :, :, 0], None] - elif len(image_or_images.get_shape()) == 3: - return [output[0, :, :, :], None] - else: - return [output, None] + images=grad, + transforms=transforms, + output_shape=array_ops.shape(image_or_images)[1:3], + interpolation=interpolation) + return [output, None, None] def bipartite_match(distance_mat,