Skip to content

Commit

Permalink
[SPARK-37802][SQL] Composite field name should work with Aggregate pu…
Browse files Browse the repository at this point in the history
…sh down

### What changes were proposed in this pull request?
Currently, composite filed name such as dept id doesn't work with aggregate push down

sql("SELECT COUNT(\`dept id\`) FROM h2.test.dept")
```
org.apache.spark.sql.catalyst.parser.ParseException:
extraneous input 'id' expecting <EOF>(line 1, pos 5)

== SQL ==
dept id
-----^^^

	at org.apache.spark.sql.catalyst.parser.ParseException.withCommand(ParseDriver.scala:271)
	at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:132)
	at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parseMultipartIdentifier(ParseDriver.scala:63)
	at org.apache.spark.sql.connector.expressions.LogicalExpressions$.parseReference(expressions.scala:39)
	at org.apache.spark.sql.connector.expressions.FieldReference$.apply(expressions.scala:365)
	at org.apache.spark.sql.execution.datasources.DataSourceStrategy$.translateAggregate(DataSourceStrategy.scala:717)
	at org.apache.spark.sql.execution.datasources.v2.PushDownUtils$.$anonfun$pushAggregates$1(PushDownUtils.scala:125)
	at scala.collection.immutable.List.flatMap(List.scala:366)
	at org.apache.spark.sql.execution.datasources.v2.PushDownUtils$.pushAggregates(PushDownUtils.scala:125)
```
This PR fixes the problem.

### Why are the changes needed?
bug fixing

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
New test

Closes apache#35108 from huaxingao/composite_name.

Authored-by: Huaxin Gao <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
huaxingao authored and cloud-fan committed Jan 7, 2022
1 parent 842c0c3 commit cf193b9
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ private[sql] object FieldReference {
def apply(column: String): NamedReference = {
LogicalExpressions.parseReference(column)
}

def column(name: String) : NamedReference = {
FieldReference(Seq(name))
}
}

private[sql] final case class SortValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,41 +706,45 @@ object DataSourceStrategy
if (agg.filter.isEmpty) {
agg.aggregateFunction match {
case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
Some(new Min(FieldReference(name)))
Some(new Min(FieldReference.column(name)))
case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
Some(new Max(FieldReference(name)))
Some(new Max(FieldReference.column(name)))
case count: aggregate.Count if count.children.length == 1 =>
count.children.head match {
// COUNT(any literal) is the same as COUNT(*)
case Literal(_, _) => Some(new CountStar())
case PushableColumnWithoutNestedColumn(name) =>
Some(new Count(FieldReference(name), agg.isDistinct))
Some(new Count(FieldReference.column(name), agg.isDistinct))
case _ => None
}
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference(name), agg.isDistinct))
Some(new Sum(FieldReference.column(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name))))
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name))))
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(FieldReference(name))))
Some(new GeneralAggregateFunc(
"VAR_POP", agg.isDistinct, Array(FieldReference.column(name))))
case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(FieldReference(name))))
Some(new GeneralAggregateFunc(
"VAR_SAMP", agg.isDistinct, Array(FieldReference.column(name))))
case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(FieldReference(name))))
Some(new GeneralAggregateFunc(
"STDDEV_POP", agg.isDistinct, Array(FieldReference.column(name))))
case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name))))
Some(new GeneralAggregateFunc(
"STDDEV_SAMP", agg.isDistinct, Array(FieldReference.column(name))))
case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct,
Array(FieldReference(left), FieldReference(right))))
Array(FieldReference.column(left), FieldReference.column(right))))
case aggregate.CovSample(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct,
Array(FieldReference(left), FieldReference(right))))
Array(FieldReference.column(left), FieldReference.column(right))))
case aggregate.Corr(PushableColumnWithoutNestedColumn(left),
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
Array(FieldReference(left), FieldReference(right))))
Array(FieldReference.column(left), FieldReference.column(right))))
case _ => None
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ object PushDownUtils extends PredicateHelper {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference(name).asInstanceOf[FieldReference])
Some(FieldReference.column(name).asInstanceOf[FieldReference])
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
.executeUpdate()
conn.prepareStatement(
"CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate()

// scalastyle:off
conn.prepareStatement(
"CREATE TABLE \"test\".\"person\" (\"\" INTEGER NOT NULL)").executeUpdate()
// scalastyle:on
conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate()
}
}

Expand Down Expand Up @@ -305,7 +316,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
Seq(Row("test", "people", false), Row("test", "empty_table", false),
Row("test", "employee", false)))
Row("test", "employee", false), Row("test", "dept", false), Row("test", "person", false)))
}

test("SQL API: create table as select") {
Expand Down Expand Up @@ -831,4 +842,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df,
Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1)))
}

test("column name with composite field") {
checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2)))
val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept")
checkAggregateRemoved(df)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [COUNT(`dept id`)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(2)))
}

test("column name with non-ascii") {
// scalastyle:off
checkAnswer(sql("SELECT `名` FROM h2.test.person"), Seq(Row(1), Row(2)))
val df = sql("SELECT COUNT(`名`) FROM h2.test.person")
checkAggregateRemoved(df)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [COUNT(`名`)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(2)))
// scalastyle:on
}
}

0 comments on commit cf193b9

Please sign in to comment.