Skip to content

Commit

Permalink
[FLINK-20738][table-planner-blink] Introduce BatchPhysicalSortAggrega…
Browse files Browse the repository at this point in the history
…te, and make BatchExecSortAggregate only extended from ExecNode

This closes apache#14562
  • Loading branch information
godfreyhe committed Jan 7, 2021
1 parent 9f8f5cd commit 29c81fe
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 105 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.nodes.exec.batch;

import org.apache.flink.api.dag.Transformation;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.agg.batch.AggWithoutKeysCodeGenerator;
import org.apache.flink.table.planner.codegen.agg.batch.SortAggCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.generated.GeneratedOperator;
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.RowType;

import org.apache.calcite.rel.core.AggregateCall;

import java.util.Arrays;
import java.util.Collections;

/** Batch {@link ExecNode} for (global) sort-based aggregate operator. */
public class BatchExecSortAggregate extends ExecNodeBase<RowData>
implements BatchExecNode<RowData> {

private final int[] grouping;
private final int[] auxGrouping;
private final AggregateCall[] aggCalls;
private final RowType aggInputRowType;
private final boolean isMerge;
private final boolean isFinal;

public BatchExecSortAggregate(
int[] grouping,
int[] auxGrouping,
AggregateCall[] aggCalls,
RowType aggInputRowType,
boolean isMerge,
boolean isFinal,
ExecEdge inputEdge,
RowType outputType,
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
this.grouping = grouping;
this.auxGrouping = auxGrouping;
this.aggCalls = aggCalls;
this.aggInputRowType = aggInputRowType;
this.isMerge = isMerge;
this.isFinal = isFinal;
}

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0);
final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner);

final RowType inputRowType = (RowType) inputNode.getOutputType();
final RowType outputRowType = (RowType) getOutputType();

final CodeGeneratorContext ctx = new CodeGeneratorContext(planner.getTableConfig());
final AggregateInfoList aggInfos =
AggregateUtil.transformToBatchAggregateInfoList(
aggInputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
null,
null);

final GeneratedOperator<OneInputStreamOperator<RowData, RowData>> generatedOperator;
if (grouping.length == 0) {
generatedOperator =
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx,
planner.getRelBuilder(),
aggInfos,
inputRowType,
outputRowType,
isMerge,
isFinal,
"NoGrouping");
} else {
generatedOperator =
SortAggCodeGenerator.genWithKeys(
ctx,
planner.getRelBuilder(),
aggInfos,
inputRowType,
outputRowType,
grouping,
auxGrouping,
isMerge,
isFinal);
}

return new OneInputTransformation<>(
inputTransform,
getDesc(),
new CodeGenOperatorFactory<>(generatedOperator),
InternalTypeInfo.of(outputRowType),
inputTransform.getParallelism());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.flink.table.types.logical.RowType
*/
object SortAggCodeGenerator {

private[flink] def genWithKeys(
def genWithKeys(
ctx: CodeGeneratorContext,
builder: RelBuilder,
aggInfoList: AggregateInfoList,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
}
case agg: BatchExecLocalSortAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchExecSortAggregate if agg.isMerge =>
case agg: BatchPhysicalSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
if (aggCallIndexInLocalAgg != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@

package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.api.dag.Transformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, SortAggCodeGenerator}
import org.apache.flink.table.planner.delegation.BatchPlanner
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode}
import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList
import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil}
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo

