Skip to content

Commit

Permalink
[SPARK-7562][SPARK-6444][SQL] Improve error reporting for expression …
Browse files Browse the repository at this point in the history
…data type mismatch

It seems hard to find a common pattern of checking types in `Expression`. Sometimes we know what input types we need(like `And`, we know we need two booleans), sometimes we just have some rules(like `Add`, we need 2 numeric types which are equal). So I defined a general interface `checkInputDataTypes` in `Expression` which returns a `TypeCheckResult`. `TypeCheckResult` can tell whether this expression passes the type checking or what the type mismatch is.

This PR mainly works on apply input types checking for arithmetic and predicate expressions.

TODO: apply type checking interface to more expressions.

Author: Wenchen Fan <[email protected]>

Closes apache#6405 from cloud-fan/6444 and squashes the following commits:

b5ff31b [Wenchen Fan] address comments
b917275 [Wenchen Fan] rebase
39929d9 [Wenchen Fan] add todo
0808fd2 [Wenchen Fan] make constrcutor of TypeCheckResult private
3bee157 [Wenchen Fan] and decimal type coercion rule for binary comparison
8883025 [Wenchen Fan] apply type check interface to CaseWhen
cffb67c [Wenchen Fan] to have resolved call the data type check function
6eaadff [Wenchen Fan] add equal type constraint to EqualTo
3affbd8 [Wenchen Fan] more fixes
654d46a [Wenchen Fan] improve tests
e0a3628 [Wenchen Fan] improve error message
1524ff6 [Wenchen Fan] fix style
69ca3fe [Wenchen Fan] add error message and tests
c71d02c [Wenchen Fan] fix hive tests
6491721 [Wenchen Fan] use value class TypeCheckResult
7ae76b9 [Wenchen Fan] address comments
cb77e4f [Wenchen Fan] Improve error reporting for expression data type mismatch
  • Loading branch information
cloud-fan authored and rxin committed Jun 3, 2015
1 parent ce320cb commit d38cf21
Show file tree
Hide file tree
Showing 17 changed files with 583 additions and 421 deletions.
4 changes: 2 additions & 2 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging {
* Log the suite name and the test name before and after each test.
*
* Subclasses should never override this method. If they wish to run
* custom code before and after each test, they should should mix in
* the {{org.scalatest.BeforeAndAfter}} trait instead.
* custom code before and after each test, they should mix in the
* {{org.scalatest.BeforeAndAfter}} trait instead.
*/
final protected override def withFixture(test: NoArgTest): Outcome = {
val testName = test.text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ trait CheckAnalysis {
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")

case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
e.failAnalysis(
s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
}

case c: Cast if !c.resolved =>
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")

case b: BinaryExpression if !b.resolved =>
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}")

case WindowExpression(UnresolvedWindowFunction(name, _), _) =>
failAnalysis(
s"Could not resolve window function '$name'. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object HiveTypeCoercion {
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
*/
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
Expand All @@ -57,6 +57,17 @@ object HiveTypeCoercion {

case _ => None
}

/**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
*/
private def findTightestCommonType(types: Seq[DataType]) = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonTypeOfTwo(d, c)
})
}
}

/**
Expand Down Expand Up @@ -180,7 +191,7 @@ trait HiveTypeCoercion {

case (l, r) if l.dataType != r.dataType =>
logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType =>
val newLeft =
if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
val newRight =
Expand Down Expand Up @@ -217,7 +228,7 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e

case b: BinaryExpression if b.left.dataType != b.right.dataType =>
findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType =>
findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType =>
val newLeft =
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
val newRight =
Expand Down Expand Up @@ -441,21 +452,18 @@ trait HiveTypeCoercion {
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)

case LessThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
// When we compare 2 decimal types with different precisions, cast them to the smallest
// common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
val resultType = DecimalType(max(p1, p2), max(s1, s2))
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2)
if e2.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2))
case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _))
if e1.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited)))

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
Expand Down Expand Up @@ -570,7 +578,7 @@ trait HiveTypeCoercion {

case a @ CreateArray(children) if !a.resolved =>
val commonType = a.childTypes.reduce(
(a, b) => findTightestCommonType(a, b).getOrElse(StringType))
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
CreateArray(
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))

Expand Down Expand Up @@ -599,14 +607,9 @@ trait HiveTypeCoercion {
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val dt: Option[DataType] = Some(NullType)
val types = es.map(_.dataType)
val rt = types.foldLeft(dt)((r, c) => r match {
case None => None
case Some(d) => findTightestCommonType(d, c)
})
rt match {
case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt)))
findTightestCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None =>
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
}
Expand All @@ -619,17 +622,13 @@ trait HiveTypeCoercion {
*/
object Division extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
case e if !e.resolved => e

