Skip to content

Commit

Permalink
[FLINK-12249][table] Fix type equivalence check problems for Window A…
Browse files Browse the repository at this point in the history
…ggregates

This closes apache#9141
  • Loading branch information
hequn8128 authored and dawidwys committed Jul 30, 2019
1 parent 01f8c35 commit 305051c
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ public static void main(String[] args) throws Exception {
String tumbleQuery = String.format(
"SELECT " +
" key, " +
//TODO: The "WHEN -1 THEN NULL" part is a temporary workaround, to make the test pass, for
// https://issues.apache.org/jira/browse/FLINK-12249. We should remove it once the issue is fixed.
" CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 WHEN -1 THEN NULL ELSE 99 END AS correct, " +
" CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 ELSE 99 END AS correct, " +
" TUMBLE_START(rowtime, INTERVAL '%d' SECOND) AS wStart, " +
" TUMBLE_ROWTIME(rowtime, INTERVAL '%d' SECOND) AS rowtime " +
"FROM (%s) " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import org.apache.calcite.plan._
import org.apache.calcite.plan.hep.HepRelVertex
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Aggregate.Group
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject}
import org.apache.calcite.rex._
import org.apache.calcite.sql.`type`.SqlTypeUtil
import org.apache.calcite.util.ImmutableBitSet

import _root_.java.math.BigDecimal
Expand Down Expand Up @@ -84,15 +86,31 @@ abstract class LogicalWindowAggregateRuleBase(description: String)
.project(project.getChildExps.updated(windowExprIdx, inAggGroupExpression))
.build()

// Currently, this rule removes the window from GROUP BY operation which may lead to changes
// of AggCall's type which brings fails on type checks.
// To solve the problem, we change the types to the inferred types in the Aggregate and then
// cast back in the project after Aggregate.
val indexAndTypes = getIndexAndInferredTypesIfChanged(agg)
val finalCalls = adjustTypes(agg, indexAndTypes)

// we don't use the builder here because it uses RelMetadataQuery which affects the plan
val newAgg = LogicalAggregate.create(
newProject,
agg.indicator,
newGroupSet,
ImmutableList.of(newGroupSet),
agg.getAggCallList)
finalCalls)

val transformed = call.builder()
val windowAgg = LogicalWindowAggregate.create(
window,
Seq[PlannerNamedWindowProperty](),
newAgg)
transformed.push(windowAgg)

// create an additional project to conform with types
// The transformation adds an additional LogicalProject at the top to ensure
// that the types are equivalent.
// 1. ensure group key types, create an additional project to conform with types
val outAggGroupExpression0 = getOutAggregateGroupExpression(rexBuilder, windowExpr)
// fix up the nullability if it is changed.
val outAggGroupExpression = if (windowExpr.getType.isNullable !=
Expand All @@ -103,20 +121,80 @@ abstract class LogicalWindowAggregateRuleBase(description: String)
} else {
outAggGroupExpression0
}
val transformed = call.builder()
val windowAgg = LogicalWindowAggregate.create(
window,
Seq[PlannerNamedWindowProperty](),
newAgg)
// The transformation adds an additional LogicalProject at the top to ensure
// that the types are equivalent.
transformed.push(windowAgg)
.project(transformed.fields().patch(windowExprIdx, Seq(outAggGroupExpression), 0))
val projectsEnsureGroupKeyTypes =
transformed.fields.patch(windowExprIdx, Seq(outAggGroupExpression), 0)
// 2. ensure aggCall types
val projectsEnsureAggCallTypes =
projectsEnsureGroupKeyTypes.zipWithIndex.map {
case (aggCall, index) =>
val aggCallIndex = index - agg.getGroupCount
if (indexAndTypes.containsKey(aggCallIndex)) {
rexBuilder.makeCast(agg.getAggCallList.get(aggCallIndex).`type`, aggCall, true)
} else {
aggCall
}
}
transformed.project(projectsEnsureAggCallTypes)

val result = transformed.build()
call.transformTo(result)
}

/**
* Change the types of [[AggregateCall]] to the corresponding inferred types.
*/
private def adjustTypes(
agg: LogicalAggregate,
indexAndTypes: Map[Int, RelDataType]) = {

agg.getAggCallList.zipWithIndex.map {
case (aggCall, index) =>
if (indexAndTypes.containsKey(index)) {
AggregateCall.create(
aggCall.getAggregation,
aggCall.isDistinct,
aggCall.isApproximate,
aggCall.ignoreNulls(),
aggCall.getArgList,
aggCall.filterArg,
aggCall.collation,
agg.getGroupCount,
agg.getInput,
indexAndTypes(index),
aggCall.name)
} else {
aggCall
}
}
}

