diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala index 87ebd86f1e548..f9bf80344bd05 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala @@ -35,8 +35,8 @@ trait OverAggregate { } private[flink] def orderingToString( - inputType: RelDataType, - orderFields: java.util.List[RelFieldCollation]): String = { + inputType: RelDataType, + orderFields: java.util.List[RelFieldCollation]): String = { val inFields = inputType.getFieldList.asScala @@ -48,9 +48,9 @@ trait OverAggregate { } private[flink] def windowRange( - logicWindow: Window, - overWindow: Group, - input: RelNode): String = { + logicWindow: Window, + overWindow: Group, + input: RelNode): String = { if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded) { s"BETWEEN ${getLowerBoundary(logicWindow, overWindow, input)} PRECEDING " + s"AND ${overWindow.upperBound}" @@ -63,8 +63,7 @@ trait OverAggregate { inputType: RelDataType, constants: Seq[RexLiteral], rowType: RelDataType, - namedAggregates: Seq[CalcitePair[AggregateCall, String]]) - : String = { + namedAggregates: Seq[CalcitePair[AggregateCall, String]]): String = { val inFields = inputType.getFieldNames.asScala val outFields = rowType.getFieldNames.asScala @@ -97,12 +96,12 @@ trait OverAggregate { } private[flink] def getLowerBoundary( - logicWindow: Window, - overWindow: Group, - input: RelNode): Long = { + logicWindow: Window, + overWindow: Group, + input: RelNode): Long = { val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef] - val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex + val lowerBoundIndex = ref.getIndex - input.getRowType.getFieldCount val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2 lowerBound match { case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala index 4884513a5d3f5..9bfdc4cac2f88 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.runtime.stream.sql import org.apache.flink.api.common.time.Time +import org.apache.flink.api.java.tuple.Tuple1 import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.functions.source.SourceFunction @@ -28,6 +29,7 @@ import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment} +import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.types.Row import org.junit.Assert._ @@ -293,13 +295,16 @@ class OverWindowITCase extends StreamingWithStateTestBase { .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) tEnv.registerTable("T1", t1) + tEnv.registerFunction("LTCNT", new LargerThanCount) val sqlQuery = "SELECT " + " c, b, " + + " LTCNT(a, CAST('4' AS BIGINT)) OVER (PARTITION BY c ORDER BY rowtime RANGE " + + " BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW), " + " COUNT(a) OVER (PARTITION BY c ORDER BY rowtime RANGE " + " BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW), " + " SUM(a) OVER (PARTITION BY c ORDER BY rowtime RANGE " + - " BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW)" + + " BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW) " + " FROM T1" val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] @@ -307,16 +312,17 @@ class OverWindowITCase extends StreamingWithStateTestBase { env.execute() val expected = List( - "Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3", - "Hello,2,6,9", "Hello,3,6,9", "Hello,2,6,9", - "Hello,3,4,9", - "Hello,4,2,7", - "Hello,5,2,9", - "Hello,6,2,11", "Hello,65,2,12", - "Hello,9,2,12", "Hello,9,2,12", "Hello,18,3,18", - "Hello World,7,1,7", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7", - "Hello World,8,2,15", - "Hello World,20,1,20") + "Hello,1,0,1,1", "Hello,15,0,2,2", "Hello,16,0,3,3", + "Hello,2,0,6,9", "Hello,3,0,6,9", "Hello,2,0,6,9", + "Hello,3,0,4,9", + "Hello,4,0,2,7", + "Hello,5,1,2,9", + "Hello,6,2,2,11", "Hello,65,2,2,12", + "Hello,9,2,2,12", "Hello,9,2,2,12", "Hello,18,3,3,18", + "Hello World,7,1,1,7", "Hello World,17,3,3,21", "Hello World,77,3,3,21", + "Hello World,18,1,1,7", + "Hello World,8,2,2,15", + "Hello World,20,1,1,20") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -354,9 +360,12 @@ class OverWindowITCase extends StreamingWithStateTestBase { .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) tEnv.registerTable("T1", t1) + tEnv.registerFunction("LTCNT", new LargerThanCount) val sqlQuery = "SELECT " + " c, a, " + + " LTCNT(a, CAST('4' AS BIGINT)) " + + " OVER (PARTITION BY c ORDER BY rowtime ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), " + " COUNT(a) " + " OVER (PARTITION BY c ORDER BY rowtime ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), " + " SUM(a) " + @@ -368,12 +377,12 @@ class OverWindowITCase extends StreamingWithStateTestBase { env.execute() val expected = List( - "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3", - "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6", - "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12", - "Hello,6,3,15", - "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21", - "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35") + "Hello,1,0,1,1", "Hello,1,0,2,2", "Hello,1,0,3,3", + "Hello,2,0,3,4", "Hello,2,0,3,5", "Hello,2,0,3,6", + "Hello,3,0,3,7", "Hello,4,0,3,9", "Hello,5,1,3,12", + "Hello,6,2,3,15", + "Hello World,7,1,1,7", "Hello World,7,2,2,14", "Hello World,7,3,3,21", + "Hello World,7,3,3,21", "Hello World,8,3,3,22", "Hello World,20,3,3,35") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -518,6 +527,8 @@ class OverWindowITCase extends StreamingWithStateTestBase { StreamITCase.clear val sqlQuery = "SELECT a, b, c, " + + " LTCNT(b, CAST('4' AS BIGINT)) OVER(" + + " PARTITION BY a ORDER BY rowtime RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), " + " SUM(b) OVER (" + " PARTITION BY a ORDER BY rowtime RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), " + " COUNT(b) OVER (" + @@ -552,25 +563,26 @@ class OverWindowITCase extends StreamingWithStateTestBase { .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) tEnv.registerTable("T1", t1) + tEnv.registerFunction("LTCNT", new LargerThanCount) val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = List( - "1,1,Hello,6,3,2,3,1", - "1,2,Hello,6,3,2,3,1", - "1,3,Hello world,6,3,2,3,1", - "1,1,Hi,7,4,1,3,1", - "2,1,Hello,1,1,1,1,1", - "2,2,Hello world,6,3,2,3,1", - "2,3,Hello world,6,3,2,3,1", - "1,4,Hello world,11,5,2,4,1", - "1,5,Hello world,29,8,3,7,1", - "1,6,Hello world,29,8,3,7,1", - "1,7,Hello world,29,8,3,7,1", - "2,4,Hello world,15,5,3,5,1", - "2,5,Hello world,15,5,3,5,1") + "1,1,Hello,0,6,3,2,3,1", + "1,2,Hello,0,6,3,2,3,1", + "1,3,Hello world,0,6,3,2,3,1", + "1,1,Hi,0,7,4,1,3,1", + "2,1,Hello,0,1,1,1,1,1", + "2,2,Hello world,0,6,3,2,3,1", + "2,3,Hello world,0,6,3,2,3,1", + "1,4,Hello world,0,11,5,2,4,1", + "1,5,Hello world,3,29,8,3,7,1", + "1,6,Hello world,3,29,8,3,7,1", + "1,7,Hello world,3,29,8,3,7,1", + "2,4,Hello world,1,15,5,3,5,1", + "2,5,Hello world,1,15,5,3,5,1") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -583,6 +595,8 @@ class OverWindowITCase extends StreamingWithStateTestBase { StreamITCase.testResults = mutable.MutableList() val sqlQuery = "SELECT a, b, c, " + + "LTCNT(b, CAST('4' AS BIGINT)) over(" + + "partition by a order by rowtime rows between unbounded preceding and current row), " + "SUM(b) over (" + "partition by a order by rowtime rows between unbounded preceding and current row), " + "count(b) over (" + @@ -618,26 +632,27 @@ class OverWindowITCase extends StreamingWithStateTestBase { .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) tEnv.registerTable("T1", t1) + tEnv.registerFunction("LTCNT", new LargerThanCount) val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] result.addSink(new StreamITCase.StringSink[Row]) env.execute() val expected = mutable.MutableList( - "1,2,Hello,2,1,2,2,2", - "1,3,Hello world,5,2,2,3,2", - "1,1,Hi,6,3,2,3,1", - "2,1,Hello,1,1,1,1,1", - "2,2,Hello world,3,2,1,2,1", - "3,1,Hello,1,1,1,1,1", - "3,2,Hello world,3,2,1,2,1", - "1,5,Hello world,11,4,2,5,1", - "1,6,Hello world,17,5,3,6,1", - "1,9,Hello world,26,6,4,9,1", - "1,8,Hello world,34,7,4,9,1", - "1,7,Hello world,41,8,5,9,1", - "2,5,Hello world,8,3,2,5,1", - "3,5,Hello world,8,3,2,5,1") + "1,2,Hello,0,2,1,2,2,2", + "1,3,Hello world,0,5,2,2,3,2", + "1,1,Hi,0,6,3,2,3,1", + "2,1,Hello,0,1,1,1,1,1", + "2,2,Hello world,0,3,2,1,2,1", + "3,1,Hello,0,1,1,1,1,1", + "3,2,Hello world,0,3,2,1,2,1", + "1,5,Hello world,1,11,4,2,5,1", + "1,6,Hello world,2,17,5,3,6,1", + "1,9,Hello world,3,26,6,4,9,1", + "1,8,Hello world,4,34,7,4,9,1", + "1,7,Hello world,5,41,8,5,9,1", + "2,5,Hello world,1,8,3,2,5,1", + "3,5,Hello world,1,8,3,2,5,1") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } @@ -852,3 +867,19 @@ object OverWindowITCase { override def cancel(): Unit = ??? } } + +/** Counts how often the first argument was larger than the second argument. */ +class LargerThanCount extends AggregateFunction[Long, Tuple1[Long]] { + + def accumulate(acc: Tuple1[Long], a: Long, b: Long): Unit = { + if (a > b) acc.f0 += 1 + } + + def retract(acc: Tuple1[Long], a: Long, b: Long): Unit = { + if (a > b) acc.f0 -= 1 + } + + override def createAccumulator(): Tuple1[Long] = Tuple1.of(0L) + + override def getValue(acc: Tuple1[Long]): Long = acc.f0 +}