Skip to content

Commit

Permalink
[FLINK-29280][table-planner] Fix join hints could not be propagated i…
Browse files Browse the repository at this point in the history
…n subquery

This closes apache#20823
  • Loading branch information
xuyangzhong authored and godfreyhe committed Sep 19, 2022
1 parent a02b2c2 commit 22cb554
Show file tree
Hide file tree
Showing 11 changed files with 1,215 additions and 277 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.calcite.sql2rel;

import org.apache.flink.table.planner.alias.ClearJoinHintWithInvalidPropagationShuttle;
import org.apache.flink.table.planner.hint.FlinkHints;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -207,6 +208,9 @@ public static RelNode decorrelateQuery(RelNode rootRel, RelBuilder relBuilder) {

// ----- FLINK MODIFICATION BEGIN -----

// replace all join hints with upper case
newRootRel = FlinkHints.capitalizeJoinHints(newRootRel);

// clear join hints which are propagated into wrong query block
// The hint QueryBlockAlias will be added when building a RelNode tree before. It is used to
// distinguish the query block in the SQL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.hint.Hintable;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalFilter;
Expand Down Expand Up @@ -492,6 +493,7 @@ public Frame decorrelateRel(LogicalProject rel) {
}
}
RelNode newProject = RelOptUtil.createProject(newInput, projects, false);
newProject = ((LogicalProject) newProject).withHints(rel.getHints());

