Skip to content

Commit

Permalink
[CALCITE-5655] Intermediate table alias should be different to avoid …
Browse files Browse the repository at this point in the history
…wrong field reference lookup in subquery remove phase

Below tests are not affacted by this bug, we added them to improve the test coverage:
RelOptRulesTest#testExpandProjectInWithTwoCorrelatedSubQueries
RelOptRulesTest#testExpandProjectInWithTwoSubQueries

Close #3159
  • Loading branch information
herunkang2018 authored and libenchao committed Apr 20, 2023
1 parent b2917b3 commit e2028ad
Show file tree
Hide file tree
Showing 5 changed files with 494 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ protected SubQueryRemoveRule(Config config) {

protected RexNode apply(RexSubQuery e, Set<CorrelationId> variablesSet,
RelOptUtil.Logic logic,
RelBuilder builder, int inputCount, int offset) {
RelBuilder builder, int inputCount, int offset, int subQueryIndex) {
switch (e.getKind()) {
case SCALAR_QUERY:
return rewriteScalarQuery(e, variablesSet, builder, inputCount, offset);
Expand All @@ -98,9 +98,9 @@ protected RexNode apply(RexSubQuery e, Set<CorrelationId> variablesSet,
return rewriteCollection(e, variablesSet, builder,
inputCount, offset);
case SOME:
return rewriteSome(e, variablesSet, builder);
return rewriteSome(e, variablesSet, builder, subQueryIndex);
case IN:
return rewriteIn(e, variablesSet, logic, builder, offset);
return rewriteIn(e, variablesSet, logic, builder, offset, subQueryIndex);
case EXISTS:
return rewriteExists(e, variablesSet, logic, builder);
case UNIQUE:
Expand Down Expand Up @@ -161,13 +161,14 @@ private static RexNode rewriteCollection(RexSubQuery e,
/**
* Rewrites a SOME sub-query into a {@link Join}.
*
* @param e SOME sub-query to rewrite
* @param builder Builder
* @param e SOME sub-query to rewrite
* @param builder Builder
* @param subQueryIndex sub-query index in multiple sub-queries
*
* @return Expression that may be used to replace the RexSubQuery
*/
private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSet,
RelBuilder builder) {
RelBuilder builder, int subQueryIndex) {
// Most general case, where the left and right keys might have nulls, and
// caller requires 3-valued logic return.
//
Expand Down Expand Up @@ -213,6 +214,11 @@ private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSe
? SqlStdOperatorTable.MIN
: SqlStdOperatorTable.MAX;

String qAlias = "q";
if (subQueryIndex != 0) {
qAlias = "q" + subQueryIndex;
}

if (variablesSet.isEmpty()) {
switch (op.comparisonKind) {
case GREATER_THAN_OR_EQUAL:
Expand Down Expand Up @@ -241,21 +247,21 @@ private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSe
builder.aggregateCall(minMax, builder.field(0)).as("m"),
builder.count(false, "c"),
builder.count(false, "d", builder.field(0)))
.as("q")
.as(qAlias)
.join(JoinRelType.INNER);
caseRexNode =
builder.call(SqlStdOperatorTable.CASE,
builder.equals(builder.field("q", "c"), builder.literal(0)),
builder.equals(builder.field(qAlias, "c"), builder.literal(0)),
literalFalse,
builder.call(SqlStdOperatorTable.IS_TRUE,
builder.call(RexUtil.op(op.comparisonKind),
e.operands.get(0), builder.field("q", "m"))),
e.operands.get(0), builder.field(qAlias, "m"))),
literalTrue,
builder.greaterThan(builder.field("q", "c"),
builder.field("q", "d")),
builder.greaterThan(builder.field(qAlias, "c"),
builder.field(qAlias, "d")),
literalUnknown,
builder.call(RexUtil.op(op.comparisonKind),
e.operands.get(0), builder.field("q", "m")));
e.operands.get(0), builder.field(qAlias, "m")));
break;

case NOT_EQUALS:
Expand Down Expand Up @@ -284,7 +290,7 @@ private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSe
builder.count(false, "c"),
builder.count(false, "d", builder.field(0)),
builder.max(builder.field(0)).as("m"))
.as("q")
.as(qAlias)
.join(JoinRelType.INNER);
caseRexNode =
builder.call(SqlStdOperatorTable.CASE,
Expand All @@ -297,10 +303,10 @@ private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSe
builder.lessThanOrEqual(builder.field("d"),
builder.literal(1))),
builder.or(
builder.notEquals(e.operands.get(0), builder.field("q", "m")),
builder.notEquals(e.operands.get(0), builder.field(qAlias, "m")),
literalUnknown),
builder.equals(builder.field("d"), builder.literal(1)),
builder.notEquals(e.operands.get(0), builder.field("q", "m")),
builder.notEquals(e.operands.get(0), builder.field(qAlias, "m")),
literalTrue);
break;

