Skip to content

Commit

Permalink
Merge pull request tensorflow#19338 from girving/clip
Browse files Browse the repository at this point in the history
Make tf.clip_by_value not crash on empty tensors
  • Loading branch information
zheng-xq authored May 18, 2018
2 parents 5cfa305 + ce11f0b commit 4efdf36
Showing 1 changed file with 16 additions and 27 deletions.
43 changes: 16 additions & 27 deletions tensorflow/core/kernels/cwise_op_clip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,52 +33,41 @@ class ClipOp : public OpKernel {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
const Tensor& in2 = ctx->input(2);
OP_REQUIRES(ctx, (in0.shape() == in1.shape() ||
TensorShapeUtils::IsScalar(in1.shape())) &&
(in0.shape() == in2.shape() ||
TensorShapeUtils::IsScalar(in2.shape())),
errors::InvalidArgument(
"clip_value_min and clip_value_max must be either of "
"the same shape as input, or a scalar. ",
"input shape: ", in0.shape().DebugString(),
"clip_value_min shape: ", in1.shape().DebugString(),
"clip_value_max shape: ", in2.shape().DebugString()));

Tensor* out = nullptr;
OP_REQUIRES_OK(
ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
if (out->NumElements() == 0) return; // Nothing to do for empty output

auto in0_flat = in0.flat<T>();
auto in1_flat = in1.flat<T>();
auto in2_flat = in2.flat<T>();
auto out_flat = out->flat<T>();
const Device& d = ctx->eigen_device<Device>();

Tensor* out = nullptr;
OP_REQUIRES_OK(
ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
auto out_flat = out->flat<T>();
if (in1.shape() == in2.shape()) {
if (in0.shape() == in1.shape()) {
functor::TernaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
} else {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in1.shape()),
errors::InvalidArgument(
"clip_value_min and clip_value_max must be either of "
"the same shape as input, or a scalar. ",
"input shape: ", in0.shape().DebugString(),
"clip_value_min shape: ", in1.shape().DebugString(),
"clip_value_max shape: ", in2.shape().DebugString()));
functor::UnaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
}
} else {
if (in0.shape() == in1.shape()) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in2.shape()),
errors::InvalidArgument(
"clip_value_min and clip_value_max must be either of "
"the same shape as input, or a scalar. ",
"input shape: ", in0.shape().DebugString(),
"clip_value_min shape: ", in1.shape().DebugString(),
"clip_value_max shape: ", in2.shape().DebugString()));
functor::BinaryLeftClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
} else {
OP_REQUIRES(ctx,
(in0.shape() == in2.shape() &&
TensorShapeUtils::IsScalar(in1.shape())),
errors::InvalidArgument(
"clip_value_min and clip_value_max must be either of "
"the same shape as input, or a scalar. ",
"input shape: ", in0.shape().DebugString(),
"clip_value_min shape: ", in1.shape().DebugString(),
"clip_value_max shape: ", in2.shape().DebugString()));
functor::BinaryRightClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
out_flat);
}
Expand Down

0 comments on commit 4efdf36

Please sign in to comment.