Skip to content

Commit

Permalink
[FLINK-7338] [table] Fix retrieval of OVER window lower bound.
Browse files Browse the repository at this point in the history
  • Loading branch information
fhueske committed Nov 2, 2017
1 parent 78c8ea2 commit 16b0882
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ trait OverAggregate {
}

private[flink] def orderingToString(
inputType: RelDataType,
orderFields: java.util.List[RelFieldCollation]): String = {
inputType: RelDataType,
orderFields: java.util.List[RelFieldCollation]): String = {

val inFields = inputType.getFieldList.asScala

Expand All @@ -48,9 +48,9 @@ trait OverAggregate {
}

private[flink] def windowRange(
logicWindow: Window,
overWindow: Group,
input: RelNode): String = {
logicWindow: Window,
overWindow: Group,
input: RelNode): String = {
if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded) {
s"BETWEEN ${getLowerBoundary(logicWindow, overWindow, input)} PRECEDING " +
s"AND ${overWindow.upperBound}"
Expand All @@ -63,8 +63,7 @@ trait OverAggregate {
inputType: RelDataType,
constants: Seq[RexLiteral],
rowType: RelDataType,
namedAggregates: Seq[CalcitePair[AggregateCall, String]])
: String = {
namedAggregates: Seq[CalcitePair[AggregateCall, String]]): String = {

val inFields = inputType.getFieldNames.asScala
val outFields = rowType.getFieldNames.asScala
Expand Down Expand Up @@ -97,12 +96,12 @@ trait OverAggregate {
}

private[flink] def getLowerBoundary(
logicWindow: Window,
overWindow: Group,
input: RelNode): Long = {
logicWindow: Window,
overWindow: Group,
input: RelNode): Long = {

val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef]
val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex
val lowerBoundIndex = ref.getIndex - input.getRowType.getFieldCount
val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2
lowerBound match {
case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.table.runtime.stream.sql

import org.apache.flink.api.common.time.Time
import org.apache.flink.api.java.tuple.Tuple1
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.functions.source.SourceFunction
Expand All @@ -28,6 +29,7 @@ import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.table.api.scala._
import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction
import org.apache.flink.table.api.{StreamQueryConfig, TableEnvironment}
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase}
import org.apache.flink.types.Row
import org.junit.Assert._
Expand Down Expand Up @@ -293,30 +295,34 @@ class OverWindowITCase extends StreamingWithStateTestBase {
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)

tEnv.registerTable("T1", t1)
tEnv.registerFunction("LTCNT", new LargerThanCount)

val sqlQuery = "SELECT " +
" c, b, " +
" LTCNT(a, CAST('4' AS BIGINT)) OVER (PARTITION BY c ORDER BY rowtime RANGE " +
" BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW), " +
" COUNT(a) OVER (PARTITION BY c ORDER BY rowtime RANGE " +
" BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW), " +
" SUM(a) OVER (PARTITION BY c ORDER BY rowtime RANGE " +
" BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW)" +
" BETWEEN INTERVAL '1' SECOND PRECEDING AND CURRENT ROW) " +
" FROM T1"

val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()

val expected = List(
"Hello,1,1,1", "Hello,15,2,2", "Hello,16,3,3",
"Hello,2,6,9", "Hello,3,6,9", "Hello,2,6,9",
"Hello,3,4,9",
"Hello,4,2,7",
"Hello,5,2,9",
"Hello,6,2,11", "Hello,65,2,12",
"Hello,9,2,12", "Hello,9,2,12", "Hello,18,3,18",
"Hello World,7,1,7", "Hello World,17,3,21", "Hello World,77,3,21", "Hello World,18,1,7",
"Hello World,8,2,15",
"Hello World,20,1,20")
"Hello,1,0,1,1", "Hello,15,0,2,2", "Hello,16,0,3,3",
"Hello,2,0,6,9", "Hello,3,0,6,9", "Hello,2,0,6,9",
"Hello,3,0,4,9",
"Hello,4,0,2,7",
"Hello,5,1,2,9",
"Hello,6,2,2,11", "Hello,65,2,2,12",
"Hello,9,2,2,12", "Hello,9,2,2,12", "Hello,18,3,3,18",
"Hello World,7,1,1,7", "Hello World,17,3,3,21", "Hello World,77,3,3,21",
"Hello World,18,1,1,7",
"Hello World,8,2,2,15",
"Hello World,20,1,1,20")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

Expand Down Expand Up @@ -354,9 +360,12 @@ class OverWindowITCase extends StreamingWithStateTestBase {
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)

tEnv.registerTable("T1", t1)
tEnv.registerFunction("LTCNT", new LargerThanCount)

val sqlQuery = "SELECT " +
" c, a, " +
" LTCNT(a, CAST('4' AS BIGINT)) " +
" OVER (PARTITION BY c ORDER BY rowtime ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), " +
" COUNT(a) " +
" OVER (PARTITION BY c ORDER BY rowtime ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), " +
" SUM(a) " +
Expand All @@ -368,12 +377,12 @@ class OverWindowITCase extends StreamingWithStateTestBase {
env.execute()

val expected = List(
"Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
"Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6",
"Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12",
"Hello,6,3,15",
"Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21",
"Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35")
"Hello,1,0,1,1", "Hello,1,0,2,2", "Hello,1,0,3,3",
"Hello,2,0,3,4", "Hello,2,0,3,5", "Hello,2,0,3,6",
"Hello,3,0,3,7", "Hello,4,0,3,9", "Hello,5,1,3,12",
"Hello,6,2,3,15",
"Hello World,7,1,1,7", "Hello World,7,2,2,14", "Hello World,7,3,3,21",
"Hello World,7,3,3,21", "Hello World,8,3,3,22", "Hello World,20,3,3,35")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

Expand Down Expand Up @@ -518,6 +527,8 @@ class OverWindowITCase extends StreamingWithStateTestBase {
StreamITCase.clear

val sqlQuery = "SELECT a, b, c, " +
" LTCNT(b, CAST('4' AS BIGINT)) OVER(" +
" PARTITION BY a ORDER BY rowtime RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), " +
" SUM(b) OVER (" +
" PARTITION BY a ORDER BY rowtime RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), " +
" COUNT(b) OVER (" +
Expand Down Expand Up @@ -552,25 +563,26 @@ class OverWindowITCase extends StreamingWithStateTestBase {
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)

tEnv.registerTable("T1", t1)
tEnv.registerFunction("LTCNT", new LargerThanCount)

val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()

val expected = List(
"1,1,Hello,6,3,2,3,1",
"1,2,Hello,6,3,2,3,1",
"1,3,Hello world,6,3,2,3,1",
"1,1,Hi,7,4,1,3,1",
"2,1,Hello,1,1,1,1,1",
"2,2,Hello world,6,3,2,3,1",
"2,3,Hello world,6,3,2,3,1",
"1,4,Hello world,11,5,2,4,1",
"1,5,Hello world,29,8,3,7,1",
"1,6,Hello world,29,8,3,7,1",
"1,7,Hello world,29,8,3,7,1",
"2,4,Hello world,15,5,3,5,1",
"2,5,Hello world,15,5,3,5,1")
"1,1,Hello,0,6,3,2,3,1",
"1,2,Hello,0,6,3,2,3,1",
"1,3,Hello world,0,6,3,2,3,1",
"1,1,Hi,0,7,4,1,3,1",
"2,1,Hello,0,1,1,1,1,1",
"2,2,Hello world,0,6,3,2,3,1",
"2,3,Hello world,0,6,3,2,3,1",
"1,4,Hello world,0,11,5,2,4,1",
"1,5,Hello world,3,29,8,3,7,1",
"1,6,Hello world,3,29,8,3,7,1",
"1,7,Hello world,3,29,8,3,7,1",
"2,4,Hello world,1,15,5,3,5,1",
"2,5,Hello world,1,15,5,3,5,1")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

Expand All @@ -583,6 +595,8 @@ class OverWindowITCase extends StreamingWithStateTestBase {
StreamITCase.testResults = mutable.MutableList()

val sqlQuery = "SELECT a, b, c, " +
"LTCNT(b, CAST('4' AS BIGINT)) over(" +
"partition by a order by rowtime rows between unbounded preceding and current row), " +
"SUM(b) over (" +
"partition by a order by rowtime rows between unbounded preceding and current row), " +
"count(b) over (" +
Expand Down Expand Up @@ -618,26 +632,27 @@ class OverWindowITCase extends StreamingWithStateTestBase {
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)

tEnv.registerTable("T1", t1)
tEnv.registerFunction("LTCNT", new LargerThanCount)

val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()

val expected = mutable.MutableList(
"1,2,Hello,2,1,2,2,2",
"1,3,Hello world,5,2,2,3,2",
"1,1,Hi,6,3,2,3,1",
"2,1,Hello,1,1,1,1,1",
"2,2,Hello world,3,2,1,2,1",
"3,1,Hello,1,1,1,1,1",
"3,2,Hello world,3,2,1,2,1",
"1,5,Hello world,11,4,2,5,1",
"1,6,Hello world,17,5,3,6,1",
"1,9,Hello world,26,6,4,9,1",
"1,8,Hello world,34,7,4,9,1",
"1,7,Hello world,41,8,5,9,1",
"2,5,Hello world,8,3,2,5,1",
"3,5,Hello world,8,3,2,5,1")
"1,2,Hello,0,2,1,2,2,2",
"1,3,Hello world,0,5,2,2,3,2",
"1,1,Hi,0,6,3,2,3,1",
"2,1,Hello,0,1,1,1,1,1",
"2,2,Hello world,0,3,2,1,2,1",
"3,1,Hello,0,1,1,1,1,1",
"3,2,Hello world,0,3,2,1,2,1",
"1,5,Hello world,1,11,4,2,5,1",
"1,6,Hello world,2,17,5,3,6,1",
"1,9,Hello world,3,26,6,4,9,1",
"1,8,Hello world,4,34,7,4,9,1",
"1,7,Hello world,5,41,8,5,9,1",
"2,5,Hello world,1,8,3,2,5,1",
"3,5,Hello world,1,8,3,2,5,1")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

Expand Down Expand Up @@ -852,3 +867,19 @@ object OverWindowITCase {
override def cancel(): Unit = ???
}
}

/** Counts how often the first argument was larger than the second argument. */
class LargerThanCount extends AggregateFunction[Long, Tuple1[Long]] {

def accumulate(acc: Tuple1[Long], a: Long, b: Long): Unit = {
if (a > b) acc.f0 += 1
}

def retract(acc: Tuple1[Long], a: Long, b: Long): Unit = {
if (a > b) acc.f0 -= 1
}

override def createAccumulator(): Tuple1[Long] = Tuple1.of(0L)

override def getValue(acc: Tuple1[Long]): Long = acc.f0
}

0 comments on commit 16b0882

Please sign in to comment.