Skip to content

Commit

Permalink
[SPARK-27425][SQL] Add count_if function
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add `count_if` function which returns the number of records satisfying a given condition.

There is no aggregation function like this in Spark, so we need to write like
- `COUNT(CASE WHEN some_condition THEN 1 END)` or
- `SUM(CASE WHEN some_condition THEN 1 END)`, 
which looks painful.

This kind of function is already supported in Presto, BigQuery and even Excel.
- Presto: [`count_if`](https://prestodb.github.io/docs/current/functions/aggregate.html#count_if)
- BigQuery: [`countif`](https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate_functions?hl=en#countif)
- Excel: [`COUNTIF`](https://support.office.com/en-us/article/countif-function-e0de10c6-f885-4e71-abb4-1f464816df34?omkt=en-US&ui=en-US&rs=en-US&ad=US) (It is a little different from above twos)

## How was this patch tested?

This patch is tested by unit test.

Closes apache#24335 from cryeo/SPARK-27425.

Authored-by: Chaerim Yeo <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
  • Loading branch information
cryeo authored and HyukjinKwon committed Jun 10, 2019
1 parent 773cfde commit c1bb331
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ object FunctionRegistry {
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
expression[CountIf]("count_if"),
expression[CovPopulation]("covar_pop"),
expression[CovSample]("covar_samp"),
expression[First]("first"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate}
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, LongType}

@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the number of `TRUE` values for the expression.
""",
examples = """
Examples:
> SELECT _FUNC_(col % 2 = 0) FROM VALUES (NULL), (0), (1), (2), (3) AS tab(col);
2
> SELECT _FUNC_(col IS NULL) FROM VALUES (NULL), (0), (1), (2), (3) AS tab(col);
1
""",
since = "3.0.0")
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes {
override def prettyName: String = "count_if"

override def children: Seq[Expression] = Seq(predicate)

override def nullable: Boolean = false

override def dataType: DataType = LongType

override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)

override def checkInputDataTypes(): TypeCheckResult = predicate.dataType match {
case BooleanType =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@ import org.apache.spark.sql.types._
* Finds all the expressions that are unevaluable and replace/rewrite them with semantically
* equivalent expressions that can be evaluated. Currently we replace two kinds of expressions:
* 1) [[RuntimeReplaceable]] expressions
* 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any
* 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any, CountIf
* This is mainly used to provide compatibility with other databases.
* Few examples are:
* we use this to support "nvl" by replacing it with "coalesce".
* we use this to replace Every and Any with Min and Max respectively.
*
* TODO: In future, explore an option to replace aggregate functions similar to
* how RruntimeReplaceable does.
* how RuntimeReplaceable does.
*/
object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case SomeAgg(arg) => Max(arg)
case AnyAgg(arg) => Max(arg)
case EveryAgg(arg) => Min(arg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,4 +894,44 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
error.message.contains("function min_by does not support ordering on type map<int,string>"))
}
}

test("count_if") {
withTempView("tempView") {
Seq(("a", None), ("a", Some(1)), ("a", Some(2)), ("a", Some(3)),
("b", None), ("b", Some(4)), ("b", Some(5)), ("b", Some(6)))
.toDF("x", "y")
.createOrReplaceTempView("tempView")

checkAnswer(
sql("SELECT COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " +
"COUNT_IF(y IS NULL) FROM tempView"),
Row(0L, 3L, 3L, 2L))

checkAnswer(
sql("SELECT x, COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " +
"COUNT_IF(y IS NULL) FROM tempView GROUP BY x"),
Row("a", 0L, 1L, 2L, 1L) :: Row("b", 0L, 2L, 1L, 1L) :: Nil)

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 1"),
Row("a"))

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 2"),
Row("b"))

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y IS NULL) > 0"),
Row("a") :: Row("b") :: Nil)

checkAnswer(
sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"),
Nil)

val error = intercept[AnalysisException] {
sql("SELECT COUNT_IF(x) FROM tempView")
}
assert(error.message.contains("function count_if requires boolean type"))
}
}
}

0 comments on commit c1bb331

Please sign in to comment.