diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 2166c34358fa62..56d24423c428d3 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -891,44 +891,51 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); - auto invalid_argument = - InvalidArgument("Shape %s size may overflow int64.", - ShapeUtil::HumanString(shape).c_str()); + if (!IsArray(shape)) { return Status::OK(); } - int64 shape_size; - if (LayoutUtil::IsSparseArray(shape)) { - shape_size = LayoutUtil::MaxSparseElements(shape.layout()); - if (shape_size < 0) { - return invalid_argument; - } - shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape)); - if (shape_size < 0) { - return invalid_argument; + + int64 shape_size = [&shape]() { + int64 shape_size; + if (LayoutUtil::IsSparseArray(shape)) { + shape_size = LayoutUtil::MaxSparseElements(shape.layout()); + if (shape_size < 0) { + return shape_size; + } + shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape)); + if (shape_size < 0) { + return shape_size; + } + shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64)); + if (shape_size < 0) { + return shape_size; + } } - shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64)); - if (shape_size < 0) { - return invalid_argument; + + shape_size = 1; + + // This is intentionally unconditional: even if the shape is sparse, we want + // to verify the densified version has a reasonable size. + if (shape.dimensions().empty()) { + return shape_size; } - } - // This is intentionally unconditional: even if the shape is sparse, we want - // to verify the densified version has a reasonable size. - if (shape.dimensions().empty()) { - return Status::OK(); - } - shape_size = 1; - for (int64 dim : shape.dimensions()) { - shape_size = MultiplyWithoutOverflow(shape_size, dim); - if (shape_size < 0) { - return invalid_argument; + for (int64 dim : shape.dimensions()) { + shape_size = MultiplyWithoutOverflow(shape_size, dim); + if (shape_size < 0) { + return shape_size; + } } - } - shape_size = MultiplyWithoutOverflow( - shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + shape_size = MultiplyWithoutOverflow( + shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + + return shape_size; + }(); + if (shape_size < 0) { - return invalid_argument; + return InvalidArgument("Shape %s size may overflow int64.", + ShapeUtil::HumanString(shape).c_str()); } VLOG(3) << "Shape size is valid: " << shape_size;