import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.RelDistribution.Type
Expand All @@ -48,18 +58,17 @@ class BatchExecLocalSortAggregate(
grouping: Array[Int],
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)])
extends BatchExecSortAggregateBase(
extends BatchPhysicalSortAggregateBase(
cluster,
traitSet,
inputRel,
outputRowType,
inputRowType,
inputRowType,
grouping,
auxGrouping,
aggCallToAggFunction,
isMerge = false,
isFinal = false) {
isFinal = false)
with LegacyBatchExecNode[RowData] {

override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecLocalSortAggregate(
Expand Down Expand Up @@ -129,6 +138,42 @@ class BatchExecLocalSortAggregate(

//~ ExecNode methods -----------------------------------------------------------

override protected def translateToPlanInternal(
planner: BatchPlanner): Transformation[RowData] = {
val input = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val ctx = CodeGeneratorContext(planner.getTableConfig)
val outputType = FlinkTypeFactory.toLogicalRowType(getRowType)
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)

val aggInfos = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowType), getAggCallList)

val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else {
SortAggCodeGenerator.genWithKeys(
ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal)
}
val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
ExecNodeUtil.createOneInputTransformation(
input,
getRelDetailedDescription,
operator,
InternalTypeInfo.of(outputType),
input.getParallelism,
0)
}

override def getInputEdges: util.List[ExecEdge] = {
if (grouping.length == 0) {
List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef}
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecSortAggregate
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.rules.physical.batch.BatchExecJoinRuleBase
import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil}

Expand All @@ -40,7 +42,7 @@ import scala.collection.JavaConversions._
*
* @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/
class BatchExecSortAggregate(
class BatchPhysicalSortAggregate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
Expand All @@ -51,21 +53,19 @@ class BatchExecSortAggregate(
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
isMerge: Boolean)
extends BatchExecSortAggregateBase(
extends BatchPhysicalSortAggregateBase(
cluster,
traitSet,
inputRel,
outputRowType,
inputRowType,
aggInputRowType,
grouping,
auxGrouping,
aggCallToAggFunction,
isMerge = isMerge,
isFinal = true) {

override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecSortAggregate(
new BatchPhysicalSortAggregate(
cluster,
traitSet,
inputs.get(0),
Expand Down Expand Up @@ -153,16 +153,25 @@ class BatchExecSortAggregate(
Some(copy(newProvidedTraitSet, Seq(newInput)))
}

//~ ExecNode methods -----------------------------------------------------------
override def translateToExecNode(): ExecNode[_] = {
new BatchExecSortAggregate(
grouping,
auxGrouping,
getAggCallList.toArray,
FlinkTypeFactory.toLogicalRowType(aggInputRowType),
isMerge,
true, // isFinal is always true
getInputEdge,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription
)
}

override def getInputEdges: util.List[ExecEdge] = {
private def getInputEdge: ExecEdge = {
if (grouping.length == 0) {
List(
ExecEdge.builder()
.damBehavior(ExecEdge.DamBehavior.END_INPUT)
.build())
ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build()
} else {
List(ExecEdge.DEFAULT)
ExecEdge.DEFAULT
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,8 @@
*/
package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.api.dag.Transformation
import org.apache.flink.table.data.RowData
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, SortAggCodeGenerator}
import org.apache.flink.table.planner.delegation.BatchPlanner
import org.apache.flink.table.planner.plan.cost.{FlinkCost, FlinkCostFactory}
import org.apache.flink.table.planner.plan.nodes.exec.LegacyBatchExecNode
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil
import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo

import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.RelNode
Expand All @@ -42,13 +31,11 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery
*
* @see [[BatchPhysicalGroupAggregateBase]] for more info.
*/
abstract class BatchExecSortAggregateBase(
abstract class BatchPhysicalSortAggregateBase(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
outputRowType: RelDataType,
inputRowType: RelDataType,
aggInputRowType: RelDataType,
grouping: Array[Int],
auxGrouping: Array[Int],
aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)],
Expand All @@ -63,8 +50,7 @@ abstract class BatchExecSortAggregateBase(
auxGrouping,
aggCallToAggFunction,
isMerge,
isFinal)
with LegacyBatchExecNode[RowData]{
isFinal) {

override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = {
val inputRows = mq.getRowCount(getInput())
Expand All @@ -79,42 +65,4 @@ abstract class BatchExecSortAggregateBase(
val costFactory = planner.getCostFactory.asInstanceOf[FlinkCostFactory]
costFactory.makeCost(rowCount, cpuCost, 0, 0, memCost)
}

//~ ExecNode methods -----------------------------------------------------------

override protected def translateToPlanInternal(
planner: BatchPlanner): Transformation[RowData] = {
val input = getInputNodes.get(0).translateToPlan(planner)
.asInstanceOf[Transformation[RowData]]
val ctx = CodeGeneratorContext(planner.getTableConfig)
val outputType = FlinkTypeFactory.toLogicalRowType(getRowType)
val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)

val aggInfos = transformToBatchAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(aggInputRowType), getAggCallList)

val generatedOperator = if (grouping.isEmpty) {
AggWithoutKeysCodeGenerator.genWithoutKeys(
ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping")
} else {
SortAggCodeGenerator.genWithKeys(
ctx,
planner.getRelBuilder,
aggInfos,
inputType,
outputType,
grouping,
auxGrouping,
isMerge,
isFinal)
}
val operator = new CodeGenOperatorFactory[RowData](generatedOperator)
ExecNodeUtil.createOneInputTransformation(
input,
getRelDetailedDescription,
operator,
InternalTypeInfo.of(outputType),
input.getParallelism,
0)
}
}
Loading

0 comments on commit 29c81fe

Please sign in to comment.