Skip to content

Commit

Permalink
BEGIN_PUBLIC
Browse files Browse the repository at this point in the history
Allow a different output shape from the input in tf.contrib.image.transform (tensorflow#17011).
END_PUBLIC

RELNOTES: Allow a different output shape from the input in tf.contrib.image.transform.

Thanks qyu@ for making the original change and fixing a few other prior issues!

Automated rollback of commit 07fdb69

PiperOrigin-RevId: 208248183
  • Loading branch information
ringw authored and tensorflower-gardener committed Aug 10, 2018
1 parent d2bec06 commit f7bf204
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 32 deletions.
33 changes: 26 additions & 7 deletions tensorflow/contrib/image/kernels/image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ImageProjectiveTransform : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
const Tensor& shape_t = ctx->input(2);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
Expand All @@ -81,11 +82,28 @@ class ImageProjectiveTransform : public OpKernel {
ProjectiveGenerator<Device, T>::kNumParameters),
errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8"));
auto images = images_t.tensor<T, 4>();
auto transform = transform_t.matrix<float>();
OP_REQUIRES(ctx, shape_t.dims() == 1,
errors::InvalidArgument("output shape must be 1-dimensional",
shape_t.shape().DebugString()));
OP_REQUIRES(ctx, shape_t.NumElements() == 2,
errors::InvalidArgument("output shape must have two elements",
shape_t.shape().DebugString()));
auto shape_vec = shape_t.vec<int32>();
int32 out_height = shape_vec(0);
int32 out_width = shape_vec(1);
OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
errors::InvalidArgument("output dimensions must be positive"));

Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
OP_REQUIRES_OK(ctx, ctx->allocate_output(
0,
TensorShape({images_t.dim_size(0), out_height,
out_width, images_t.dim_size(3)}),
&output_t));
auto output = output_t->tensor<T, 4>();
auto images = images_t.tensor<T, 4>();
auto transform = transform_t.matrix<float>();

(FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform);
}
Expand Down Expand Up @@ -129,10 +147,11 @@ TF_CALL_double(DECLARE_FUNCTOR);

} // end namespace functor

#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("dtype"), \
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("dtype") \
.HostMemory("output_shape"), \
ImageProjectiveTransform<GPUDevice, TYPE>)

TF_CALL_uint8(REGISTER);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/image/kernels/image_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ struct FillProjectiveTransform {
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
output->device(device) = images.generate(
output->device(device) = output->generate(
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
}
};
Expand Down
57 changes: 50 additions & 7 deletions tensorflow/contrib/image/ops/image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,66 @@ limitations under the License.

namespace tensorflow {

using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

namespace {

// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
// height and width come from the size_tensor.
Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
int size_input_idx, DimensionHandle channel_dim) {
// Verify shape of size input.
ShapeHandle size;
TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));

// Get size values from the size tensor.
const Tensor* size_tensor = c->input_tensor(size_input_idx);
DimensionHandle width;
DimensionHandle height;
if (size_tensor == nullptr) {
width = c->UnknownDim();
height = c->UnknownDim();
} else {
// TODO(petewarden) - Remove once we have constant evaluation in C++ only.
if (size_tensor->dtype() != DT_INT32) {
return errors::InvalidArgument(
"Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
"but got ",
DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
" in ", c->DebugString());
}
auto vec = size_tensor->vec<int32>();
height = c->MakeDim(vec(0));
width = c->MakeDim(vec(1));
}
c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
return Status::OK();
}

// TODO(qyu): Move this to core/framework/common_shape_fns.h
Status ResizeShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
c->Dim(input, 3));
}

} // namespace

// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
// implement "same" and "valid" modes in the Python function.
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
.Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
.SetShapeFn(ResizeShapeFn)
.Doc(R"doc(
Applies the given transform to each of the images.
Expand All @@ -49,7 +92,7 @@ If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps
the *output* point `(x, y)` to a transformed *input* point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
`k = c0 x + c1 y + 1`. If the transformed point lays outside of the input
image, the output pixel is set to 0. The output is the same size as the input,
image, the output pixel is set to 0.
images: 4D `Tensor`, input image(s) in NHWC format.
transforms: 2D `Tensor`, projective transform(s) to apply to the image(s).
Expand Down
44 changes: 44 additions & 0 deletions tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import googletest

_DTYPES = set(
Expand Down Expand Up @@ -194,6 +195,19 @@ def test_bilinear_uint8(self):
[0.0, 149, 233, 149, 0.0],
[0.0, 0.0, 87., 0.0, 0.0]])

def test_rotate_static_shape(self):
image = array_ops.diag([1., 2., 3.])
result = image_ops.rotate(
image, random_ops.random_uniform((), -1, 1), interpolation="BILINEAR")
self.assertEqual(image.get_shape(), result.get_shape())

