Skip to content

Commit

Permalink
[FLINK-28491][table-planner] Introduce APPROX_COUNT_DISTINCT aggregat…
Browse files Browse the repository at this point in the history
…e function for batch sql

This closes apache#20243
  • Loading branch information
godfreyhe committed Jul 22, 2022
1 parent c1e66fc commit 9db9e19
Show file tree
Hide file tree
Showing 12 changed files with 993 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,8 @@ public List<SqlGroupedWindowFunction> getAuxiliaryFunctions() {
public static final SqlAggFunction VAR_POP = SqlStdOperatorTable.VAR_POP;
public static final SqlAggFunction VAR_SAMP = SqlStdOperatorTable.VAR_SAMP;
public static final SqlAggFunction SINGLE_VALUE = SqlStdOperatorTable.SINGLE_VALUE;
public static final SqlAggFunction APPROX_COUNT_DISTINCT =
SqlStdOperatorTable.APPROX_COUNT_DISTINCT;

// ARRAY OPERATORS
public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = new SqlArrayConstructor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
import org.apache.flink.table.planner.functions.sql.{SqlFirstLastValueAggFunction, SqlListAggFunction}
import org.apache.flink.table.planner.functions.utils.AggSqlFunction
import org.apache.flink.table.runtime.functions.aggregate.{BuiltInAggregateFunction, CollectAggFunction, FirstValueAggFunction, FirstValueWithRetractAggFunction, JsonArrayAggFunction, JsonObjectAggFunction, LagAggFunction, LastValueAggFunction, LastValueWithRetractAggFunction, ListAggWithRetractAggFunction, ListAggWsWithRetractAggFunction, MaxWithRetractAggFunction, MinWithRetractAggFunction}
import org.apache.flink.table.runtime.functions.aggregate.BatchApproxCountDistinctAggFunctions._
import org.apache.flink.table.types.logical._
import org.apache.flink.table.types.logical.LogicalTypeRoot._

Expand Down Expand Up @@ -80,7 +81,9 @@ class AggFunctionFactory(
case _: SqlCountAggFunction if call.getArgList.size() > 1 =>
throw new TableException("We now only support the count of one field.")

// TODO supports ApproximateCountDistinctAggFunction and CountDistinctAggFunction
// TODO supports CountDistinctAggFunction
case _: SqlCountAggFunction if call.isDistinct && call.isApproximate =>
createApproxCountDistinctAggFunction(argTypes, index)

case _: SqlCountAggFunction if call.getArgList.isEmpty => createCount1AggFunction(argTypes)

Expand Down Expand Up @@ -405,6 +408,49 @@ class AggFunctionFactory(
}
}

private def createApproxCountDistinctAggFunction(
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (!isBounded) {
throw new TableException(
s"APPROX_COUNT_DISTINCT aggregate function does not support yet for streaming.")
}
argTypes(0).getTypeRoot match {
case TINYINT =>
new ByteApproxCountDistinctAggFunction
case SMALLINT =>
new ShortApproxCountDistinctAggFunction
case INTEGER =>
new IntApproxCountDistinctAggFunction
case BIGINT =>
new LongApproxCountDistinctAggFunction
case FLOAT =>
new FloatApproxCountDistinctAggFunction
case DOUBLE =>
new DoubleApproxCountDistinctAggFunction
case DATE =>
new DateApproxCountDistinctAggFunction
case TIME_WITHOUT_TIME_ZONE =>
new TimeApproxCountDistinctAggFunction
case TIMESTAMP_WITHOUT_TIME_ZONE =>
val d = argTypes(0).asInstanceOf[TimestampType]
new TimestampApproxCountDistinctAggFunction(d)
case TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
val ltzType = argTypes(0).asInstanceOf[LocalZonedTimestampType]
new TimestampLtzApproxCountDistinctAggFunction(ltzType)
case DECIMAL =>
val d = argTypes(0).asInstanceOf[DecimalType]
new DecimalApproxCountDistinctAggFunction(d)
case CHAR | VARCHAR =>
new StringApproxCountDistinctAggFunction()

case t =>
throw new TableException(
s"APPROX_COUNT_DISTINCT aggregate function does not support type: ''$t''.\n" +
s"Please re-check the data type.")
}
}

private def createCount1AggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = {
new Count1AggFunction
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ object RelExplainUtil {
var offset = fullGrouping.length
val aggStrings = aggCallToAggFunction.zipWithIndex.map {
case ((aggCall, udf), index) =>
val approximate = if (aggCall.isApproximate) {
"APPROXIMATE "
} else {
""
}

val distinct = if (aggCall.isDistinct) {
if (aggCall.getArgList.size() == 0) {
"DISTINCT"
Expand Down Expand Up @@ -208,10 +214,10 @@ object RelExplainUtil {
}

if (aggCall.filterArg >= 0 && aggCall.filterArg < inputFieldNames.size) {
s"${aggCall.getAggregation}($distinct$argListNames) FILTER " +
s"${aggCall.getAggregation}($approximate$distinct$argListNames) FILTER " +
s"${inputFieldNames(aggCall.filterArg)}"
} else {
s"${aggCall.getAggregation}($distinct$argListNames)"
s"${aggCall.getAggregation}($approximate$distinct$argListNames)"
}
}

Expand Down
Loading

0 comments on commit 9db9e19

Please sign in to comment.