Skip to content

Commit

Permalink
[SPARK-9830][SPARK-11641][SQL][FOLLOW-UP] Remove AggregateExpression1…
Browse files Browse the repository at this point in the history
… and update toString of Exchange

https://issues.apache.org/jira/browse/SPARK-9830

This is the follow-up pr for apache#9556 to address davies' comments.

Author: Yin Huai <[email protected]>

Closes apache#9607 from yhuai/removeAgg1-followup.
  • Loading branch information
yhuai authored and rxin committed Nov 11, 2015
1 parent e281b87 commit 3121e78
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ class Analyzer(
case min: Min if isDistinct =>
AggregateExpression(min, Complete, isDistinct = false)
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct)
case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
// This function is not an aggregate function, just return the resolved one.
case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,21 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case aggExpr: AggregateExpression =>
// TODO: Is it possible that the child of a agg function is another
// agg function?
aggExpr.aggregateFunction.children.foreach {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
// a Project node.
case child if !child.deterministic =>
aggExpr.aggregateFunction.children.foreach { child =>
child.foreach {
case agg: AggregateExpression =>
failAnalysis(
s"It is not allowed to use an aggregate function in the argument of " +
s"another aggregate function. Please use the inner aggregate function " +
s"in a sub-query.")
case other => // OK
}

if (!child.deterministic) {
failAnalysis(
s"nondeterministic expression ${expr.prettyString} should not " +
s"appear in the arguments of an aggregate function.")
case child => // OK
}
}
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
Expand All @@ -133,19 +137,33 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}

def checkSupportedGroupingDataType(
expressionString: String,
dataType: DataType): Unit = dataType match {
case BinaryType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in binary type or its inner field is " +
s"in binary type")
case a: ArrayType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in array type or its inner field is " +
s"in array type")
case m: MapType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in map type or its inner field is " +
s"in map type")
case s: StructType =>
s.fields.foreach { f =>
checkSupportedGroupingDataType(expressionString, f.dataType)
}
case udt: UserDefinedType[_] =>
checkSupportedGroupingDataType(expressionString, udt.sqlType)
case _ => // OK
}

def checkValidGroupingExprs(expr: Expression): Unit = {
expr.dataType match {
case BinaryType =>
failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
"in grouping expression")
case a: ArrayType =>
failAnalysis(s"array type expression ${expr.prettyString} cannot be used " +
"in grouping expression")
case m: MapType =>
failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
"in grouping expression")
case _ => // OK
}
checkSupportedGroupingDataType(expr.prettyString, expr.dataType)

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ case class Average(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = resultType

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w

override def dataType: DataType = DoubleType

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {

override def dataType: DataType = resultType

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType))
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = resultType

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
Seq(TypeCollection(LongType, DoubleType, DecimalType))

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,59 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.types._

import scala.beans.{BeanProperty, BeanInfo}

@BeanInfo
private[sql] case class GroupableData(@BeanProperty data: Int)

private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {

override def sqlType: DataType = IntegerType

override def serialize(obj: Any): Int = {
obj match {
case groupableData: GroupableData => groupableData.data
}
}

override def deserialize(datum: Any): GroupableData = {
datum match {
case data: Int => GroupableData(data)
}
}

override def userClass: Class[GroupableData] = classOf[GroupableData]

private[spark] override def asNullable: GroupableUDT = this
}

@BeanInfo
private[sql] case class UngroupableData(@BeanProperty data: Array[Int])

private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {

override def sqlType: DataType = ArrayType(IntegerType)

override def serialize(obj: Any): ArrayData = {
obj match {
case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
}
}

override def deserialize(datum: Any): UngroupableData = {
datum match {
case data: Array[Int] => UngroupableData(data)
}
}

override def userClass: Class[UngroupableData] = classOf[UngroupableData]

private[spark] override def asNullable: UngroupableUDT = this
}

case class TestFunction(
children: Seq[Expression],
inputTypes: Seq[AbstractDataType])
Expand Down Expand Up @@ -194,39 +245,65 @@ class AnalysisErrorSuite extends AnalysisTest {
assert(error.message.contains("Conflicting attributes"))
}

test("aggregation can't work on binary and map types") {
val plan =
Aggregate(
AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
test("check grouping expression data types") {
def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = {
val plan =
Aggregate(
AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil,
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", dataType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))

shouldSuccess match {
case true =>
assertAnalysisSuccess(plan, true)
case false =>
assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
}

assertAnalysisError(plan,
"binary type expression a cannot be used in grouping expression" :: Nil)
}

val plan2 =
Aggregate(
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))
val supportedDataTypes = Seq(
StringType,
NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", StringType, nullable = true),
new GroupableUDT())
supportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = true)
}

assertAnalysisError(plan2,
"map type expression a cannot be used in grouping expression" :: Nil)
val unsupportedDataTypes = Seq(
BinaryType,
ArrayType(IntegerType),
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
new UngroupableUDT())
unsupportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = false)
}
}

val plan3 =
test("we should fail analysis when we find nested aggregate functions") {
val plan =
Aggregate(
AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil,
Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil,
Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil,
LocalRelation(
AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)),
AttributeReference("a", IntegerType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))

assertAnalysisError(plan3,
"array type expression a cannot be used in grouping expression" :: Nil)
assertAnalysisError(
plan,
"It is not allowed to use an aggregate function in the argument of " +
"another aggregate function." :: Nil)
}

test("Join can't work on binary and map types") {
Expand Down
1 change: 1 addition & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ private[spark] object SQLConf {
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ case class Exchange(
override def nodeName: String = {
val extraInfo = coordinator match {
case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated =>
"Shuffle"
s"(coordinator id: ${System.identityHashCode(coordinator)})"
case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated =>
"May shuffle"
case None => "Shuffle without coordinator"
s"(coordinator id: ${System.identityHashCode(coordinator)})"
case None => ""
}

val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange"
s"$simpleNodeName($extraInfo)"
s"${simpleNodeName}${extraInfo}"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
}
(keyValueOutput, runFunc)

case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) =>
val runFunc = (sqlContext: SQLContext) => {
logWarning(
s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " +
s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " +
s"continue to be true.")
Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true"))
}
(keyValueOutput, runFunc)

// Configures a single property.
case Some((key, Some(value))) =>
val runFunc = (sqlContext: SQLContext) => {
Expand Down

0 comments on commit 3121e78

Please sign in to comment.