def test_transform_static_output_shape(self):
image = constant_op.constant([[1., 2.], [3., 4.]])
result = image_ops.transform(
image, random_ops.random_uniform([8], -1, 1),
output_shape=constant_op.constant([3, 5]))
self.assertAllEqual([3, 5], result.get_shape())

def _test_grad(self, shape_to_test):
with self.test_session():
test_image_shape = shape_to_test
Expand All @@ -213,10 +227,40 @@ def _test_grad(self, shape_to_test):
x_init_value=test_image)
self.assertLess(left_err, 1e-10)

def _test_grad_different_shape(self, input_shape, output_shape):
with self.test_session():
test_image_shape = input_shape
test_image = np.random.randn(*test_image_shape)
test_image_tensor = constant_op.constant(
test_image, shape=test_image_shape)
test_transform = image_ops.angles_to_projective_transforms(
np.pi / 2, 4, 4)

if len(output_shape) == 2:
resize_shape = output_shape
elif len(output_shape) == 3:
resize_shape = output_shape[0:2]
elif len(output_shape) == 4:
resize_shape = output_shape[1:3]
output = image_ops.transform(
images=test_image_tensor,
transforms=test_transform,
output_shape=resize_shape)
left_err = gradient_checker.compute_gradient_error(
test_image_tensor,
test_image_shape,
output,
output_shape,
x_init_value=test_image)
self.assertLess(left_err, 1e-10)

def test_grad(self):
self._test_grad([16, 16])
self._test_grad([4, 12, 12])
self._test_grad([3, 4, 12, 12])
self._test_grad_different_shape([16, 16], [8, 8])
self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])


class BipartiteMatchTest(test_util.TensorFlowTestCase):
Expand Down
52 changes: 35 additions & 17 deletions tensorflow/contrib/image/python/ops/image_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
Expand All @@ -40,6 +41,9 @@
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)


# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name
# used by PIL, maybe more readable) mode, which determines the correct
# output_shape and translation for the transform.
def rotate(images, angles, interpolation="NEAREST", name=None):
"""Rotate image(s) counterclockwise by the passed angle(s) in radians.
Expand Down Expand Up @@ -213,7 +217,11 @@ def translations_to_projective_transforms(translations, name=None):
axis=1)


def transform(images, transforms, interpolation="NEAREST", name=None):
def transform(images,
transforms,
interpolation="NEAREST",
output_shape=None,
name=None):
"""Applies the given transform(s) to the image(s).
Args:
Expand All @@ -230,6 +238,10 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
the transform mapping input points to output points. Note that gradients
are not backpropagated into transformation parameters.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
output_shape: Output dimesion after the transform, [height, width].
If None, output is the same size as input image.
name: The name of the op.
Returns:
Image(s) with the same type and shape as `images`, with the given
Expand All @@ -238,6 +250,7 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
Raises:
TypeError: If `image` is an invalid type.
ValueError: If output shape is not 1-D int32 Tensor.
"""
with ops.name_scope(name, "transform"):
image_or_images = ops.convert_to_tensor(images, name="images")
Expand All @@ -256,6 +269,17 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
else:
raise TypeError("Images should have rank between 2 and 4.")

if output_shape is None:
output_shape = tensor_util.constant_value(
array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3]

output_shape = ops.convert_to_tensor(
output_shape, dtypes.int32, name="output_shape")

if not output_shape.get_shape().is_compatible_with([2]):
raise ValueError("output_shape must be a 1-D Tensor of 2 elements: "
"new_height, new_width")

if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif transform_or_transforms.get_shape().ndims is None:
Expand All @@ -265,8 +289,12 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
transforms = transform_or_transforms
else:
raise TypeError("Transforms should have rank 1 or 2.")

output = gen_image_ops.image_projective_transform(
images, transforms, interpolation=interpolation.upper())
images,
output_shape=output_shape,
transforms=transforms,
interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
Expand Down Expand Up @@ -376,14 +404,6 @@ def _image_projective_transform_grad(op, grad):

if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
if len(image_or_images.get_shape()) == 2:
images = image_or_images[None, :, :, None]
elif len(image_or_images.get_shape()) == 3:
images = image_or_images[None, :, :, :]
elif len(image_or_images.get_shape()) == 4:
images = image_or_images
else:
raise TypeError("Images should have rank between 2 and 4")
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif len(transform_or_transforms.get_shape()) == 2:
Expand All @@ -396,13 +416,11 @@ def _image_projective_transform_grad(op, grad):
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = gen_image_ops.image_projective_transform(
grad, transforms, interpolation=interpolation)
if len(image_or_images.get_shape()) == 2:
return [output[0, :, :, 0], None]
elif len(image_or_images.get_shape()) == 3:
return [output[0, :, :, :], None]
else:
return [output, None]
images=grad,
transforms=transforms,
output_shape=array_ops.shape(image_or_images)[1:3],
interpolation=interpolation)
return [output, None, None]


def bipartite_match(distance_mat,
Expand Down

0 comments on commit f7bf204

Please sign in to comment.