Expand Down Expand Up @@ -344,23 +350,23 @@ private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSe

parentQueryFields.addAll(builder.fields());
parentQueryFields.add(builder.alias(literalTrue, indicator));
builder.project(parentQueryFields).as("q");
builder.project(parentQueryFields).as(qAlias);
builder.join(JoinRelType.LEFT, literalTrue, variablesSet);
caseRexNode =
builder.call(SqlStdOperatorTable.CASE,
builder.isNull(builder.field("q", indicator)),
builder.isNull(builder.field(qAlias, indicator)),
literalFalse,
builder.equals(builder.field("q", "c"), builder.literal(0)),
builder.equals(builder.field(qAlias, "c"), builder.literal(0)),
literalFalse,
builder.call(SqlStdOperatorTable.IS_TRUE,
builder.call(RexUtil.op(op.comparisonKind),
e.operands.get(0), builder.field("q", "m"))),
e.operands.get(0), builder.field(qAlias, "m"))),
literalTrue,
builder.greaterThan(builder.field("q", "c"),
builder.field("q", "d")),
builder.greaterThan(builder.field(qAlias, "c"),
builder.field(qAlias, "d")),
literalUnknown,
builder.call(RexUtil.op(op.comparisonKind),
e.operands.get(0), builder.field("q", "m")));
e.operands.get(0), builder.field(qAlias, "m")));
break;

case NOT_EQUALS:
Expand Down Expand Up @@ -397,11 +403,11 @@ private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSe

parentQueryFields.addAll(builder.fields());
parentQueryFields.add(builder.alias(literalTrue, indicator));
builder.project(parentQueryFields).as("q"); // TODO use projectPlus
builder.project(parentQueryFields).as(qAlias); // TODO use projectPlus
builder.join(JoinRelType.LEFT, literalTrue, variablesSet);
caseRexNode =
builder.call(SqlStdOperatorTable.CASE,
builder.isNull(builder.field("q", indicator)),
builder.isNull(builder.field(qAlias, indicator)),
literalFalse,
builder.equals(builder.field("c"), builder.literal(0)),
literalFalse,
Expand All @@ -412,10 +418,10 @@ private static RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSe
builder.lessThanOrEqual(builder.field("d"),
builder.literal(1))),
builder.or(
builder.notEquals(e.operands.get(0), builder.field("q", "m")),
builder.notEquals(e.operands.get(0), builder.field(qAlias, "m")),
literalUnknown),
builder.equals(builder.field("d"), builder.literal(1)),
builder.notEquals(e.operands.get(0), builder.field("q", "m")),
builder.notEquals(e.operands.get(0), builder.field(qAlias, "m")),
literalTrue);
break;

