Skip to content

Commit

Permalink
[FLINK-24239] Event time temporal join should support values from arr…
Browse files Browse the repository at this point in the history
…ay, map, row, etc. as join key (apache#24253)
  • Loading branch information
dawidwys authored Feb 8, 2024
1 parent 1b95b19 commit 01cdc70
Show file tree
Hide file tree
Showing 7 changed files with 1,336 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ import org.apache.flink.table.planner.plan.schema.{LegacyTableSourceTable, Table
import org.apache.flink.table.planner.plan.utils.TemporalJoinUtil
import org.apache.flink.table.sources.LookupableTableSource

import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand, RelOptUtil}
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.hep.{HepPlanner, HepRelVertex}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.TableScan
import org.apache.calcite.rel.core.{CorrelationId, TableScan}
import org.apache.calcite.rel.logical._
import org.apache.calcite.rex._

Expand Down Expand Up @@ -141,6 +142,30 @@ abstract class LogicalCorrelateToJoinFromTemporalTableRule(
}
case _ => false
}

protected def decorrelate(
rexNode: RexNode,
leftRowType: RelDataType,
correlationId: CorrelationId): RexNode = {
rexNode.accept(new RexShuttle() {
// change correlate variable expression to normal RexInputRef (which is from left side)
override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
fieldAccess.getReferenceExpr match {
case corVar: RexCorrelVariable =>
require(correlationId.equals(corVar.id))
val index = leftRowType.getFieldList.indexOf(fieldAccess.getField)
RexInputRef.of(index, leftRowType)
case _ => super.visitFieldAccess(fieldAccess)
}
}

// update the field index from right side
override def visitInputRef(inputRef: RexInputRef): RexNode = {
val rightIndex = leftRowType.getFieldCount + inputRef.getIndex
new RexInputRef(rightIndex, inputRef.getType)
}
})
}
}

