Skip to content

Commit

Permalink
[FLINK-18445][table] Add pre-filtering optimization for lookup join
Browse files Browse the repository at this point in the history
  • Loading branch information
lincoln-lil committed Aug 31, 2023
1 parent ee110aa commit 360b97a
Show file tree
Hide file tree
Showing 24 changed files with 1,917 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ public class BatchExecLookupJoin extends CommonExecLookupJoin
public BatchExecLookupJoin(
ReadableConfig tableConfig,
FlinkJoinType joinType,
@Nullable RexNode joinCondition,
@Nullable RexNode preFilterCondition,
@Nullable RexNode remainingJoinCondition,
TemporalTableSourceSpec temporalTableSourceSpec,
Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
@Nullable List<RexNode> projectionOnTemporalTable,
Expand All @@ -64,7 +65,8 @@ public BatchExecLookupJoin(
ExecNodeContext.newContext(BatchExecLookupJoin.class),
ExecNodeContext.newPersistedConfig(BatchExecLookupJoin.class, tableConfig),
joinType,
joinCondition,
preFilterCondition,
remainingJoinCondition,
temporalTableSourceSpec,
lookupKeys,
projectionOnTemporalTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.apache.flink.table.runtime.collector.ListenableCollector;
import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
import org.apache.flink.table.runtime.generated.GeneratedCollector;
import org.apache.flink.table.runtime.generated.GeneratedFilterCondition;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
Expand Down Expand Up @@ -144,7 +145,8 @@ public abstract class CommonExecLookupJoin extends ExecNodeBase<RowData> {
public static final String LOOKUP_JOIN_MATERIALIZE_TRANSFORMATION = "lookup-join-materialize";

public static final String FIELD_NAME_JOIN_TYPE = "joinType";
public static final String FIELD_NAME_JOIN_CONDITION = "joinCondition";
public static final String FIELD_NAME_PRE_FILTER_CONDITION = "preFilterCondition";
public static final String FIELD_NAME_REMAINING_JOIN_CONDITION = "joinCondition";
public static final String FIELD_NAME_TEMPORAL_TABLE = "temporalTable";
public static final String FIELD_NAME_LOOKUP_KEYS = "lookupKeys";
public static final String FIELD_NAME_PROJECTION_ON_TEMPORAL_TABLE =
Expand Down Expand Up @@ -175,9 +177,14 @@ public abstract class CommonExecLookupJoin extends ExecNodeBase<RowData> {
@JsonProperty(FIELD_NAME_FILTER_ON_TEMPORAL_TABLE)
private final @Nullable RexNode filterOnTemporalTable;

/** join condition except equi-conditions extracted as lookup keys. */
@JsonProperty(FIELD_NAME_JOIN_CONDITION)
private final @Nullable RexNode joinCondition;
/** pre-filter condition on left input except lookup keys. */
@JsonProperty(FIELD_NAME_PRE_FILTER_CONDITION)
@JsonInclude(JsonInclude.Include.NON_NULL)
private final @Nullable RexNode preFilterCondition;

/** remaining join condition except pre-filter & equi-conditions except lookup keys. */
@JsonProperty(FIELD_NAME_REMAINING_JOIN_CONDITION)
private final @Nullable RexNode remainingJoinCondition;

@JsonProperty(FIELD_NAME_INPUT_CHANGELOG_MODE)
private final ChangelogMode inputChangelogMode;
Expand All @@ -195,7 +202,8 @@ protected CommonExecLookupJoin(
ExecNodeContext context,
ReadableConfig persistedConfig,
FlinkJoinType joinType,
@Nullable RexNode joinCondition,
@Nullable RexNode preFilterCondition,
@Nullable RexNode remainingJoinCondition,
// TODO: refactor this into TableSourceTable, once legacy TableSource is removed
TemporalTableSourceSpec temporalTableSourceSpec,
Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
Expand All @@ -210,7 +218,8 @@ protected CommonExecLookupJoin(
super(id, context, persistedConfig, inputProperties, outputType, description);
checkArgument(inputProperties.size() == 1);
this.joinType = checkNotNull(joinType);
this.joinCondition = joinCondition;
this.preFilterCondition = preFilterCondition;
this.remainingJoinCondition = remainingJoinCondition;
this.lookupKeys = Collections.unmodifiableMap(checkNotNull(lookupKeys));
this.temporalTableSourceSpec = checkNotNull(temporalTableSourceSpec);
this.projectionOnTemporalTable = projectionOnTemporalTable;
Expand Down Expand Up @@ -410,7 +419,11 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
"TableFunctionResultFuture",
inputRowType,
rightRowType,
JavaScalaConversionUtil.toScala(Optional.ofNullable(joinCondition)));
JavaScalaConversionUtil.toScala(
Optional.ofNullable(remainingJoinCondition)));
GeneratedFilterCondition generatedPreFilterCondition =
LookupJoinCodeGenerator.generatePreFilterCondition(
config, classLoader, preFilterCondition, inputRowType);

DataStructureConverter<?, ?> fetcherConverter =
DataStructureConverters.getConverter(generatedFuncWithType.dataType());
Expand All @@ -431,6 +444,7 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
(DataStructureConverter<RowData, Object>) fetcherConverter,
generatedCalc,
generatedResultFuture,
generatedPreFilterCondition,
InternalSerializers.create(rightRowType),
isLeftOuterJoin,
asyncLookupOptions.asyncBufferCapacity);
Expand All @@ -441,6 +455,7 @@ private StreamOperatorFactory<RowData> createAsyncLookupJoin(
generatedFuncWithType.tableFunc(),
(DataStructureConverter<RowData, Object>) fetcherConverter,
generatedResultFuture,
generatedPreFilterCondition,
InternalSerializers.create(rightRowType),
isLeftOuterJoin,
asyncLookupOptions.asyncBufferCapacity);
Expand Down Expand Up @@ -540,9 +555,14 @@ protected ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
inputRowType,
rightRowType,
resultRowType,
JavaScalaConversionUtil.toScala(Optional.ofNullable(joinCondition)),
JavaScalaConversionUtil.toScala(
Optional.ofNullable(remainingJoinCondition)),
JavaScalaConversionUtil.toScala(Optional.empty()),
true);

GeneratedFilterCondition generatedPreFilterCondition =
LookupJoinCodeGenerator.generatePreFilterCondition(
config, classLoader, preFilterCondition, inputRowType);
ProcessFunction<RowData, RowData> processFunc;
if (projectionOnTemporalTable != null) {
// a projection or filter after table source scan
Expand All @@ -560,6 +580,7 @@ protected ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
generatedFetcher,
generatedCalc,
generatedCollector,
generatedPreFilterCondition,
isLeftOuterJoin,
rightRowType.getFieldCount());
} else {
Expand All @@ -568,6 +589,7 @@ protected ProcessFunction<RowData, RowData> createSyncLookupJoinFunction(
new LookupJoinRunner(
generatedFetcher,
generatedCollector,
generatedPreFilterCondition,
isLeftOuterJoin,
rightRowType.getFieldCount());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ public class StreamExecLookupJoin extends CommonExecLookupJoin
public StreamExecLookupJoin(
ReadableConfig tableConfig,
FlinkJoinType joinType,
@Nullable RexNode joinCondition,
@Nullable RexNode preFilterCondition,
@Nullable RexNode remainingJoinCondition,
TemporalTableSourceSpec temporalTableSourceSpec,
Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
@Nullable List<RexNode> projectionOnTemporalTable,
Expand All @@ -118,7 +119,8 @@ public StreamExecLookupJoin(
ExecNodeContext.newContext(StreamExecLookupJoin.class),
ExecNodeContext.newPersistedConfig(StreamExecLookupJoin.class, tableConfig),
joinType,
joinCondition,
preFilterCondition,
remainingJoinCondition,
temporalTableSourceSpec,
lookupKeys,
projectionOnTemporalTable,
Expand All @@ -143,7 +145,9 @@ public StreamExecLookupJoin(
@JsonProperty(FIELD_NAME_TYPE) ExecNodeContext context,
@JsonProperty(FIELD_NAME_CONFIGURATION) ReadableConfig persistedConfig,
@JsonProperty(FIELD_NAME_JOIN_TYPE) FlinkJoinType joinType,
@JsonProperty(FIELD_NAME_JOIN_CONDITION) @Nullable RexNode joinCondition,
@JsonProperty(FIELD_NAME_PRE_FILTER_CONDITION) @Nullable RexNode preFilterCondition,
@JsonProperty(FIELD_NAME_REMAINING_JOIN_CONDITION) @Nullable
RexNode remainingJoinCondition,
@JsonProperty(FIELD_NAME_TEMPORAL_TABLE)
TemporalTableSourceSpec temporalTableSourceSpec,
@JsonProperty(FIELD_NAME_LOOKUP_KEYS) Map<Integer, LookupJoinUtil.LookupKey> lookupKeys,
Expand All @@ -169,7 +173,8 @@ public StreamExecLookupJoin(
context,
persistedConfig,
joinType,
joinCondition,
preFilterCondition,
remainingJoinCondition,
temporalTableSourceSpec,
lookupKeys,
projectionOnTemporalTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,10 @@ public StreamPhysicalRel visit(

// required determinism cannot be satisfied even upsert materialize was enabled if:
// 1. remaining join condition contains non-deterministic call
JavaScalaConversionUtil.toJava(lookupJoin.remainingCondition())
.ifPresent(condi -> checkNonDeterministicCondition(condi, lookupJoin));
JavaScalaConversionUtil.toJava(lookupJoin.finalPreFilterCondition())
.ifPresent(cond -> checkNonDeterministicCondition(cond, lookupJoin));
JavaScalaConversionUtil.toJava(lookupJoin.finalRemainingCondition())
.ifPresent(cond -> checkNonDeterministicCondition(cond, lookupJoin));

// 2. inner calc in lookJoin contains either non-deterministic condition or calls
JavaScalaConversionUtil.toJava(lookupJoin.calcOnTemporalTable())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ class FlinkRelMdUpsertKeys private extends MetadataHandler[UpsertKeys] {
val rightUniqueKeys = FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOfTemporalTable(join)

val remainingConditionNonDeterministic =
join.remainingCondition.exists(c => !RexUtil.isDeterministic(c))
join.finalPreFilterCondition.exists(c => !RexUtil.isDeterministic(c)) ||
join.finalRemainingCondition.exists(c => !RexUtil.isDeterministic(c))
lazy val calcOnTemporalTableNonDeterministic =
join.calcOnTemporalTable.exists(p => !FlinkRexUtil.isDeterministic(p))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class BatchPhysicalLookupJoin(
new BatchExecLookupJoin(
tableConfig,
JoinTypeUtil.getFlinkJoinType(joinType),
remainingCondition.orNull,
finalPreFilterCondition.orNull,
finalRemainingCondition.orNull,
new TemporalTableSourceSpec(temporalTable),
allLookupKeys.map(item => (Int.box(item._1), item._2)).asJava,
projectionOnTemporalTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.FlinkRelNode
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel
import org.apache.flink.table.planner.plan.schema.{IntermediateRelTable, LegacyTableSourceTable, TableSourceTable}
import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, ExpressionFormat, JoinTypeUtil, LookupJoinUtil, RelExplainUtil}
import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, ExpressionFormat, InputRefVisitor, JoinTypeUtil, LookupJoinUtil, RelExplainUtil}
import org.apache.flink.table.planner.plan.utils.ExpressionFormat.ExpressionFormat
import org.apache.flink.table.planner.plan.utils.LookupJoinUtil._
import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall
Expand Down Expand Up @@ -98,11 +98,11 @@ abstract class CommonPhysicalLookupJoin(
// all potential index keys, mapping from field index in table source to LookupKey
analyzeLookupKeys(cluster.getRexBuilder, joinKeyPairs, calcOnTemporalTable)
}
// remaining condition used to filter the joined records (left input record X lookup-ed records)
val remainingCondition: Option[RexNode] = getRemainingJoinCondition(
// split join condition(except the lookup keys) into pre-filter(used to filter the left input
// before lookup) and remaining parts(used to filter the joined records)
val (finalPreFilterCondition, finalRemainingCondition) = splitJoinCondition(
cluster.getRexBuilder,
inputRel.getRowType,
calcOnTemporalTable,
allLookupKeys.values.toList,
joinInfo)

Expand Down Expand Up @@ -195,12 +195,9 @@ abstract class CommonPhysicalLookupJoin(
.itemIf("where", whereString, whereString.nonEmpty)
.itemIf(
"joinCondition",
joinConditionToString(
resultFieldNames,
remainingCondition,
preferExpressionFormat(pw),
pw.getDetailLevel),
remainingCondition.isDefined)
joinConditionToString(resultFieldNames, preferExpressionFormat(pw), pw.getDetailLevel),
finalRemainingCondition.isDefined || finalPreFilterCondition.isDefined
)
.item("select", selection)
.itemIf("upsertMaterialize", "true", upsertMaterialize)
.itemIf("async", asyncOptions.getOrElse(""), asyncOptions.isDefined)
Expand All @@ -217,13 +214,15 @@ abstract class CommonPhysicalLookupJoin(
case _ => ChangelogMode.insertOnly()
}

/** Gets the remaining join condition which is used */
private def getRemainingJoinCondition(
/**
* Splits the remaining condition in joinInfo into the pre-filter(used to filter the left input
* before lookup) and remaining parts(used to filter the joined records).
*/
private def splitJoinCondition(
rexBuilder: RexBuilder,
leftRelDataType: RelDataType,
calcOnTemporalTable: Option[RexProgram],
leftKeys: List[LookupKey],
joinInfo: JoinInfo): Option[RexNode] = {
joinInfo: JoinInfo): (Option[RexNode], Option[RexNode]) = {
// indexes of left key fields
val leftKeyIndexes =
leftKeys
Expand All @@ -244,9 +243,31 @@ abstract class CommonPhysicalLookupJoin(
val rightInputRef = new RexInputRef(rightIndex, rightFieldType)
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, leftInputRef, rightInputRef)
}
val remainingAnds = remainingEquals ++ joinInfo.nonEquiConditions.asScala
// build a new condition
val condition = RexUtil.composeConjunction(rexBuilder, remainingAnds.toList.asJava)
if (joinType.generatesNullsOnRight) {
// only extract pre-filter for left & full outer joins(otherwise the pre-filter will always be pushed down)
val (leftLocal, remaining) =
joinInfo.nonEquiConditions.asScala.partition {
r =>
{
val inputRefs = new InputRefVisitor()
r.accept(inputRefs)
// if all input refs belong to left
inputRefs.getFields.forall(idx => idx < inputRel.getRowType.getFieldCount)
}
}
val remainingAnds = remainingEquals ++ remaining
// build final pre-filter and remaining conditions
(
composeCondition(rexBuilder, leftLocal.toList),
composeCondition(rexBuilder, remainingAnds.toList))
} else {
val remainingAnds = remainingEquals ++ joinInfo.nonEquiConditions.asScala
(None, composeCondition(rexBuilder, remainingAnds.toList))
}
}

private def composeCondition(rexBuilder: RexBuilder, rexNodes: List[RexNode]): Option[RexNode] = {
val condition = RexUtil.composeConjunction(rexBuilder, rexNodes.asJava)
if (condition.isAlwaysTrue) {
None
} else {
Expand Down Expand Up @@ -466,16 +487,30 @@ abstract class CommonPhysicalLookupJoin(

private def joinConditionToString(
resultFieldNames: Array[String],
joinCondition: Option[RexNode],
expressionFormat: ExpressionFormat = ExpressionFormat.Prefix,
sqlExplainLevel: SqlExplainLevel): String = joinCondition match {
case Some(condition) =>
getExpressionString(
condition,
resultFieldNames.toList,
None,
expressionFormat,
sqlExplainLevel)
case None => "N/A"
sqlExplainLevel: SqlExplainLevel): String = {

def appendCondition(sb: StringBuilder, cond: Option[RexNode]): Unit = {
cond match {
case Some(condition) =>
sb.append(
getExpressionString(
condition,
resultFieldNames.toList,
None,
expressionFormat,
sqlExplainLevel))
case _ =>
}
}

if (finalPreFilterCondition.isEmpty && finalRemainingCondition.isEmpty) {
"N/A"
} else {
val sb = new StringBuilder
appendCondition(sb, finalPreFilterCondition)
appendCondition(sb, finalRemainingCondition)
sb.toString()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ class StreamPhysicalLookupJoin(
new StreamExecLookupJoin(
tableConfig,
JoinTypeUtil.getFlinkJoinType(joinType),
remainingCondition.orNull,
finalPreFilterCondition.orNull,
finalRemainingCondition.orNull,
new TemporalTableSourceSpec(temporalTable),
allLookupKeys.map(item => (Int.box(item._1), item._2)).asJava,
projectionOnTemporalTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,30 @@ public void testJoinTemporalTableWithAsyncRetryHint2() {
+ "FROM MyTable AS T JOIN LookupTable "
+ "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id");
}

@Test
public void testLeftJoinTemporalTableWithPreFilter() {
util.verifyJsonPlan(
"INSERT INTO MySink1 SELECT * "
+ "FROM MyTable AS T LEFT JOIN LookupTable "
+ "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id AND T.b > 'abc'");
}

@Test
public void testLeftJoinTemporalTableWithPostFilter() {
util.verifyJsonPlan(
"INSERT INTO MySink1 SELECT * "
+ "FROM MyTable AS T LEFT JOIN LookupTable "
+ "FOR SYSTEM_TIME AS OF T.proctime AS D ON T.a = D.id "
+ "AND CHAR_LENGTH(D.name) > CHAR_LENGTH(T.b)");
}

@Test
public void testLeftJoinTemporalTableWithMultiJoinConditions() {
util.verifyJsonPlan(
"INSERT INTO MySink1 SELECT * "
+ "FROM MyTable AS T LEFT JOIN LookupTable "
+ "FOR SYSTEM_TIME AS OF T.proctime AS D "
+ "ON T.a = D.id AND T.b > 'abc' AND T.b <> D.name AND T.c = 100");
}
}
Loading

0 comments on commit 360b97a

Please sign in to comment.