// Decimal and Double remain the same
case d: Divide if d.resolved && d.dataType == DoubleType => d
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
Divide(l, Cast(r, DecimalType.Unlimited))
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
Divide(Cast(l, DecimalType.Unlimited), r)
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
Expand All @@ -642,42 +641,33 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = cw.branches.sliding(2, 2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
cw match {
case _: CaseWhen =>
CaseWhen(transformedBranches)
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}

case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = ckw.branches.sliding(2, 2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case s => s
}.reduce(_ ++ _)
val transformedKey = if (ckw.key.dataType != commonType) {
Cast(ckw.key, commonType)
} else {
ckw.key
}
CaseKeyWhen(transformedKey, transformedBranches)
case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
val maybeCommonType = findTightestCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case other => other
}.reduce(_ ++ _)
c match {
case _: CaseWhen => CaseWhen(castedBranches)
case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches)
}
}.getOrElse(c)

case c: CaseKeyWhen if c.childrenResolved && !c.resolved =>
val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType))
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case other => other
}.reduce(_ ++ _)
CaseKeyWhen(Cast(c.key, commonType), castedBranches)
}.getOrElse(c)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

/**
* Represents the result of `Expression.checkInputDataTypes`.
* We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true.
*/
trait TypeCheckResult {
def isFailure: Boolean = !isSuccess
def isSuccess: Boolean
}

object TypeCheckResult {

/**
* Represents the successful result of `Expression.checkInputDataTypes`.
*/
object TypeCheckSuccess extends TypeCheckResult {
def isSuccess: Boolean = true
}

/**
* Represents the failing result of `Expression.checkInputDataTypes`,
* with a error message to show the reason of failure.
*/
case class TypeCheckFailure(message: String) extends TypeCheckResult {
def isSuccess: Boolean = false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -53,11 +53,12 @@ abstract class Expression extends TreeNode[Expression] {

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and `false` if it still contains any unresolved placeholders. Implementations of expressions
* should override this if the resolution of this type of expression involves more than just
* the resolution of its children.
* and input data types checking passed, and `false` if it still contains any unresolved
* placeholders or has data types mismatch.
* Implementations of expressions should override this if the resolution of this type of
* expression involves more than just the resolution of its children and type checking.
*/
lazy val resolved: Boolean = childrenResolved
lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

/**
* Returns the [[DataType]] of the result of evaluating this expression. It is
Expand Down Expand Up @@ -94,12 +95,21 @@ abstract class Expression extends TreeNode[Expression] {
case (i1, i2) => i1 == i2
}
}

/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`
* TODO: we should remove the default implementation and implement it for all
* expressions with proper error message.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
self: Product =>

def symbol: String
def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol")

override def foldable: Boolean = left.foldable && right.foldable

Expand Down Expand Up @@ -133,7 +143,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression {
* so that the proper type conversions can be performed in the analyzer.
*/
trait ExpectsInputTypes {
self: Expression =>

def expectedChildTypes: Seq[DataType]

override def checkInputDataTypes(): TypeCheckResult = {
// We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
// so type mismatch error won't be reported here, but for underling `Cast`s.
TypeCheckResult.TypeCheckSuccess
}
}
Loading

0 comments on commit d38cf21

Please sign in to comment.