Skip to content

Commit

Permalink
[FLINK-30542][table-planner] Introduce adaptive local hash aggregate …
Browse files Browse the repository at this point in the history
…to adaptively determine whether local hash aggregate is required at runtime

This closes apache#21586
  • Loading branch information
swuferhong authored and godfreyhe committed Jan 31, 2023
1 parent cf358d7 commit 122ba8f
Show file tree
Hide file tree
Showing 15 changed files with 733 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class BatchExecHashAggregate extends ExecNodeBase<RowData>
private final RowType aggInputRowType;
private final boolean isMerge;
private final boolean isFinal;
private final boolean supportAdaptiveLocalHashAgg;

public BatchExecHashAggregate(
ReadableConfig tableConfig,
Expand All @@ -67,6 +68,7 @@ public BatchExecHashAggregate(
RowType aggInputRowType,
boolean isMerge,
boolean isFinal,
boolean supportAdaptiveLocalHashAgg,
InputProperty inputProperty,
RowType outputType,
String description) {
Expand All @@ -83,6 +85,7 @@ public BatchExecHashAggregate(
this.aggInputRowType = aggInputRowType;
this.isMerge = isMerge;
this.isFinal = isFinal;
this.supportAdaptiveLocalHashAgg = supportAdaptiveLocalHashAgg;
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -126,17 +129,17 @@ protected Transformation<RowData> translateToPlanInternal(
config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)
.getBytes();
generatedOperator =
new HashAggCodeGenerator(
ctx,
planner.createRelBuilder(),
aggInfos,
inputRowType,
outputRowType,
grouping,
auxGrouping,
isMerge,
isFinal)
.genWithKeys();
HashAggCodeGenerator.genWithKeys(
ctx,
planner.createRelBuilder(),
aggInfos,
inputRowType,
outputRowType,
grouping,
auxGrouping,
isMerge,
isFinal,
supportAdaptiveLocalHashAgg);
}

return ExecNodeUtil.createOneInputTransformation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@ package org.apache.flink.table.planner.codegen

import org.apache.flink.table.data.RowData
import org.apache.flink.table.data.binary.BinaryRowData
import org.apache.flink.table.data.writer.BinaryRowWriter
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.planner.codegen.GenerateUtils.generateRecordStatement
import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens
import org.apache.flink.table.planner.functions.aggfunctions._
import org.apache.flink.table.planner.plan.utils.AggregateInfo
import org.apache.flink.table.runtime.generated.{GeneratedProjection, Projection}
import org.apache.flink.table.types.logical.RowType
import org.apache.flink.table.types.logical.{BigIntType, LogicalType, RowType}
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes

import scala.collection.mutable.ArrayBuffer

/**
* CodeGenerator for projection, Take out some fields of [[RowData]] to generate a new [[RowData]].
Expand Down Expand Up @@ -124,6 +131,163 @@ object ProjectionCodeGenerator {
new GeneratedProjection(className, code, ctx.references.toArray, ctx.tableConfig)
}

/**
* If adaptive local hash aggregation takes effect, local hash aggregation will be suppressed. In
* order to ensure that the data structure transmitted downstream with doing local hash
* aggregation is consistent with the data format transmitted downstream without doing local hash
* aggregation, we need to do projection for grouping function value.
*
* <p> For example, for sql statement "select a, avg(b), count(c) from T group by a", if local
* hash aggregation suppressed and a row (1, 5, "a") comes to local hash aggregation, we will pass
* (1, 5, 1, 1) to downstream.
*/
def genAdaptiveLocalHashAggValueProjectionCode(
ctx: CodeGeneratorContext,
inputType: RowType,
outClass: Class[_ <: RowData] = classOf[BinaryRowData],
inputTerm: String = DEFAULT_INPUT1_TERM,
aggInfos: Array[AggregateInfo],
outRecordTerm: String = DEFAULT_OUT_RECORD_TERM,
outRecordWriterTerm: String = DEFAULT_OUT_RECORD_WRITER_TERM): String = {
val fieldExprs: ArrayBuffer[GeneratedExpression] = ArrayBuffer()
aggInfos.map {
aggInfo =>
aggInfo.function match {
case sumAggFunction: SumAggFunction =>
fieldExprs += genValueProjectionForSumAggFunc(
ctx,
inputType,
inputTerm,
sumAggFunction.getResultType.getLogicalType,
aggInfo.agg.getArgList.get(0))
case _: MaxAggFunction | _: MinAggFunction =>
fieldExprs += GenerateUtils.generateFieldAccess(
ctx,
inputType,
inputTerm,
aggInfo.agg.getArgList.get(0))
case avgAggFunction: AvgAggFunction =>
fieldExprs += genValueProjectionForSumAggFunc(
ctx,
inputType,
inputTerm,
avgAggFunction.getSumType.getLogicalType,
aggInfo.agg.getArgList.get(0))
fieldExprs += genValueProjectionForCountAggFunc(
ctx,
inputTerm,
aggInfo.agg.getArgList.get(0))
case _: CountAggFunction =>
fieldExprs += genValueProjectionForCountAggFunc(
ctx,
inputTerm,
aggInfo.agg.getArgList.get(0))
case _: Count1AggFunction =>
fieldExprs += genValueProjectionForCount1AggFunc(ctx)
}
}

val binaryRowWriter = CodeGenUtils.className[BinaryRowWriter]
val typeTerm = outClass.getCanonicalName
ctx.addReusableMember(s"private $typeTerm $outRecordTerm= new $typeTerm(${fieldExprs.size});")
ctx.addReusableMember(
s"private $binaryRowWriter $outRecordWriterTerm = new $binaryRowWriter($outRecordTerm);")

val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap
val setFieldsCode = fieldExprs.zipWithIndex
.map {
case (fieldExpr, index) =>
val pos = fieldExprIdxToOutputRowPosMap.getOrElse(
index,
throw new CodeGenException(s"Illegal field expr index: $index"))
rowSetField(
ctx,
classOf[BinaryRowData],
outRecordTerm,
pos.toString,
fieldExpr,
Option(outRecordWriterTerm))
}
.mkString("\n")

val writer = outRecordWriterTerm
val resetWriter = s"$writer.reset();"
val completeWriter: String = s"$writer.complete();"
s"""
|$resetWriter
|$setFieldsCode
|$completeWriter
""".stripMargin
}

/**
* Do projection for grouping function 'sum(col)' if adaptive local hash aggregation takes effect.
* For 'count(col)', we will try to convert the projected value type to sum agg function target
* type if col is not null and convert it to default value type if col is null.
*/
def genValueProjectionForSumAggFunc(
ctx: CodeGeneratorContext,
inputType: LogicalType,
inputTerm: String,
targetType: LogicalType,
index: Int): GeneratedExpression = {
val fieldType = getFieldTypes(inputType).get(index)
val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val readCode = rowFieldReadAccess(index.toString, inputTerm, fieldType)
val Seq(fieldTerm, nullTerm) =
ctx.addReusableLocalVariables((resultTypeTerm, "field"), ("boolean", "isNull"))

val inputCode =
s"""
|$nullTerm = $inputTerm.isNullAt($index);
|$fieldTerm = $defaultValue;
|if (!$nullTerm) {
| $fieldTerm = $readCode;
|}
""".stripMargin.trim

val expression = GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType)
// Convert the projected value type to sum agg func target type.
ScalarOperatorGens.generateCast(ctx, expression, targetType, true)
}

/**
* Do projection for grouping function 'count(col)' if adaptive local hash aggregation takes
* effect. 'count(col)' will be convert to 1L if col is not null and convert to 0L if col is null.
*/
def genValueProjectionForCountAggFunc(
ctx: CodeGeneratorContext,
inputTerm: String,
index: Int): GeneratedExpression = {
val Seq(fieldTerm, nullTerm) =
ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull"))

val inputCode =
s"""
|$fieldTerm = 0L;
|if (!$inputTerm.isNullAt($index)) {
| $fieldTerm = 1L;
|}
""".stripMargin.trim

GeneratedExpression(fieldTerm, nullTerm, inputCode, new BigIntType())
}

/**
* Do projection for grouping function 'count(*)' or 'count(1)' if adaptive local hash agg takes
* effect. 'count(*) or count(1)' will be convert to 1L and transmitted to downstream.
*/
def genValueProjectionForCount1AggFunc(ctx: CodeGeneratorContext): GeneratedExpression = {
val Seq(fieldTerm, nullTerm) =
ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull"))
val inputCode =
s"""
|$fieldTerm = 1L;
|""".stripMargin.trim
GeneratedExpression(fieldTerm, nullTerm, inputCode, new BigIntType())
}

/** For java invoke. */
def generateProjection(
ctx: CodeGeneratorContext,
Expand Down
Loading

0 comments on commit 122ba8f

Please sign in to comment.