Skip to content

Commit

Permalink
[FLINK-23054][table] Join unique/pk optimization should based on upse…
Browse files Browse the repository at this point in the history
…rt key
  • Loading branch information
JingsongLi committed Jun 29, 2021
1 parent 4642c1f commit 992340e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
package org.apache.flink.table.planner.plan.nodes.physical.stream

import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}
import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, InputProperty}
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecJoin
import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalJoin
import org.apache.flink.table.planner.plan.utils.JoinUtil

import org.apache.calcite.plan._
import org.apache.calcite.rel.core.{Join, JoinRelType}
import org.apache.calcite.rel.core.{Exchange, Join, JoinRelType}
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rex.RexNode
Expand Down Expand Up @@ -67,11 +68,11 @@ class StreamPhysicalJoin(
*/
def inputUniqueKeyContainsJoinKey(inputOrdinal: Int): Boolean = {
val input = getInput(inputOrdinal)
val inputUniqueKeys = getCluster.getMetadataQuery.getUniqueKeys(input)
val joinKeys = if (inputOrdinal == 0) joinSpec.getLeftKeys else joinSpec.getRightKeys
val inputUniqueKeys = getUniqueKeys(input, joinKeys)
if (inputUniqueKeys != null) {
val joinKeys = if (inputOrdinal == 0) joinSpec.getLeftKeys else joinSpec.getRightKeys
inputUniqueKeys.exists {
uniqueKey => joinKeys.forall(uniqueKey.toArray.contains(_))
uniqueKey => joinKeys.forall(uniqueKey.contains(_))
}
} else {
false
Expand All @@ -98,21 +99,22 @@ class StreamPhysicalJoin(
JoinUtil.analyzeJoinInput(
InternalTypeInfo.of(FlinkTypeFactory.toLogicalRowType(left.getRowType)),
joinSpec.getLeftKeys,
getUniqueKeys(left)))
getUniqueKeys(left, joinSpec.getLeftKeys)))
.item(
"rightInputSpec",
JoinUtil.analyzeJoinInput(
InternalTypeInfo.of(FlinkTypeFactory.toLogicalRowType(right.getRowType)),
joinSpec.getRightKeys,
getUniqueKeys(right)))
getUniqueKeys(right, joinSpec.getRightKeys)))
}

private def getUniqueKeys(input: RelNode): List[Array[Int]] = {
val uniqueKeys = cluster.getMetadataQuery.getUniqueKeys(input)
if (uniqueKeys == null || uniqueKeys.isEmpty) {
private def getUniqueKeys(input: RelNode, keys: Array[Int]): List[Array[Int]] = {
val upsertKeys = FlinkRelMetadataQuery.reuseOrCreate(cluster.getMetadataQuery)
.getUpsertKeysInKeyGroupRange(input, keys)
if (upsertKeys == null || upsertKeys.isEmpty) {
List.empty
} else {
uniqueKeys.map(_.asList.map(_.intValue).toArray).toList
upsertKeys.map(_.asList.map(_.intValue).toArray).toList
}

}
Expand All @@ -125,8 +127,8 @@ class StreamPhysicalJoin(
override def translateToExecNode(): ExecNode[_] = {
new StreamExecJoin(
joinSpec,
getUniqueKeys(left),
getUniqueKeys(right),
getUniqueKeys(left, joinSpec.getLeftKeys),
getUniqueKeys(right, joinSpec.getRightKeys),
InputProperty.DEFAULT,
InputProperty.DEFAULT,
FlinkTypeFactory.toLogicalRowType(getRowType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,49 @@ Calc(select=[a1])
: +- LegacyTableSourceScan(table=[[default_catalog, default_database, A, source: [TestTableSource(a1, a2, a3)]]], fields=[a1, a2, a3])
+- Exchange(distribution=[hash[pk1]])
+- TableSourceScan(table=[[default_catalog, default_database, tableWithCompositePk, project=[pk1]]], fields=[pk1])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinDisorderChangeLog">
<Resource name="sql">
<![CDATA[
SELECT T1.person, T1.sum_votes, T1.prize, T2.age FROM
(SELECT T.person, T.sum_votes, award.prize FROM
(SELECT person, SUM(votes) AS sum_votes FROM src GROUP BY person) T,
award
WHERE T.sum_votes = award.votes) T1, people T2
WHERE T1.person = T2.person
]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(person=[$0], sum_votes=[$1], prize=[$2], age=[$4])
+- LogicalFilter(condition=[=($0, $3)])
+- LogicalJoin(condition=[true], joinType=[inner])
:- LogicalProject(person=[$0], sum_votes=[$1], prize=[$3])
: +- LogicalFilter(condition=[=($1, $2)])
: +- LogicalJoin(condition=[true], joinType=[inner])
: :- LogicalAggregate(group=[{0}], sum_votes=[SUM($1)])
: : +- LogicalTableScan(table=[[default_catalog, default_database, src]])
: +- LogicalTableScan(table=[[default_catalog, default_database, award]])
+- LogicalTableScan(table=[[default_catalog, default_database, people]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[person, sum_votes, prize, age])
+- Join(joinType=[InnerJoin], where=[(person = person0)], select=[person, sum_votes, prize, person0, age], leftInputSpec=[NoUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])
:- Exchange(distribution=[hash[person]])
: +- Calc(select=[person, sum_votes, prize])
: +- Join(joinType=[InnerJoin], where=[(sum_votes = votes)], select=[person, sum_votes, votes, prize], leftInputSpec=[HasUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])
: :- Exchange(distribution=[hash[sum_votes]])
: : +- GroupAggregate(groupBy=[person], select=[person, SUM(votes) AS sum_votes])
: : +- Exchange(distribution=[hash[person]])
: : +- TableSourceScan(table=[[default_catalog, default_database, src]], fields=[person, votes])
: +- Exchange(distribution=[hash[votes]])
: +- TableSourceScan(table=[[default_catalog, default_database, award]], fields=[votes, prize])
+- Exchange(distribution=[hash[person]])
+- TableSourceScan(table=[[default_catalog, default_database, people]], fields=[person, age])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,38 @@ class JoinTest extends TableTestBase {
|""".stripMargin)
util.verifyExecPlan("SELECT A.a1 FROM A LEFT JOIN tableWithCompositePk T ON A.a1 = T.pk1")
}

@Test
def testJoinDisorderChangeLog(): Unit = {
util.tableEnv.executeSql(
"""
|CREATE TABLE src (person String, votes BIGINT) WITH(
| 'connector' = 'values'
|)
|""".stripMargin)

util.tableEnv.executeSql(
"""
|CREATE TABLE award (votes BIGINT, prize DOUBLE, PRIMARY KEY(votes) NOT ENFORCED) WITH(
| 'connector' = 'values'
|)
|""".stripMargin)

util.tableEnv.executeSql(
"""
|CREATE TABLE people (person STRING, age INT, PRIMARY KEY(person) NOT ENFORCED) WITH(
| 'connector' = 'values'
|)
|""".stripMargin)

util.verifyExecPlan(
"""
|SELECT T1.person, T1.sum_votes, T1.prize, T2.age FROM
| (SELECT T.person, T.sum_votes, award.prize FROM
| (SELECT person, SUM(votes) AS sum_votes FROM src GROUP BY person) T,
| award
| WHERE T.sum_votes = award.votes) T1, people T2
| WHERE T1.person = T2.person
|""".stripMargin)
}
}

0 comments on commit 992340e

Please sign in to comment.