/**
Expand All @@ -161,24 +186,7 @@ abstract class LogicalCorrelateToJoinFromLookupTemporalTableRule(
validateSnapshotInCorrelate(snapshot, correlate)

val leftRowType = leftInput.getRowType
val joinCondition = filterCondition.accept(new RexShuttle() {
// change correlate variable expression to normal RexInputRef (which is from left side)
override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
fieldAccess.getReferenceExpr match {
case corVar: RexCorrelVariable =>
require(correlate.getCorrelationId.equals(corVar.id))
val index = leftRowType.getFieldList.indexOf(fieldAccess.getField)
RexInputRef.of(index, leftRowType)
case _ => super.visitFieldAccess(fieldAccess)
}
}

// update the field index from right side
override def visitInputRef(inputRef: RexInputRef): RexNode = {
val rightIndex = leftRowType.getFieldCount + inputRef.getIndex
new RexInputRef(rightIndex, inputRef.getType)
}
})
val joinCondition = decorrelate(filterCondition, leftRowType, correlate.getCorrelationId)

val builder = call.builder()
builder.push(leftInput)
Expand All @@ -198,8 +206,8 @@ abstract class LogicalCorrelateToJoinFromGeneralTemporalTableRule(

protected def extractRightEventTimeInputRef(
leftInput: RelNode,
snapshot: LogicalSnapshot): Option[RexNode] = {
val rightFields = snapshot.getRowType.getFieldList.asScala
rightInput: RelNode): Option[RexNode] = {
val rightFields = rightInput.getRowType.getFieldList.asScala
val timeAttributeFields = rightFields.filter(
f =>
f.getType.isInstanceOf[TimeIndicatorRelDataType] &&
Expand All @@ -209,7 +217,7 @@ abstract class LogicalCorrelateToJoinFromGeneralTemporalTableRule(
val timeColIndex = leftInput.getRowType.getFieldCount +
rightFields.indexOf(timeAttributeFields.get(0))
val timeColDataType = timeAttributeFields.get(0).getType
val rexBuilder = snapshot.getCluster.getRexBuilder
val rexBuilder = rightInput.getCluster.getRexBuilder
Some(rexBuilder.makeInputRef(timeColDataType, timeColIndex))
} else {
None
Expand Down Expand Up @@ -237,57 +245,32 @@ abstract class LogicalCorrelateToJoinFromGeneralTemporalTableRule(
val snapshot = getLogicalSnapshot(call)

val leftRowType = leftInput.getRowType
val joinCondition = filterCondition.accept(new RexShuttle() {
// change correlate variable expression to normal RexInputRef (which is from left side)
override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
fieldAccess.getReferenceExpr match {
case corVar: RexCorrelVariable =>
require(correlate.getCorrelationId.equals(corVar.id))
val index = leftRowType.getFieldList.indexOf(fieldAccess.getField)
RexInputRef.of(index, leftRowType)
case _ => super.visitFieldAccess(fieldAccess)
}
}

// update the field index from right side
override def visitInputRef(inputRef: RexInputRef): RexNode = {
val rightIndex = leftRowType.getFieldCount + inputRef.getIndex
new RexInputRef(rightIndex, inputRef.getType)
}
})
val joinCondition = decorrelate(filterCondition, leftRowType, correlate.getCorrelationId)

validateSnapshotInCorrelate(snapshot, correlate)

val rexBuilder = correlate.getCluster.getRexBuilder
val (leftJoinKey, rightJoinKey) = {
val relBuilder = call.builder()
relBuilder.push(leftInput)
relBuilder.push(snapshot)
val rewriteJoin = relBuilder.join(correlate.getJoinType, joinCondition).build()
val joinInfo = rewriteJoin.asInstanceOf[LogicalJoin].analyzeCondition()
val leftJoinKey = joinInfo.leftKeys.map(i => rexBuilder.makeInputRef(leftInput, i))
val leftFieldCnt = leftInput.getRowType.getFieldCount
val rightJoinKey = joinInfo.rightKeys.map(
i => {
val leftKeyType = snapshot.getRowType.getFieldList.get(i).getType
rexBuilder.makeInputRef(leftKeyType, leftFieldCnt + i)
})
if (leftJoinKey.length == 0 || rightJoinKey.length == 0) {
throw new ValidationException(
"Currently the join key in Temporal Table Join " +
"can not be empty.")
}
(leftJoinKey, rightJoinKey)
val relBuilder = call.builder()
relBuilder.push(leftInput)
relBuilder.push(snapshot)
val nonPushedJoin =
relBuilder.join(correlate.getJoinType, joinCondition).build().asInstanceOf[LogicalJoin]
val rewriteJoin = RelOptUtil.pushDownJoinConditions(nonPushedJoin, relBuilder)
val actualJoin = rewriteJoin match {
case _: LogicalJoin => rewriteJoin.asInstanceOf[LogicalJoin]
case _ => rewriteJoin.getInput(0).asInstanceOf[LogicalJoin]
}

val snapshotTimeInputRef = extractSnapshotTimeInputRef(leftInput, snapshot)
val (leftJoinKey, rightJoinKey) = extractJoinKeys(actualJoin)

val snapshotTimeInputRef = extractSnapshotTimeInputRef(actualJoin.getLeft, snapshot)
.getOrElse(
throw new ValidationException(
"Temporal Table Join requires time " +
"attribute in the left table, but no time attribute found."))

val temporalCondition = if (isRowTimeTemporalTableJoin(snapshot)) {
val rightTimeInputRef = extractRightEventTimeInputRef(leftInput, snapshot)
val rightTimeInputRef = extractRightEventTimeInputRef(actualJoin.getLeft, actualJoin.getRight)
if (rightTimeInputRef.isEmpty || !isRowtimeIndicatorType(rightTimeInputRef.get.getType)) {
throw new ValidationException(
"Event-Time Temporal Table Join requires both" +
Expand Down Expand Up @@ -323,15 +306,47 @@ abstract class LogicalCorrelateToJoinFromGeneralTemporalTableRule(
}

val builder = call.builder()
val condition = builder.and(joinCondition, temporalCondition)

builder.push(leftInput)
builder.push(snapshot)
builder.join(correlate.getJoinType, condition)
val temporalJoin = builder.build()
val condition = builder.and(actualJoin.getCondition, temporalCondition)

val joinWithTemporalCondition = actualJoin.copy(
actualJoin.getTraitSet,
condition,
actualJoin.getLeft,
actualJoin.getRight,
actualJoin.getJoinType,
actualJoin.isSemiJoinDone)

val temporalJoin = if (actualJoin != rewriteJoin) {
rewriteJoin.replaceInput(0, joinWithTemporalCondition)
rewriteJoin
} else {
joinWithTemporalCondition
}
call.transformTo(temporalJoin)
}

private def extractJoinKeys(actualJoin: LogicalJoin): (Seq[RexNode], Seq[RexNode]) = {

val joinInfo = actualJoin.analyzeCondition()
val leftInput = actualJoin.getInput(0)
val rightInput = actualJoin.getInput(1)
val rexBuilder = actualJoin.getCluster.getRexBuilder

val leftJoinKey = joinInfo.leftKeys.map(i => rexBuilder.makeInputRef(leftInput, i))
val leftFieldCnt = leftInput.getRowType.getFieldCount
val rightJoinKey = joinInfo.rightKeys.map(
i => {
val rightKeyType = rightInput.getRowType.getFieldList.get(i).getType
rexBuilder.makeInputRef(rightKeyType, leftFieldCnt + i)
})
if (leftJoinKey.isEmpty || rightJoinKey.isEmpty) {
throw new ValidationException(
"Currently the join key in Temporal Table Join " +
"can not be empty.")
}
(leftJoinKey, rightJoinKey)
}

private def isRowTimeTemporalTableJoin(snapshot: LogicalSnapshot): Boolean =
snapshot.getPeriod.getType match {
case t: TimeIndicatorRelDataType if t.isEventTime => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public TemporalJoinRestoreTest() {
public List<TableTestProgram> programs() {
return Arrays.asList(
TemporalJoinTestPrograms.TEMPORAL_JOIN_TABLE_JOIN,
TemporalJoinTestPrograms.TEMPORAL_JOIN_TABLE_JOIN_NESTED_KEY,
TemporalJoinTestPrograms.TEMPORAL_JOIN_TABLE_JOIN_KEY_FROM_MAP,
TemporalJoinTestPrograms.TEMPORAL_JOIN_TEMPORAL_FUNCTION);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.flink.table.test.program.TableTestProgram;
import org.apache.flink.types.Row;

import java.util.HashMap;
import java.util.Map;

import static org.apache.flink.table.api.Expressions.$;

/** {@link TableTestProgram} definitions for testing {@link StreamExecTemporalJoin}. */
Expand All @@ -45,6 +48,49 @@ public class TemporalJoinTestPrograms {
Row.of(1L, "USD", "2020-10-10 00:00:58"))
.build();

