Skip to content

Commit

Permalink
[minor][table-planner-blink] Minor code cleanup for some exec nodes a…
Browse files Browse the repository at this point in the history
…nd physical nodes

This closes apache#14733
  • Loading branch information
godfreyhe committed Jan 26, 2021
1 parent 40c3068 commit cacfef2
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import java.lang.reflect.InvocationTargetException;
import java.util.Collections;

/** Batch [[ExecNode]] for Python unbounded group aggregate. */
/** Batch {@link ExecNode} for Python unbounded group aggregate. */
public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
implements BatchExecNode<RowData> {

Expand All @@ -50,19 +50,19 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
+ "BatchArrowPythonGroupAggregateFunctionOperator";

private final int[] grouping;
private final int[] groupingSet;
private final int[] auxGrouping;
private final AggregateCall[] aggCalls;

public BatchExecPythonGroupAggregate(
int[] grouping,
int[] groupingSet,
int[] auxGrouping,
AggregateCall[] aggCalls,
ExecEdge inputEdge,
RowType outputType,
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
this.grouping = grouping;
this.groupingSet = groupingSet;
this.auxGrouping = auxGrouping;
this.aggCalls = aggCalls;
}

Expand All @@ -85,7 +85,6 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
return transform;
}

@SuppressWarnings("unchecked")
private OneInputTransformation<RowData, RowData> createPythonOneInputTransformation(
Transformation<RowData> inputTransform,
RowType inputRowType,
Expand All @@ -102,7 +101,7 @@ private OneInputTransformation<RowData, RowData> createPythonOneInputTransformat
outputRowType,
pythonUdafInputOffsets,
pythonFunctionInfos);
return new OneInputTransformation(
return new OneInputTransformation<>(
inputTransform,
getDescription(),
pythonOperator,
Expand All @@ -117,10 +116,10 @@ private OneInputStreamOperator<RowData, RowData> getPythonAggregateFunctionOpera
RowType outputRowType,
int[] udafInputOffsets,
PythonFunctionInfo[] pythonFunctionInfos) {
final Class clazz =
final Class<?> clazz =
CommonPythonUtil.loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME);
try {
Constructor ctor =
Constructor<?> ctor =
clazz.getConstructor(
Configuration.class,
PythonFunctionInfo[].class,
Expand All @@ -136,7 +135,7 @@ private OneInputStreamOperator<RowData, RowData> getPythonAggregateFunctionOpera
inputRowType,
outputRowType,
grouping,
groupingSet,
auxGrouping,
udafInputOffsets);
} catch (NoSuchMethodException
| IllegalAccessException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import java.util.Arrays;
import java.util.Collections;

/** Batch [[ExecNode]] for group widow aggregate (Python user defined aggregate function). */
/** Batch {@link ExecNode} for group widow aggregate (Python user defined aggregate function). */
public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
implements BatchExecNode<RowData> {

Expand All @@ -60,15 +60,15 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
+ "BatchArrowPythonGroupWindowAggregateFunctionOperator";

private final int[] grouping;
private final int[] groupingSet;
private final int[] auxGrouping;
private final AggregateCall[] aggCalls;
private final LogicalWindow window;
private final int inputTimeFieldIndex;
private final FlinkRelBuilder.PlannerNamedWindowProperty[] namedWindowProperties;

public BatchExecPythonGroupWindowAggregate(
int[] grouping,
int[] groupingSet,
int[] auxGrouping,
AggregateCall[] aggCalls,
LogicalWindow window,
int inputTimeFieldIndex,
Expand All @@ -78,13 +78,14 @@ public BatchExecPythonGroupWindowAggregate(
String description) {
super(Collections.singletonList(inputEdge), outputType, description);
this.grouping = grouping;
this.groupingSet = groupingSet;
this.auxGrouping = auxGrouping;
this.aggCalls = aggCalls;
this.window = window;
this.inputTimeFieldIndex = inputTimeFieldIndex;
this.namedWindowProperties = namedWindowProperties;
}

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0);
Expand Down Expand Up @@ -203,7 +204,7 @@ private OneInputStreamOperator<RowData, RowData> getPythonGroupWindowAggregateFu
slideSize,
namePropertyTypeArray,
grouping,
groupingSet,
auxGrouping,
udafInputOffsets);
} catch (NoSuchMethodException
| InstantiationException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,19 @@
import java.util.List;

/**
* Batch [[ExecNode]] for sort-based over window aggregate (Python user defined aggregate function).
* Batch {@link ExecNode} for sort-based over window aggregate (Python user defined aggregate
* function).
*/
public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase {

private static final String ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch."
+ "BatchArrowPythonOverWindowAggregateFunctionOperator";

private List<Long> lowerBoundary;
private List<Long> upperBoundary;
private List<AggregateCall> aggCalls;
private List<Integer> aggWindowIndex;
private final List<Long> lowerBoundary;
private final List<Long> upperBoundary;
private final List<AggregateCall> aggCalls;
private final List<Integer> aggWindowIndex;

public BatchExecPythonOverAggregate(
OverSpec over, ExecEdge inputEdge, RowType outputType, String description) {
Expand Down Expand Up @@ -179,11 +180,11 @@ private OneInputStreamOperator<RowData, RowData> getPythonOverWindowAggregateFun
boolean[] isRangeWindows,
int[] udafInputOffsets,
PythonFunctionInfo[] pythonFunctionInfos) {
Class clazz =
Class<?> clazz =
CommonPythonUtil.loadClass(
ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
try {
Constructor ctor =
Constructor<?> ctor =
clazz.getConstructor(
Configuration.class,
PythonFunctionInfo[].class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,15 @@ private OneInputStreamOperator<RowData, RowData> getPythonScalarFunctionOperator
PythonFunctionInfo[] pythonFunctionInfos,
int[] forwardedFields,
boolean isArrow) {
Class clazz;
Class<?> clazz;
if (isArrow) {
clazz = CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
} else {
clazz = CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
}

try {
Constructor ctor =
Constructor<?> ctor =
clazz.getConstructor(
Configuration.class,
PythonFunctionInfo[].class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
inputCountIndex,
countStarInserted);
// partitioned aggregation
OneInputTransformation transform =
new OneInputTransformation(
OneInputTransformation<RowData, RowData> transform =
new OneInputTransformation<>(
inputTransform,
getDescription(),
operator,
Expand Down Expand Up @@ -163,7 +163,7 @@ private OneInputStreamOperator<RowData, RowData> getPythonAggregateFunctionOpera
boolean countStarInserted) {
Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME);
try {
Constructor ctor =
Constructor<?> ctor =
clazz.getConstructor(
Configuration.class,
RowType.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import java.util.Arrays;
import java.util.Collections;

/** Stream [[ExecNode]] for unbounded python group table aggregate. */
/** Stream {@link ExecNode} for unbounded python group table aggregate. */
public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
implements StreamExecNode<RowData> {
private static final Logger LOG =
Expand Down Expand Up @@ -162,9 +162,9 @@ private OneInputStreamOperator<RowData, RowData> getPythonTableAggregateFunction
long maxIdleStateRetentionTime,
boolean generateUpdateBefore,
int indexOfCountStar) {
Class clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME);
Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME);
try {
Constructor ctor =
Constructor<?> ctor =
clazz.getConstructor(
Configuration.class,
RowType.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toDuration;
import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toLong;

/** Stream [[ExecNode]] for group widow aggregate (Python user defined aggregate function). */
/** Stream {@link ExecNode} for group widow aggregate (Python user defined aggregate function). */
public class StreamExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
implements StreamExecNode<RowData> {
private static final Logger LOGGER =
Expand Down Expand Up @@ -110,6 +110,7 @@ public StreamExecPythonGroupWindowAggregate(
this.emitStrategy = emitStrategy;
}

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final boolean isCountWindow;
Expand Down Expand Up @@ -289,7 +290,7 @@ private Tuple2<WindowAssigner<?>, Trigger<?>> generateWindowAssignerAndTrigger()
inputTransform.getParallelism());
}

@SuppressWarnings("unchecked")
@SuppressWarnings({"unchecked", "rawtypes"})
private OneInputStreamOperator<RowData, RowData>
getPythonStreamGroupWindowAggregateFunctionOperator(
Configuration config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import java.math.BigDecimal;
import java.util.Collections;

/** Stream [[ExecNode]] for python time-based over operator. */
/** Stream {@link ExecNode} for python time-based over operator. */
public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
implements StreamExecNode<RowData> {
private static final Logger LOG = LoggerFactory.getLogger(StreamExecPythonOverAggregate.class);
Expand Down Expand Up @@ -82,6 +82,7 @@ public StreamExecPythonOverAggregate(
this.overSpec = overSpec;
}

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
if (overSpec.getGroups().size() > 1) {
Expand Down Expand Up @@ -230,9 +231,9 @@ private OneInputStreamOperator<RowData, RowData> getPythonOverWindowAggregateFun
className =
ARROW_PYTHON_OVER_WINDOW_ROWS_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME;
}
Class clazz = CommonPythonUtil.loadClass(className);
Class<?> clazz = CommonPythonUtil.loadClass(className);
try {
Constructor ctor =
Constructor<?> ctor =
clazz.getConstructor(
Configuration.class,
long.class,
Expand Down Expand Up @@ -272,9 +273,9 @@ private OneInputStreamOperator<RowData, RowData> getPythonOverWindowAggregateFun
className =
ARROW_PYTHON_OVER_WINDOW_RANGE_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME;
}
Class clazz = CommonPythonUtil.loadClass(className);
Class<?> clazz = CommonPythonUtil.loadClass(className);
try {
Constructor ctor =
Constructor<?> ctor =
clazz.getConstructor(
Configuration.class,
PythonFunctionInfo[].class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.cost.{FlinkCost, FlinkCostFactory}
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecNestedLoopJoin
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.utils.JoinTypeUtil
import org.apache.flink.table.runtime.typeutils.{BinaryRowDataSerializer}
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer

import org.apache.calcite.plan._
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rex.RexNode

import java.util

import scala.collection.JavaConversions._

/**
* Batch physical RelNode for nested-loop [[Join]].
*/
Expand Down Expand Up @@ -115,8 +111,6 @@ class BatchPhysicalNestedLoopJoin(
satisfyTraitsOnBroadcastJoin(requiredTraitSet, leftIsBuild)
}

//~ ExecNode methods -----------------------------------------------------------

override def translateToExecNode(): ExecNode[_] = {
val (leftEdge, rightEdge) = getInputEdges
new BatchExecNestedLoopJoin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonCalc
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}

import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistributionTraitDef
import org.apache.flink.table.planner.plan.cost.{FlinkCost, FlinkCostFactory}
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecSortMergeJoin
import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
import org.apache.flink.table.planner.plan.utils.{FlinkRelMdUtil, FlinkRelOptUtil, JoinTypeUtil, JoinUtil}
import org.apache.flink.table.runtime.operators.join.FlinkJoinType

Expand All @@ -32,8 +32,6 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelCollationTraitDef, RelNode, RelWriter}
import org.apache.calcite.rex.RexNode

import java.util

import scala.collection.JavaConversions._

/**
Expand Down Expand Up @@ -168,19 +166,7 @@ class BatchPhysicalSortMergeJoin(
Some(copy(newProvidedTraitSet, Seq(newLeft, newRight)))
}

//~ ExecNode methods -----------------------------------------------------------

// this method must be in sync with the behavior of SortMergeJoinOperator.
def getInputEdges: util.List[ExecEdge] = List(
ExecEdge.builder()
.damBehavior(ExecEdge.DamBehavior.END_INPUT)
.build(),
ExecEdge.builder()
.damBehavior(ExecEdge.DamBehavior.END_INPUT)
.build())

override def translateToExecNode(): ExecNode[_] = {

JoinUtil.validateJoinSpec(
joinSpec,
FlinkTypeFactory.toLogicalRowType(left.getRowType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ import scala.collection.JavaConverters._
abstract class CommonPhysicalExchange(
cluster: RelOptCluster,
traitSet: RelTraitSet,
relNode: RelNode,
inputRel: RelNode,
relDistribution: RelDistribution)
extends Exchange(cluster, traitSet, relNode, relDistribution)
with FlinkPhysicalRel {
extends Exchange(cluster, traitSet, inputRel, relDistribution)
with FlinkPhysicalRel {

override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = {
val inputRows = mq.getRowCount(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class StreamPhysicalSort(
.item("orderBy", RelExplainUtil.collationToString(sortCollation, getRowType))
}

//~ ExecNode methods -----------------------------------------------------------

override def translateToExecNode(): ExecNode[_] = {
new StreamExecSort(
SortUtil.getSortSpec(sortCollation.getFieldCollations),
Expand Down

0 comments on commit cacfef2

Please sign in to comment.