Skip to content

Commit

Permalink
[CALCITE-4420] Some simple arithmetic operations can be simplified
Browse files Browse the repository at this point in the history
  • Loading branch information
liyafan82 committed Jul 12, 2021
1 parent 4066a34 commit c700e37
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 19 deletions.
90 changes: 90 additions & 0 deletions core/src/main/java/org/apache/calcite/rex/RexSimplify.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeCoercionRule;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.Bug;
Expand All @@ -49,6 +50,7 @@

import org.checkerframework.checker.nullness.qual.Nullable;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
Expand Down Expand Up @@ -311,6 +313,11 @@ RexNode simplify(RexNode e, RexUnknownAs unknownAs) {
return simplifyUnaryMinus((RexCall) e, unknownAs);
case PLUS_PREFIX:
return simplifyUnaryPlus((RexCall) e, unknownAs);
case PLUS:
case MINUS:
case TIMES:
case DIVIDE:
return simplifyArithmetic((RexCall) e);
default:
if (e.getClass() == RexCall.class) {
return simplifyGenericNode((RexCall) e);
Expand Down Expand Up @@ -389,6 +396,89 @@ private RexNode simplifyGenericNode(RexCall e) {
return rexBuilder.makeCall(e.getType(), e.getOperator(), operands);
}

/**
* Try to find a literal with the given value in the input list.
* The type of the literal must be one of the numeric types.
*/
private int findLiteralIndex(List<RexNode> operands, BigDecimal value) {
for (int i = 0; i < operands.size(); i++) {
if (operands.get(i).isA(SqlKind.LITERAL)) {
Comparable comparable = ((RexLiteral) operands.get(i)).getValue();
if (comparable instanceof BigDecimal
&& value.compareTo((BigDecimal) comparable) == 0) {
return i;
}
}
}
return -1;
}

private RexNode simplifyArithmetic(RexCall e) {
if (e.getType().getSqlTypeName().getFamily() != SqlTypeFamily.NUMERIC
|| e.getOperands().stream().anyMatch(
o -> e.getType().getSqlTypeName().getFamily() != SqlTypeFamily.NUMERIC)) {
// we only support simplifying numeric types.
return simplifyGenericNode(e);
}

assert e.getOperands().size() == 2;

switch (e.getKind()) {
case PLUS:
return simplifyPlus(e);
case MINUS:
return simplifyMinus(e);
case TIMES:
return simplifyMultiply(e);
case DIVIDE:
return simplifyDivide(e);
default:
throw new IllegalArgumentException("Unsupported arithmeitc operation " + e.getKind());
}
}

private RexNode simplifyPlus(RexCall e) {
final int zeroIndex = findLiteralIndex(e.operands, BigDecimal.ZERO);
if (zeroIndex >= 0) {
// return the other operand.
RexNode other = e.getOperands().get((zeroIndex + 1) % 2);
return other.getType().equals(e.getType())
? other : rexBuilder.makeCast(e.getType(), other);
}
return simplifyGenericNode(e);
}

private RexNode simplifyMinus(RexCall e) {
final int zeroIndex = findLiteralIndex(e.operands, BigDecimal.ZERO);
if (zeroIndex == 1) {
RexNode leftOperand = e.getOperands().get(0);
return leftOperand.getType().equals(e.getType())
? leftOperand : rexBuilder.makeCast(e.getType(), leftOperand);
}
return simplifyGenericNode(e);
}

private RexNode simplifyMultiply(RexCall e) {
final int oneIndex = findLiteralIndex(e.operands, BigDecimal.ONE);
if (oneIndex >= 0) {
// return the other operand.
RexNode other = e.getOperands().get((oneIndex + 1) % 2);
return other.getType().equals(e.getType())
? other : rexBuilder.makeCast(e.getType(), other);
}
return simplifyGenericNode(e);
}

private RexNode simplifyDivide(RexCall e) {
final int oneIndex = findLiteralIndex(e.operands, BigDecimal.ONE);
if (oneIndex == 1) {
RexNode leftOperand = e.getOperands().get(0);
return leftOperand.getType().equals(e.getType())
? leftOperand : rexBuilder.makeCast(e.getType(), leftOperand);
}
return simplifyGenericNode(e);
}

private RexNode simplifyLike(RexCall e, RexUnknownAs unknownAs) {
if (e.operands.get(1) instanceof RexLiteral) {
final RexLiteral literal = (RexLiteral) e.operands.get(1);
Expand Down
31 changes: 30 additions & 1 deletion core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,7 @@ trueLiteral, literal(1),
RexNode caseNode = case_(
gt(div(vIntNotNull(), literal(1)), literal(1)), falseLiteral,
trueLiteral);
checkSimplify(caseNode, "<=(/(?0.notNullInt0, 1), 1)");
checkSimplify(caseNode, "<=(?0.notNullInt0, 1)");
}

@Test void testPushNotIntoCase() {
Expand Down Expand Up @@ -3224,4 +3224,33 @@ private SqlSpecialOperatorWithPolicy(String name, SqlKind kind, int prec, boolea
@Test void testSimplifyVarbinary() {
checkSimplifyUnchanged(cast(cast(vInt(), tVarchar(true, 100)), tVarbinary(true)));
}

@Test void testSimplifySimpleArithmetic() {
RexNode a = vIntNotNull(1);
RexNode zero = literal(0);
RexNode one = literal(1);

RexNode b = vDecimalNotNull(2);
RexNode half = literal(new BigDecimal(0.5), b.getType());

checkSimplify(add(a, zero), "?0.notNullInt1");
checkSimplify(add(zero, a), "?0.notNullInt1");
checkSimplify(add(a, nullInt), "null:INTEGER");
checkSimplify(add(nullInt, a), "null:INTEGER");

checkSimplify(sub(a, zero), "?0.notNullInt1");
checkSimplify(sub(a, nullInt), "null:INTEGER");

checkSimplify(mul(a, one), "?0.notNullInt1");
checkSimplify(mul(one, a), "?0.notNullInt1");
checkSimplify(mul(a, nullInt), "null:INTEGER");
checkSimplify(mul(nullInt, a), "null:INTEGER");

checkSimplify(div(a, one), "?0.notNullInt1");
checkSimplify(div(a, nullInt), "null:INTEGER");

checkSimplifyUnchanged(add(b, half));

checkSimplify(add(zero, sub(nullInt, nullInt)), "null:INTEGER");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,9 @@ JOIN dept on emp.deptno + 0 = dept.deptno]]>
</Resource>
<Resource name="plan">
<![CDATA[
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$10], NAME=[$11])
LogicalJoin(condition=[=($9, $10)], joinType=[inner])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], $f9=[+($7, 0)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$9], NAME=[$10])
LogicalJoin(condition=[=($7, $9)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
Expand All @@ -557,10 +556,9 @@ JOIN dept on dept.deptno = emp.deptno + 0]]>
</Resource>
<Resource name="plan">
<![CDATA[
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$10], NAME=[$11])
LogicalJoin(condition=[=($10, $9)], joinType=[inner])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], $f9=[+($7, 0)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$9], NAME=[$10])
LogicalJoin(condition=[=($9, $7)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/resources/sql/misc.iq
Original file line number Diff line number Diff line change
Expand Up @@ -2334,13 +2334,13 @@ FROM (VALUES (0, 2, 4, 8),
(1, 2, 4, 8),
(CAST(null as int), CAST(null as int), CAST(null as int), CAST(null as int))) AS T(A,B,C,D);
V
14.0
13.0
9.5
1.75
0.0
1.875
null
0
14
!ok

# End misc.iq
Original file line number Diff line number Diff line change
Expand Up @@ -3103,12 +3103,12 @@ private void testCountWithApproxDistinct(boolean approx, String sql,
new DruidChecker(
"\"filter\":{\"type\":\"expression\",\"expression\":\"(((CAST(\\\"product_id\\\", ",
"LONG",
") + (1 * \\\"store_sales\\\")) / (\\\"store_cost\\\" - 5))",
") + \\\"store_sales\\\") / (\\\"store_cost\\\" - 5))",
" <= ((floor(\\\"store_sales\\\") * 25) + 2))\"}"))
.explainContains("PLAN=EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], "
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
+ "filter=[<=(/(+(CAST($1):INTEGER, *(1, $90)), -($91, 5)), +(*(FLOOR($90), 25), 2))], "
+ "filter=[<=(/(+(CAST($1):INTEGER, $90), -($91, 5)), +(*(FLOOR($90), 25), 2))], "
+ "groups=[{}], aggs=[[COUNT()]])")
.returnsOrdered("EXPR$0=82129");
}
Expand All @@ -3135,7 +3135,7 @@ private void testCountWithApproxDistinct(boolean approx, String sql,
+ "AND EXTRACT(MONTH FROM \"timestamp\") / 4 + 1 = 1";
final String queryType = "{'queryType':'timeseries','dataSource':'foodmart'";
final String filterExp1 = "{'type':'expression','expression':'(((CAST(\\'product_id\\'";
final String filterExpPart2 = " (1 * \\'store_sales\\')) / (\\'store_cost\\' - 5)) "
final String filterExpPart2 = " \\'store_sales\\') / (\\'store_cost\\' - 5)) "
+ "<= ((floor(\\'store_sales\\') * 25) + 2))'}";
final String likeExpressionFilter = "{'type':'expression','expression':'like(\\'product_id\\'";
final String likeExpressionFilter2 = "1%";
Expand All @@ -3157,7 +3157,7 @@ private void testCountWithApproxDistinct(boolean approx, String sql,
final String quarterAsExpressionFilter3 = "/ 4) + 1) == 1)'}]}";
final String plan = "PLAN=EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000Z/"
+ "2992-01-10T00:00:00.000Z]], filter=[AND(<=(/(+(CAST($1):INTEGER, *(1, $90)), "
+ "2992-01-10T00:00:00.000Z]], filter=[AND(<=(/(+(CAST($1):INTEGER, $90), "
+ "-($91, 5)), +(*(FLOOR($90), 25), 2)), >($90, 0), LIKE($1, '1%'), >($91, 1), "
+ "<($0, 1997-01-02 00:00:00), =(EXTRACT(FLAG(MONTH), $0), 1), "
+ "=(EXTRACT(FLAG(DAY), $0), 1), =(+(/(EXTRACT(FLAG(MONTH), $0), 4), 1), 1))], "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3733,12 +3733,12 @@ private void testCountWithApproxDistinct(boolean approx, String sql, String expe
new DruidChecker(
"\"filter\":{\"type\":\"expression\",\"expression\":\"(((CAST(\\\"product_id\\\", ",
"LONG",
") + (1 * \\\"store_sales\\\")) / (\\\"store_cost\\\" - 5))",
") + \\\"store_sales\\\") / (\\\"store_cost\\\" - 5))",
" <= ((floor(\\\"store_sales\\\") * 25) + 2))\"}"))
.explainContains("PLAN=EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], "
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
+ "filter=[<=(/(+(CAST($1):INTEGER, *(1, $90)), -($91, 5)), +(*(FLOOR($90), 25), 2))], "
+ "filter=[<=(/(+(CAST($1):INTEGER, $90), -($91, 5)), +(*(FLOOR($90), 25), 2))], "
+ "groups=[{}], aggs=[[COUNT()]])")
.returnsOrdered("EXPR$0=82129");
}
Expand Down Expand Up @@ -3766,7 +3766,7 @@ private void testCountWithApproxDistinct(boolean approx, String sql, String expe
+ "AND EXTRACT(MONTH FROM \"timestamp\") / 4 + 1 = 1 ";
final String queryType = "{'queryType':'timeseries','dataSource':'foodmart'";
final String filterExp1 = "{'type':'expression','expression':'(((CAST(\\'product_id\\'";
final String filterExpPart2 = " (1 * \\'store_sales\\')) / (\\'store_cost\\' - 5)) "
final String filterExpPart2 = " \\'store_sales\\') / (\\'store_cost\\' - 5)) "
+ "<= ((floor(\\'store_sales\\') * 25) + 2))'}";
final String likeExpressionFilter = "{'type':'expression','expression':'like(\\'product_id\\'";
final String likeExpressionFilter2 = "1%";
Expand Down Expand Up @@ -3794,7 +3794,7 @@ private void testCountWithApproxDistinct(boolean approx, String sql, String expe
final String plan = "PLAN=EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], "
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], "
+ "filter=[AND(<=(/(+(CAST($1):INTEGER, *(1, $90)), -($91, 5)), +(*(FLOOR($90), 25), 2)), "
+ "filter=[AND(<=(/(+(CAST($1):INTEGER, $90), -($91, 5)), +(*(FLOOR($90), 25), 2)), "
+ ">($90, 0), LIKE($1, '1%'), >($91, 1), <($0, 1997-01-02 00:00:00), "
+ "=(EXTRACT(FLAG(MONTH), $0), 1), =(EXTRACT(FLAG(DAY), $0), 1), "
+ "=(+(/(EXTRACT(FLAG(MONTH), $0), 4), 1), 1))], groups=[{}], aggs=[[COUNT()]])";
Expand Down

0 comments on commit c700e37

Please sign in to comment.