static final SourceTestStep ORDERS_WITH_NESTED_ID =
SourceTestStep.newBuilder("OrdersNestedId")
.addSchema(
"amount bigint",
"nested_row ROW<currency STRING>",
"nested_map MAP<STRING NOT NULL, STRING>",
"order_time STRING",
"rowtime as TO_TIMESTAMP(order_time) ",
"WATERMARK FOR rowtime AS rowtime")
.producedBeforeRestore(
Row.of(
2L,
Row.of("Euro"),
mapOf("currency", "Euro"),
"2020-10-10 00:00:42"),
Row.of(
1L,
Row.of("usd"),
mapOf("currency", "USD"),
"2020-10-10 00:00:43"),
Row.of(
50L,
Row.of("Yen"),
mapOf("currency", "Yen"),
"2020-10-10 00:00:44"),
Row.of(
3L,
Row.of("Euro"),
mapOf("currency", "Euro"),
"2020-10-10 00:00:45"))
.producedAfterRestore(
Row.of(
1L,
Row.of("Euro"),
mapOf("currency", "Euro"),
"2020-10-10 00:00:58"),
Row.of(
1L,
Row.of("usd"),
mapOf("currency", "USD"),
"2020-10-10 00:00:58"))
.build();

static final SourceTestStep RATES =
SourceTestStep.newBuilder("RatesHistory")
.addSchema(
Expand Down Expand Up @@ -84,6 +130,36 @@ public class TemporalJoinTestPrograms {
+ "ON o.currency = r.currency ")
.build();

static final TableTestProgram TEMPORAL_JOIN_TABLE_JOIN_NESTED_KEY =
TableTestProgram.of(
"temporal-join-table-join-nested-key",
"validates temporal join with a table when the join keys comes from a nested row")
.setupTableSource(ORDERS_WITH_NESTED_ID)
.setupTableSource(RATES)
.setupTableSink(AMOUNTS)
.runSql(
"INSERT INTO MySink "
+ "SELECT amount * r.rate "
+ "FROM OrdersNestedId AS o "
+ "JOIN RatesHistory FOR SYSTEM_TIME AS OF o.rowtime AS r "
+ "ON (case when o.nested_row.currency = 'usd' then upper(o.nested_row.currency) ELSE o.nested_row.currency END) = r.currency ")
.build();

static final TableTestProgram TEMPORAL_JOIN_TABLE_JOIN_KEY_FROM_MAP =
TableTestProgram.of(
"temporal-join-table-join-key-from-map",
"validates temporal join with a table when the join key comes from a map value")
.setupTableSource(ORDERS_WITH_NESTED_ID)
.setupTableSource(RATES)
.setupTableSink(AMOUNTS)
.runSql(
"INSERT INTO MySink "
+ "SELECT amount * r.rate "
+ "FROM OrdersNestedId AS o "
+ "JOIN RatesHistory FOR SYSTEM_TIME AS OF o.rowtime AS r "
+ "ON o.nested_map['currency'] = r.currency ")
.build();

static final TableTestProgram TEMPORAL_JOIN_TEMPORAL_FUNCTION =
TableTestProgram.of(
"temporal-join-temporal-function",
Expand All @@ -100,4 +176,10 @@ public class TemporalJoinTestPrograms {
+ "LATERAL TABLE (Rates(o.rowtime)) AS r "
+ "WHERE o.currency = r.currency ")
.build();

private static Map<String, String> mapOf(String key, String value) {
final HashMap<String, String> map = new HashMap<>();
map.put(key, value);
return map;
}
}
Loading

0 comments on commit 01cdc70

Please sign in to comment.