Skip to content

Commit

Permalink
[SPARK-20329][SQL] Make timezone aware expression without timezone un…
Browse files Browse the repository at this point in the history
…resolved

## What changes were proposed in this pull request?
A cast expression with a resolved time zone is not equal to a cast expression without a resolved time zone. The `ResolveAggregateFunction` assumed that these expression were the same, and would fail to resolve `HAVING` clauses which contain a `Cast` expression.

This is in essence caused by the fact that a `TimeZoneAwareExpression` can be resolved without a set time zone. This PR fixes this, and makes a `TimeZoneAwareExpression` unresolved as long as it has no TimeZone set.

## How was this patch tested?
Added a regression test to the `SQLQueryTestSuite.having` file.

Author: Herman van Hovell <[email protected]>

Closes apache#17641 from hvanhovell/SPARK-20329.
  • Loading branch information
hvanhovell authored and cloud-fan committed Apr 21, 2017
1 parent 0368eb9 commit 760c8d0
Show file tree
Hide file tree
Showing 19 changed files with 148 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveTimeZone(conf) ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Expand All @@ -161,8 +162,6 @@ class Analyzer(
HandleNullInputsForUDF),
Batch("FixNullability", Once,
FixNullability),
Batch("ResolveTimeZone", Once,
ResolveTimeZone),
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
Expand Down Expand Up @@ -2368,23 +2367,6 @@ class Analyzer(
}
}
}

/**
* Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
* time zone.
*/
object ResolveTimeZone extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
// the types between the value expression and list query expression of IN expression.
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
// information for time zone aware expressions.
case e: ListQuery => e.withNewPlan(apply(e.plan))
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/
case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
Expand Down Expand Up @@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
val castedExpr = if (e.dataType.sameType(targetType)) {
e
} else {
Cast(e, targetType)
cast(e, targetType)
}
castedExpr.transform {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
}.eval()
castedExpr.eval()
} catch {
case NonFatal(ex) =>
table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

/**
* Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
* time zone.
*/
case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] {
private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
// the types between the value expression and list query expression of IN expression.
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
// information for time zone aware expressions.
case e: ListQuery => e.withNewPlan(apply(e.plan))
}

override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveExpressions(transformTimeZoneExprs)

def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
}

/**
* Mix-in trait for constructing valid [[Cast]] expressions.
*/
trait CastSupport {
/**
* Configuration used to create a valid cast expression.
*/
def conf: SQLConf

/**
* Create a Cast expression with the session local time zone.
*/
def cast(child: Expression, dataType: DataType): Cast = {
Cast(child, dataType, Option(conf.sessionLocalTimeZone))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf
* This should be only done after the batch of Resolution, because the view attributes are not
* completely resolved during the batch of Resolution.
*/
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case v @ View(desc, output, child) if child.resolved && output != child.output =>
val resolver = conf.resolver
Expand Down Expand Up @@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " +
s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n")
} else {
Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata))
}
case (_, originAttr) => originAttr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone}
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand All @@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
* Common base class for time zone aware expressions.
*/
trait TimeZoneAwareExpression extends Expression {
/** The expression is only resolved when the time zone has been set. */
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined

/** the timezone ID to be used to evaluate value. */
def timeZoneId: Option[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}

/**
Expand Down Expand Up @@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
val converted = ResolveInlineTables(conf).convert(table)
val withTimeZone = ResolveTimeZone(conf).apply(table)
val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
assert(converted.output.map(_.dataType) == Seq(TimestampType))
assert(converted.data.size == 1)
assert(converted.data(0).getLong(0) == correct)
assert(output.map(_.dataType) == Seq(TimestampType))
assert(data.size == 1)
assert(data.head.getLong(0) == correct)
}

test("nullability inference in convert") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest {
}
}

private val timeZoneResolver = ResolveTimeZone(new SQLConf)

private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
}

test("WidenSetOperationTypes for except and intersect") {
val firstTable = LocalRelation(
AttributeReference("i", IntegerType)(),
Expand All @@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("f", FloatType)(),
AttributeReference("l", LongType)())

val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)

val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except]
val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r2.left, expectedTypes)
Expand Down Expand Up @@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("p", ByteType)(),
AttributeReference("q", DoubleType)())

val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)

val unionRelation = wt(
val unionRelation = widenSetOperationTypes(
Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
assert(unionRelation.children.length == 4)
checkOutput(unionRelation.children.head, expectedTypes)
Expand All @@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest {
}
}

val dp = TypeCoercion.WidenSetOperationTypes

val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())
val right1 = LocalRelation(
AttributeReference("r", DecimalType(5, 5))())
val expectedType1 = Seq(DecimalType(10, 8))

val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union]
val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except]
val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect]

checkOutput(r1.children.head, expectedType1)
checkOutput(r1.children.last, expectedType1)
Expand All @@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest {
val plan2 = LocalRelation(
AttributeReference("r", rType)())

val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union]
val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except]
val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect]

checkOutput(r1.children.last, Seq(expectedType))
checkOutput(r2.right, Seq(expectedType))
checkOutput(r3.right, Seq(expectedType))

val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union]
val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except]
val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect]

checkOutput(r4.children.last, Seq(expectedType))
checkOutput(r5.left, Seq(expectedType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String
*/
class CastSuite extends SparkFunSuite with ExpressionEvalHelper {

private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = {
private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = {
v match {
case lit: Expression => Cast(lit, targetType, timeZoneId)
case _ => Cast(Literal(v), targetType, timeZoneId)
Expand All @@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}

private def checkNullCast(from: DataType, to: DataType): Unit = {
checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null)
checkEvaluation(cast(Literal.create(null, from), to), null)
}

test("null cast") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Seconds") {
assert(Second(Literal.create(null, DateType), gmtId).resolved === false)
assert(Second(Cast(Literal(d), TimestampType), None).resolved === true)
assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true)
checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15)
checkEvaluation(Second(Literal(ts), gmtId), 15)
Expand Down Expand Up @@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Hour") {
assert(Hour(Literal.create(null, DateType), gmtId).resolved === false)
assert(Hour(Literal(ts), None).resolved === true)
assert(Hour(Literal(ts), gmtId).resolved === true)
checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13)
checkEvaluation(Hour(Literal(ts), gmtId), 13)
Expand All @@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Minute") {
assert(Minute(Literal.create(null, DateType), gmtId).resolved === false)
assert(Minute(Literal(ts), None).resolved === true)
assert(Minute(Literal(ts), gmtId).resolved === true)
checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(
Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand All @@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
protected def checkEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
val expr: Expression = serializer.deserialize(serializer.serialize(expression))
val resolver = ResolveTimeZone(new SQLConf)
val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SparkPlanner(
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
JoinSelection ::
Expand Down
Loading

0 comments on commit 760c8d0

Please sign in to comment.