Skip to content

Commit

Permalink
[SPARK-41391][SQL] The output column name of groupBy.agg(count_distin…
Browse files Browse the repository at this point in the history
…ct) is incorrect

### What changes were proposed in this pull request?

correct the output column name of groupBy.agg(count_distinct),  so the "*" is expanded correctly into column names and the output column has the distinct keyword.

### Why are the changes needed?

Output column name for groupBy.agg(count_distinct)  is incorrect . However similar queries in spark sql return correct output column. For groupBy.agg queries on dataframe "*" is not expanded correctly in the output column  and the distinct keyword is missing from output column.

```
// initializing data
scala> val df = spark.range(1, 10).withColumn("value", lit(1))
df: org.apache.spark.sql.DataFrame = [id: bigint, value: int]
scala> df.createOrReplaceTempView("table")

// Dataframe  aggregate queries with incorrect output column
scala> df.groupBy("id").agg(count_distinct($"*"))
res3: org.apache.spark.sql.DataFrame = [id: bigint, count(unresolvedstar()): bigint]
scala> df.groupBy("id").agg(count_distinct($"value"))
res1: org.apache.spark.sql.DataFrame = [id: bigint, count(value): bigint]

// Spark Sql aggregate queries with correct output column
scala> spark.sql(" SELECT id, COUNT(DISTINCT *) FROM table GROUP BY id ")
res4: org.apache.spark.sql.DataFrame = [id: bigint, count(DISTINCT id, value): bigint]
scala> spark.sql(" SELECT id, COUNT(DISTINCT value) FROM table GROUP BY id ")
res2: org.apache.spark.sql.DataFrame = [id: bigint, count(DISTINCT value): bigint]
```

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added UT

Closes apache#40116 from ritikam2/master.

Authored-by: Ritika Maheshwari <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
ritikam2 authored and cloud-fan committed Mar 31, 2023
1 parent 35503a5 commit cb7d082
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class RelationalGroupedDataset protected[sql](
case expr: NamedExpression => expr
case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
UnresolvedAlias(a, Some(Column.generateAlias))
case u: UnresolvedFunction => UnresolvedAlias(expr, None)
case expr: Expression => Alias(expr, toPrettySQL(expr))()
}

Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,18 @@ class DataFrameSuite extends QueryTest
checkAnswer(approxSummaryDF, approxSummaryResult)
}

test("SPARK-41391: Correct the output column name of groupBy.agg(count_distinct)") {
withTempView("person") {
person.createOrReplaceTempView("person")
val df1 = person.groupBy("id").agg(count_distinct(col("name")))
val df2 = spark.sql("SELECT id, COUNT(DISTINCT name) FROM person GROUP BY id")
assert(df1.columns === df2.columns)
val df3 = person.groupBy("id").agg(count_distinct(col("*")))
val df4 = spark.sql("SELECT id, COUNT(DISTINCT *) FROM person GROUP BY id")
assert(df3.columns === df4.columns)
}
}

test("summary advanced") {
val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
val orderMatters = person2.summary(stats: _*)
Expand Down

0 comments on commit cb7d082

Please sign in to comment.