Skip to content

Commit

Permalink
[SPARK-29326][SQL] ANSI store assignment policy: throw exception on c…
Browse files Browse the repository at this point in the history
…asting failure

### What changes were proposed in this pull request?

1. With ANSI store assignment policy,  an exception is thrown on casting failure
2. Introduce a new expression `AnsiCast` for the ANSI store assignment policy, so that the store assignment policy configuration won't affect the general `Cast`.

### Why are the changes needed?

As per ANSI SQL standard, ANSI store assignment policy should throw an exception on insertion failure, such as inserting out-of-range value to a numeric field.

### Does this PR introduce any user-facing change?

With ANSI store assignment policy,  an exception is thrown on casting failure

### How was this patch tested?

Unit test

Closes apache#25997 from gengliangwang/newCast.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
gengliangwang authored and cloud-fan committed Oct 4, 2019
1 parent 8b71e54 commit 91747bd
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 270 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
Expand Down Expand Up @@ -99,9 +99,16 @@ object TableOutputResolver {
// Renaming is needed for handling the following cases like
// 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
// 2) Target tables have column metadata
Some(Alias(
Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)),
tableAttr.name)(explicitMetadata = Option(tableAttr.metadata)))
storeAssignmentPolicy match {
case StoreAssignmentPolicy.ANSI =>
Some(Alias(
AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)),
tableAttr.name)(explicitMetadata = Option(tableAttr.metadata)))
case _ =>
Some(Alias(
Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)),
tableAttr.name)(explicitMetadata = Option(tableAttr.metadata)))
}
}

storeAssignmentPolicy match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,23 +243,11 @@ object Cast {
}
}