Expand Down Expand Up @@ -536,17 +542,18 @@ private static RexNode rewriteUnique(RexSubQuery e, RelBuilder builder) {
/**
* Rewrites an IN RexSubQuery into a {@link Join}.
*
* @param e IN sub-query to rewrite
* @param variablesSet A set of variables used by a relational
* expression of the specified RexSubQuery
* @param logic Logic for evaluating
* @param builder Builder
* @param offset Offset to shift {@link RexInputRef}
* @param e IN sub-query to rewrite
* @param variablesSet A set of variables used by a relational
* expression of the specified RexSubQuery
* @param logic Logic for evaluating
* @param builder Builder
* @param offset Offset to shift {@link RexInputRef}
* @param subQueryIndex sub-query index in multiple sub-queries
*
* @return Expression that may be used to replace the RexSubQuery
*/
private static RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet,
RelOptUtil.Logic logic, RelBuilder builder, int offset) {
RelOptUtil.Logic logic, RelBuilder builder, int offset, int subQueryIndex) {
// Most general case, where the left and right keys might have nulls, and
// caller requires 3-valued logic return.
//
Expand Down Expand Up @@ -628,6 +635,11 @@ private static RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet,
// order by cs desc limit 1) as dt
//

String ctAlias = "ct";
if (subQueryIndex != 0) {
ctAlias = "ct" + subQueryIndex;
}

boolean allLiterals = RexUtil.allLiterals(e.getOperands());
final List<RexNode> expressionOperands = new ArrayList<>(e.getOperands());

Expand Down Expand Up @@ -698,7 +710,7 @@ private static RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet,
builder.aggregate(builder.groupKey(),
builder.count(false, "c"),
builder.count(builder.fields()).as("ck"));
builder.as("ct");
builder.as(ctAlias);
if (!variablesSet.isEmpty()) {
builder.join(JoinRelType.LEFT, trueLiteral, variablesSet);
} else {
Expand All @@ -714,7 +726,11 @@ private static RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet,
}
}