/**
* Check if there are any types of [[AggregateCall]] that need to be changed. Return the
* [[AggregateCall]] indexes and the corresponding inferred types.
*/
private def getIndexAndInferredTypesIfChanged(
agg: LogicalAggregate)
: Map[Int, RelDataType] = {

agg.getAggCallList.zipWithIndex.flatMap {
case (aggCall, index) =>
val origType = aggCall.`type`
val aggCallBinding = new Aggregate.AggCallBinding(
agg.getCluster.getTypeFactory,
aggCall.getAggregation,
SqlTypeUtil.projectTypes(agg.getInput.getRowType, aggCall.getArgList),
0,
aggCall.hasFilter)
val inferredType = aggCall.getAggregation.inferReturnType(aggCallBinding)

if (origType != inferredType && agg.getGroupCount == 1) {
Some(index, inferredType)
} else {
None
}
}.toMap
}

private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = {
val project = agg.getInput.asInstanceOf[HepRelVertex].getCurrentRel.asInstanceOf[LogicalProject]
val groupKeys = agg.getGroupSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1409,4 +1409,111 @@ Calc(select=[w$end AS EXPR$0])
]]>
</Resource>
</TestCase>
<TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=AUTO]">
<Resource name="sql">
<![CDATA[
SELECT
SUM(correct) AS s,
AVG(correct) AS a,
TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
FROM (
SELECT CASE a
WHEN 1 THEN 1
ELSE 99
END AS correct, b
FROM MyTable
)
GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
+- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
+- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1])
+- Exchange(distribution=[single])
+- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1])
+- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
<TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=ONE_PHASE]">
<Resource name="sql">
<![CDATA[
SELECT
SUM(correct) AS s,
AVG(correct) AS a,
TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
FROM (
SELECT CASE a
WHEN 1 THEN 1
ELSE 99
END AS correct, b
FROM MyTable
)
GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
+- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
+- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1])
+- Exchange(distribution=[single])
+- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
<TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=TWO_PHASE]">
<Resource name="sql">
<![CDATA[
SELECT
SUM(correct) AS s,
AVG(correct) AS a,
TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
FROM (
SELECT CASE a
WHEN 1 THEN 1
ELSE 99
END AS correct, b
FROM MyTable
)
GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
+- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
+- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1])
+- Exchange(distribution=[single])
+- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1])
+- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
</Root>
Original file line number Diff line number Diff line change
Expand Up @@ -467,4 +467,39 @@ Calc(select=[EXPR$0, wAvg, w$start AS EXPR$2, w$end AS EXPR$3])
]]>
</Resource>
</TestCase>
<TestCase name="testReturnTypeInferenceForWindowAgg">
<Resource name="sql">
<![CDATA[
SELECT
SUM(correct) AS s,
AVG(correct) AS a,
TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart
FROM (
SELECT CASE a
WHEN 1 THEN 1
ELSE 99
END AS correct, rowtime
FROM MyTable
)
GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
+- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+- LogicalProject($f0=[TUMBLE($4, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
+- GroupWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime, w$proctime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime])
+- Exchange(distribution=[single])
+- Calc(select=[rowtime, CASE(=(a, 1), 1, 99) AS $f1])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
</Root>
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,28 @@ class WindowAggregateTest(aggStrategy: AggregatePhaseStrategy) extends TableTest
""".stripMargin
util.verifyPlan(sql)
}

@Test
def testReturnTypeInferenceForWindowAgg() = {

val sql =
"""
|SELECT
| SUM(correct) AS s,
| AVG(correct) AS a,
| TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
|FROM (
| SELECT CASE a
| WHEN 1 THEN 1
| ELSE 99
| END AS correct, b
| FROM MyTable
|)
|GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
""".stripMargin

util.verifyPlan(sql)
}
}

object WindowAggregateTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,25 @@ class WindowAggregateTest extends TableTestBase {
util.verifyPlan(sql)
}

@Test
def testReturnTypeInferenceForWindowAgg() = {

val sql =
"""
|SELECT
| SUM(correct) AS s,
| AVG(correct) AS a,
| TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart
|FROM (
| SELECT CASE a
| WHEN 1 THEN 1
| ELSE 99
| END AS correct, rowtime
| FROM MyTable
|)
|GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)
""".stripMargin

util.verifyPlan(sql)
}
}
Loading

0 comments on commit 305051c

Please sign in to comment.