/**
* Cast the child expression to the target data type.
*
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
* session local timezone by an analyzer [[ResolveTimeZone]].
*/
@ExpressionDescription(
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
""")
case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with NullIntolerant {
abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression with NullIntolerant {

def child: Expression

def this(child: Expression, dataType: DataType) = this(child, dataType, None)
def dataType: DataType

override def toString: String = s"cast($child as ${dataType.simpleString})"

Expand All @@ -274,8 +262,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
protected def ansiEnabled: Boolean

// When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
// Otherwise behave like Expression.resolved.
Expand All @@ -289,7 +276,6 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

private lazy val dateFormatter = DateFormatter(zoneId)
private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
private val failOnIntegralTypeOverflow = SQLConf.get.ansiEnabled

// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
Expand Down Expand Up @@ -493,7 +479,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Int](_, d => null)
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t))
case x: NumericType if failOnIntegralTypeOverflow =>
case x: NumericType if ansiEnabled =>
b => x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
Expand All @@ -508,11 +494,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
buildCast[Int](_, d => null)
case TimestampType if failOnIntegralTypeOverflow =>
case TimestampType if ansiEnabled =>
buildCast[Long](_, t => LongExactNumeric.toInt(timestampToLong(t)))
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t).toInt)
case x: NumericType if failOnIntegralTypeOverflow =>
case x: NumericType if ansiEnabled =>
b => x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
Expand All @@ -531,7 +517,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
case DateType =>
buildCast[Int](_, d => null)
case TimestampType if failOnIntegralTypeOverflow =>
case TimestampType if ansiEnabled =>
buildCast[Long](_, t => {
val longValue = timestampToLong(t)
if (longValue == longValue.toShort) {
Expand All @@ -542,7 +528,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
})
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t).toShort)
case x: NumericType if failOnIntegralTypeOverflow =>
case x: NumericType if ansiEnabled =>
b =>
val intValue = try {
x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
Expand Down Expand Up @@ -572,7 +558,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
case DateType =>
buildCast[Int](_, d => null)
case TimestampType if failOnIntegralTypeOverflow =>
case TimestampType if ansiEnabled =>
buildCast[Long](_, t => {
val longValue = timestampToLong(t)
if (longValue == longValue.toByte) {
Expand All @@ -583,7 +569,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
})
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t).toByte)
case x: NumericType if failOnIntegralTypeOverflow =>
case x: NumericType if ansiEnabled =>
b =>
val intValue = try {
x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
Expand All @@ -600,8 +586,6 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}

private val nullOnOverflow = !SQLConf.get.ansiEnabled

/**
* Change the precision / scale in a given decimal to those set in `decimalType` (if any),
* modifying `value` in-place and returning it if successful. If an overflow occurs, it
Expand All @@ -614,7 +598,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
if (value.changePrecision(decimalType.precision, decimalType.scale)) {
value
} else {
if (nullOnOverflow) {
if (!ansiEnabled) {
null
} else {
throw new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
Expand All @@ -630,7 +614,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
*/
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
value.toPrecision(
decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow)
decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled)


private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
Expand Down Expand Up @@ -1095,7 +1079,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|$evPrim = $d;
""".stripMargin
} else {
val overflowCode = if (nullOnOverflow) {
val overflowCode = if (!ansiEnabled) {
s"$evNull = true;"
} else {
s"""
Expand Down Expand Up @@ -1274,7 +1258,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castTimestampToIntegralTypeCode(
ctx: CodegenContext,
integralType: String): CastFunction = {
if (failOnIntegralTypeOverflow) {
if (ansiEnabled) {
val longValue = ctx.freshName("longValue")
(c, evPrim, evNull) =>
code"""
Expand All @@ -1293,15 +1277,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castDecimalToIntegralTypeCode(
ctx: CodegenContext,
integralType: String): CastFunction = {
if (failOnIntegralTypeOverflow) {
if (ansiEnabled) {
(c, evPrim, evNull) => code"$evPrim = $c.roundTo${integralType.capitalize}();"
} else {
(c, evPrim, evNull) => code"$evPrim = $c.to${integralType.capitalize}();"
}
}

private[this] def castIntegralTypeToIntegralTypeExactCode(integralType: String): CastFunction = {
assert(failOnIntegralTypeOverflow)
assert(ansiEnabled)
(c, evPrim, evNull) =>
code"""
if ($c == ($integralType) $c) {
Expand Down Expand Up @@ -1329,7 +1313,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castFractionToIntegralTypeCode(
fractionType: String,
integralType: String): CastFunction = {
assert(failOnIntegralTypeOverflow)
assert(ansiEnabled)
val (min, max) = lowerAndUpperBound(fractionType, integralType)
val mathClass = classOf[Math].getName
// When casting floating values to integral types, Spark uses the method `Numeric.toInt`
Expand Down Expand Up @@ -1366,11 +1350,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "byte")
case _: ShortType | _: IntegerType | _: LongType if failOnIntegralTypeOverflow =>
case _: ShortType | _: IntegerType | _: LongType if ansiEnabled =>
castIntegralTypeToIntegralTypeExactCode("byte")
case _: FloatType if failOnIntegralTypeOverflow =>
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "byte")
case _: DoubleType if failOnIntegralTypeOverflow =>
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "byte")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (byte) $c;"
Expand All @@ -1397,11 +1381,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "short")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "short")
case _: IntegerType | _: LongType if failOnIntegralTypeOverflow =>
case _: IntegerType | _: LongType if ansiEnabled =>
castIntegralTypeToIntegralTypeExactCode("short")
case _: FloatType if failOnIntegralTypeOverflow =>
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "short")
case _: DoubleType if failOnIntegralTypeOverflow =>
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "short")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (short) $c;"
Expand All @@ -1426,10 +1410,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "int")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "int")
case _: LongType if failOnIntegralTypeOverflow => castIntegralTypeToIntegralTypeExactCode("int")
case _: FloatType if failOnIntegralTypeOverflow =>
case _: LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("int")
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "int")
case _: DoubleType if failOnIntegralTypeOverflow =>
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "int")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (int) $c;"
Expand All @@ -1456,9 +1440,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (long) ${timestampToLongCode(c)};"
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "long")
case _: FloatType if failOnIntegralTypeOverflow =>
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "long")
case _: DoubleType if failOnIntegralTypeOverflow =>
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "long")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
Expand Down Expand Up @@ -1647,6 +1631,43 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
}
}

/**
* Cast the child expression to the target data type.
*
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
* session local timezone by an analyzer [[ResolveTimeZone]].
*/
@ExpressionDescription(
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
examples = """
Examples:
> SELECT _FUNC_('10' as int);
10
""")
case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends CastBase {
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled
}

/**
* Cast the child expression to the target data type as per ANSI SQL standard.
* A runtime exception will be thrown on casting failure such as converting an out-of-range value
* to an integral type.
*
* When cast from/to timezone related types, we need timeZoneId, which will be resolved with
* session local timezone by an analyzer [[ResolveTimeZone]].
*/
case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
extends CastBase {
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override protected val ansiEnabled: Boolean = true
}

/**
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
Expand Down
Loading

0 comments on commit 91747bd

Please sign in to comment.