final RexNode newCorCondition;
if (frame.c != null) {
Expand Down Expand Up @@ -544,11 +546,13 @@ public Frame decorrelateRel(LogicalFilter rel) {

// Using LogicalFilter.create instead of RelBuilder.filter to create Filter
// because RelBuilder.filter method does not have VariablesSet arg.
final LogicalFilter newFilter =
final RelNode newFilter =
LogicalFilter.create(
frame.r,
remainingCondition,
com.google.common.collect.ImmutableSet.copyOf(rel.getVariablesSet()));
frame.r,
remainingCondition,
com.google.common.collect.ImmutableSet.copyOf(
rel.getVariablesSet()))
.withHints(rel.getHints());

// Adds input's correlation condition
if (frame.c != null) {
Expand Down Expand Up @@ -705,7 +709,8 @@ public Frame decorrelateRel(LogicalAggregate rel) {
}

relBuilder.push(
LogicalAggregate.create(newProject, false, newGroupSet, null, newAggCalls));
LogicalAggregate.create(
newProject, rel.getHints(), newGroupSet, null, newAggCalls));

if (!omittedConstants.isEmpty()) {
final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
Expand Down Expand Up @@ -876,7 +881,9 @@ public Frame decorrelateRel(Sort rel) {
RelCollation oldCollation = rel.getCollation();
RelCollation newCollation = RexUtil.apply(mapping, oldCollation);

final Sort newSort = LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch);
final RelNode newSort =
LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch)
.withHints(rel.getHints());

// Sort does not change input ordering
return new Frame(rel, newSort, frame.c, frame.oldToNewOutputs);
Expand Down Expand Up @@ -917,6 +924,9 @@ public Frame decorrelateRel(RelNode rel) {
if (!Util.equalShallow(oldInputs, newInputs)) {
newRel = rel.copy(rel.getTraitSet(), newInputs);
}
if (rel instanceof Hintable) {
newRel = ((Hintable) newRel).withHints(((Hintable) rel).getHints());
}
}
// the output position should not change since there are no corVars coming from below.
return new Frame(rel, newRel, null, identityMap(rel.getRowType().getFieldCount()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import javax.annotation.Nullable

import java.lang.{Boolean => JBoolean}
import java.util
import java.util.Locale
import java.util.function.{Function => JFunction}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -250,7 +251,7 @@ class FlinkPlannerImpl(
JavaScalaConversionUtil.toScala(hints).foreach {
case hint: SqlHint =>
val hintName = hint.getName
if (JoinStrategy.isJoinStrategy(hintName)) {
if (JoinStrategy.isJoinStrategy(hintName.toUpperCase(Locale.ROOT))) {
return true
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
*/
package org.apache.flink.table.planner.plan.rules.logical

import org.apache.flink.table.planner.alias.ClearJoinHintWithInvalidPropagationShuttle
import org.apache.flink.table.planner.calcite.{FlinkRelBuilder, FlinkRelFactories}
import org.apache.flink.table.planner.hint.FlinkHints

import com.google.common.collect.ImmutableList
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand, RelOptUtil}
import org.apache.calcite.plan.RelOptRule._
import org.apache.calcite.plan.RelOptUtil.Logic
import org.apache.calcite.rel.{RelNode, RelShuttleImpl}
Expand Down Expand Up @@ -93,7 +95,14 @@ class FlinkSubQueryRemoveRule(
relBuilder.filter(c)
}
relBuilder.project(fields(relBuilder, filter.getRowType.getFieldCount))
call.transformTo(relBuilder.build)
// the sub query has been replaced with a common node,
// so hints in it should also be resolved with the same logic in SqlToRelConverter
val newNode = relBuilder.build
val nodeWithHint = RelOptUtil.propagateRelHints(newNode, false)
val nodeWithCapitalizedJoinHints = FlinkHints.capitalizeJoinHints(nodeWithHint)
val finalNode =
nodeWithCapitalizedJoinHints.accept(new ClearJoinHintWithInvalidPropagationShuttle)
call.transformTo(finalNode)
case _ => // do nothing
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,46 @@ public void testJoinHintWithoutCaseSensitive() {
verifyRelPlanByCustom(String.format(sql, buildCaseSensitiveStr(getTestSingleJoinHint())));
}

@Test
public void testJoinHintWithJoinHintInSubQuery() {
String sql =
"select * from T1 WHERE a1 IN (select /*+ %s(T2) */ a2 from T2 join T3 on T2.a2 = T3.a3)";

verifyRelPlanByCustom(String.format(sql, buildCaseSensitiveStr(getTestSingleJoinHint())));
}

@Test
public void testJoinHintWithJoinHintInCorrelateAndWithFilter() {
String sql =
"select * from T1 WHERE a1 IN (select /*+ %s(T2) */ a2 from T2 join T3 on T2.a2 = T3.a3 where T1.a1 = T2.a2)";

verifyRelPlanByCustom(String.format(sql, buildCaseSensitiveStr(getTestSingleJoinHint())));
}

@Test
public void testJoinHintWithJoinHintInCorrelateAndWithProject() {
String sql =
"select * from T1 WHERE a1 IN (select /*+ %s(T2) */ a2 + T1.a1 from T2 join T3 on T2.a2 = T3.a3)";

verifyRelPlanByCustom(String.format(sql, buildCaseSensitiveStr(getTestSingleJoinHint())));
}

@Test
public void testJoinHintWithJoinHintInCorrelateAndWithAgg() {
String sql =
"select * from T1 WHERE a1 IN (select /*+ %s(T2) */ count(T2.a2) from T2 join T1 on T2.a2 = T1.a1 group by T1.a1)";

verifyRelPlanByCustom(String.format(sql, buildCaseSensitiveStr(getTestSingleJoinHint())));
}

@Test
public void testJoinHintWithJoinHintInCorrelateAndWithSortLimit() {
String sql =
"select * from T1 WHERE a1 IN (select /*+ %s(T2) */ T2.a2 from T2 join T1 on T2.a2 = T1.a1 order by T1.a1 limit 10)";

verifyRelPlanByCustom(String.format(sql, buildCaseSensitiveStr(getTestSingleJoinHint())));
}

protected String buildAstPlanWithQueryBlockAlias(List<RelNode> relNodes) {
StringBuilder astBuilder = new StringBuilder();
relNodes.forEach(
Expand Down
Loading

0 comments on commit 22cb554

Please sign in to comment.