builder.as("dt");
String dtAlias = "dt";
if (subQueryIndex != 0) {
dtAlias = "dt" + subQueryIndex;
}
builder.as(dtAlias);
int refOffset = offset;
final List<RexNode> conditions =
Pair.zip(expressionOperands, builder.fields()).stream()
Expand Down Expand Up @@ -750,7 +766,7 @@ private static RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet,
b);
} else {
operands.add(
builder.equals(builder.field("ct", "c"), builder.literal(0)),
builder.equals(builder.field(ctAlias, "c"), builder.literal(0)),
falseLiteral);
}
break;
Expand All @@ -775,8 +791,8 @@ private static RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet,
case TRUE_FALSE_UNKNOWN:
case UNKNOWN_AS_TRUE:
operands.add(
builder.lessThan(builder.field("ct", "ck"),
builder.field("ct", "c")),
builder.lessThan(builder.field(ctAlias, "ck"),
builder.field(ctAlias, "c")),
b);
break;
default:
Expand Down Expand Up @@ -825,7 +841,7 @@ private static void matchProject(SubQueryRemoveRule rule,
final Set<CorrelationId> variablesSet =
RelOptUtil.getVariablesUsed(e.rel);
final RexNode target =
rule.apply(e, variablesSet, logic, builder, 1, fieldCount);
rule.apply(e, variablesSet, logic, builder, 1, fieldCount, 0);
final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
builder.project(shuttle.apply(project.getProjects()),
project.getRowType().getFieldNames());
Expand All @@ -852,7 +868,7 @@ private static void matchFilter(SubQueryRemoveRule rule,
RelOptUtil.getVariablesUsed(e.rel);
final RexNode target =
rule.apply(e, variablesSet, logic,
builder, 1, builder.peek().getRowType().getFieldCount());
builder, 1, builder.peek().getRowType().getFieldCount(), count);
final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
c = c.accept(shuttle);
}
Expand All @@ -876,7 +892,7 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) {
final Set<CorrelationId> variablesSet =
RelOptUtil.getVariablesUsed(e.rel);
final RexNode target =
rule.apply(e, variablesSet, logic, builder, 2, fieldCount);
rule.apply(e, variablesSet, logic, builder, 2, fieldCount, 0);
final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
builder.join(join.getJoinType(), shuttle.apply(join.getCondition()));
builder.project(fields(builder, join.getRowType().getFieldCount()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class SqlAdvisorTest extends SqlValidatorTestCase {
"TABLE(CATALOG.SALES.EMPTY_PRODUCTS)",
"TABLE(CATALOG.SALES.EMP_ADDRESS)",
"TABLE(CATALOG.SALES.DEPT)",
"TABLE(CATALOG.SALES.DEPTNULLABLES)",
"TABLE(CATALOG.SALES.DEPT_SINGLE)",
"TABLE(CATALOG.SALES.DEPT_NESTED)",
"TABLE(CATALOG.SALES.DEPT_NESTED_EXPANDED)",
Expand Down
94 changes: 94 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6219,6 +6219,40 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
sql(sql).withSubQueryRules().withLateDecorrelate(true).check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5655">[CALCITE-5655]
* Wrong field reference lookup due to same intermediate table alias
* of multiple sub-queries in subquery remove phase</a>. */
@Test public void testSomeWithTwoCorrelatedSubQueries() {
final String sql = "select empno from sales.empnullables as e\n"
+ "where deptno > some(\n"
+ " select deptno from sales.deptnullables where e.ename = name and deptno > 10)\n"
+ "or deptno < some(\n"
+ " select deptno from sales.deptnullables where e.ename = name and deptno < 20)";
sql(sql)
.withSubQueryRules()
.withRelBuilderSimplify(false)
.withTrim(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5655">[CALCITE-5655]
* Wrong field reference lookup due to same intermediate table alias
* of multiple sub-queries in subquery remove phase</a>. */
@Test public void testSomeWithTwoSubQueries() {
final String sql = "select empno from sales.empnullables\n"
+ "where deptno > some(\n"
+ " select deptno from sales.deptnullables where name = 'dept1')\n"
+ "or deptno < some(\n"
+ " select deptno from sales.deptnullables where name = 'dept2')";
sql(sql)
.withSubQueryRules()
.withRelBuilderSimplify(false)
.withTrim(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-1546">[CALCITE-1546]
* Sub-queries connected by OR</a>. */
Expand Down Expand Up @@ -6252,6 +6286,32 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
.check();
}

@Test void testExpandProjectInWithTwoCorrelatedSubQueries() {
final String sql = "select empno, deptno in (\n"
+ " select deptno from sales.deptnullables where e.ename = name and deptno > 10)\n"
+ "or deptno in (\n"
+ " select deptno from sales.deptnullables where e.ename = name and deptno < 20)\n"
+ "from sales.empnullables as e";
sql(sql)
.withSubQueryRules()
.withRelBuilderSimplify(false)
.withTrim(true)
.check();
}

@Test void testExpandProjectInWithTwoSubQueries() {
final String sql = "select empno, deptno in (\n"
+ " select deptno from sales.deptnullables where name = 'dept1')\n"
+ "or deptno in (\n"
+ " select deptno from sales.deptnullables where name = 'dept2')\n"
+ "from sales.empnullables";
sql(sql)
.withSubQueryRules()
.withRelBuilderSimplify(false)
.withTrim(true)
.check();
}

@Test void testExpandProjectInComposite() {
final String sql = "select empno, (empno, deptno) in (\n"
+ " select empno, deptno from sales.emp where empno < 20) as d\n"
Expand Down Expand Up @@ -6298,6 +6358,40 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
sql(sql).withSubQueryRules().check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5655">[CALCITE-5655]
* Wrong field reference lookup due to same intermediate table alias
* of multiple sub-queries in subquery remove phase</a>. */
@Test void testExpandFilterInCorrelatedWithTwoSubQueries() {
final String sql = "select empno from sales.empnullables as e\n"
+ "where deptno in (\n"
+ " select deptno from sales.deptnullables where e.ename = name and deptno > 10)\n"
+ "or deptno in (\n"
+ " select deptno from sales.deptnullables where e.ename = name and deptno < 20)";
sql(sql)
.withSubQueryRules()
.withRelBuilderSimplify(false)
.withTrim(true)
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5655">[CALCITE-5655]
* Wrong field reference lookup due to same intermediate table alias
* of multiple sub-queries in subquery remove phase</a>. */
@Test void testExpandFilterInWithTwoSubQueries() {
final String sql = "select empno from sales.empnullables\n"
+ "where deptno in (\n"
+ " select deptno from sales.deptnullables where name = 'dept1')\n"
+ "or deptno in (\n"
+ " select deptno from sales.deptnullables where name = 'dept2')";
sql(sql)
.withSubQueryRules()
.withRelBuilderSimplify(false)
.withTrim(true)
.check();
}

/** An IN filter that requires full 3-value logic (true, false, unknown). */
@Test void testExpandFilterIn3Value() {
final String sql = "select empno\n"
Expand Down
Loading

0 comments on commit e2028ad

Please sign in to comment.