diff --git a/tensorflow/core/kernels/cwise_op_clip.cc b/tensorflow/core/kernels/cwise_op_clip.cc index 14d889e8e3b7c0..49b90e855be649 100644 --- a/tensorflow/core/kernels/cwise_op_clip.cc +++ b/tensorflow/core/kernels/cwise_op_clip.cc @@ -33,52 +33,41 @@ class ClipOp : public OpKernel { const Tensor& in0 = ctx->input(0); const Tensor& in1 = ctx->input(1); const Tensor& in2 = ctx->input(2); + OP_REQUIRES(ctx, (in0.shape() == in1.shape() || + TensorShapeUtils::IsScalar(in1.shape())) && + (in0.shape() == in2.shape() || + TensorShapeUtils::IsScalar(in2.shape())), + errors::InvalidArgument( + "clip_value_min and clip_value_max must be either of " + "the same shape as input, or a scalar. ", + "input shape: ", in0.shape().DebugString(), + "clip_value_min shape: ", in1.shape().DebugString(), + "clip_value_max shape: ", in2.shape().DebugString())); + + Tensor* out = nullptr; + OP_REQUIRES_OK( + ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out)); + if (out->NumElements() == 0) return; // Nothing to do for empty output auto in0_flat = in0.flat(); auto in1_flat = in1.flat(); auto in2_flat = in2.flat(); + auto out_flat = out->flat(); const Device& d = ctx->eigen_device(); - Tensor* out = nullptr; - OP_REQUIRES_OK( - ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out)); - auto out_flat = out->flat(); if (in1.shape() == in2.shape()) { if (in0.shape() == in1.shape()) { functor::TernaryClipOp()(d, in0_flat, in1_flat, in2_flat, out_flat); } else { - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in1.shape()), - errors::InvalidArgument( - "clip_value_min and clip_value_max must be either of " - "the same shape as input, or a scalar. ", - "input shape: ", in0.shape().DebugString(), - "clip_value_min shape: ", in1.shape().DebugString(), - "clip_value_max shape: ", in2.shape().DebugString())); functor::UnaryClipOp()(d, in0_flat, in1_flat, in2_flat, out_flat); } } else { if (in0.shape() == in1.shape()) { - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in2.shape()), - errors::InvalidArgument( - "clip_value_min and clip_value_max must be either of " - "the same shape as input, or a scalar. ", - "input shape: ", in0.shape().DebugString(), - "clip_value_min shape: ", in1.shape().DebugString(), - "clip_value_max shape: ", in2.shape().DebugString())); functor::BinaryLeftClipOp()(d, in0_flat, in1_flat, in2_flat, out_flat); } else { - OP_REQUIRES(ctx, - (in0.shape() == in2.shape() && - TensorShapeUtils::IsScalar(in1.shape())), - errors::InvalidArgument( - "clip_value_min and clip_value_max must be either of " - "the same shape as input, or a scalar. ", - "input shape: ", in0.shape().DebugString(), - "clip_value_min shape: ", in1.shape().DebugString(), - "clip_value_max shape: ", in2.shape().DebugString())); functor::BinaryRightClipOp()(d, in0_flat, in1_flat, in2_flat, out_flat); }