Skip to content

Commit

Permalink
[FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBa…
Browse files Browse the repository at this point in the history
…ckend and add harness tests for CollectAggFunction

This closes apache#7253
  • Loading branch information
dianfu authored and sunjincheng121 committed Dec 17, 2018
1 parent ff848d0 commit 9a45fca
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
def retract(acc: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
val count = acc.map.get(value)
if (count == 1) {
acc.map.remove(value)
} else {
acc.map.put(value, count - 1)
if (count != null) {
if (count == 1) {
acc.map.remove(value)
} else {
acc.map.put(value, count - 1)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.runtime.harness

import java.lang.{Integer => JInt}
import java.util.concurrent.ConcurrentLinkedQueue

import org.apache.flink.api.common.time.Time
import org.apache.flink.api.scala._
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.dataview.MapView
import org.apache.flink.table.dataview.StateMapView
import org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction
import org.apache.flink.table.runtime.harness.HarnessTestBase.TestStreamQueryConfig
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.types.Row
import org.junit.Assert.assertTrue
import org.junit.Test

import scala.collection.JavaConverters._
import scala.collection.mutable

class AggFunctionHarnessTest extends HarnessTestBase {
private val queryConfig = new TestStreamQueryConfig(Time.seconds(0), Time.seconds(0))

@Test
def testCollectAggregate(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)

val data = new mutable.MutableList[(JInt, String)]
val t = env.fromCollection(data).toTable(tEnv, 'a, 'b)
tEnv.registerTable("T", t)
val sqlQuery = tEnv.sqlQuery(
s"""
|SELECT
| b, collect(a)
|FROM (
| SELECT a, b
| FROM T
| GROUP BY a, b
|) GROUP BY b
|""".stripMargin)

val testHarness = createHarnessTester[String, CRow, CRow](
sqlQuery.toRetractStream[Row](queryConfig), "groupBy")

testHarness.setStateBackend(getStateBackend)
testHarness.open()

val operator = getOperator(testHarness)
val state = getState(
operator,
"function",
classOf[GroupAggProcessFunction],
"acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]]
assertTrue(state.isInstanceOf[StateMapView[_, _]])
assertTrue(operator.getKeyedStateBackend.isInstanceOf[RocksDBKeyedStateBackend[_]])

val expectedOutput = new ConcurrentLinkedQueue[Object]()

testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 1).asJava), 1))

testHarness.processElement(new StreamRecord(CRow(1: JInt, "bbb"), 1))
expectedOutput.add(new StreamRecord(CRow("bbb", Map(1 -> 1).asJava), 1))

testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 1).asJava), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2).asJava), 1))

testHarness.processElement(new StreamRecord(CRow(2: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 2).asJava), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2, 2 -> 1).asJava), 1))

// remove some state: state may be cleaned up by the state backend
// if not accessed beyond ttl time
operator.setCurrentKey(Row.of("aaa"))
state.remove(2)

// retract after state has been cleaned up
testHarness.processElement(new StreamRecord(CRow(false, 2: JInt, "aaa"), 1))

val result = testHarness.getOutput

verify(expectedOutput, result)

testHarness.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,27 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRIN
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.streaming.api.operators.{AbstractUdfStreamOperator, OneInputStreamOperator}
import org.apache.flink.streaming.api.scala.DataStream
import org.apache.flink.streaming.api.transformations._
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.table.api.dataview.DataView
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.runtime.aggregate.GeneratedAggregations
import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks}
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase
import org.apache.flink.table.utils.EncodingUtils
import org.junit.Rule
import org.junit.rules.ExpectedException

class HarnessTestBase {
// used for accurate exception information checking.
val expectedException = ExpectedException.none()
import _root_.scala.collection.JavaConversions._

@Rule
def thrown = expectedException
class HarnessTestBase extends StreamingWithStateTestBase {

val longMinWithRetractAggFunction: String =
EncodingUtils.encodeObjectToString(new LongMinWithRetractAggFunction)
Expand Down Expand Up @@ -491,13 +491,83 @@ class HarnessTestBase {
distinctCountFuncName,
distinctCountAggCode)

def createHarnessTester[KEY, IN, OUT](
dataStream: DataStream[_],
prefixOperatorName: String)
: KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = {

val transformation = extractExpectedTransformation(
dataStream.javaStream.getTransformation,
prefixOperatorName).asInstanceOf[OneInputTransformation[_, _]]
if (transformation == null) {
throw new Exception("Can not find the expected transformation")
}

val processOperator = transformation.getOperator.asInstanceOf[OneInputStreamOperator[IN, OUT]]
val keySelector = transformation.getStateKeySelector.asInstanceOf[KeySelector[IN, KEY]]
val keyType = transformation.getStateKeyType.asInstanceOf[TypeInformation[KEY]]

createHarnessTester(processOperator, keySelector, keyType)
.asInstanceOf[KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT]]
}

private def extractExpectedTransformation(
transformation: StreamTransformation[_],
prefixOperatorName: String): StreamTransformation[_] = {
def extractFromInputs(inputs: StreamTransformation[_]*): StreamTransformation[_] = {
for (input <- inputs) {
val t = extractExpectedTransformation(input, prefixOperatorName)
if (t != null) {
return t
}
}
null
}

transformation match {
case one: OneInputTransformation[_, _] =>
if (one.getName.startsWith(prefixOperatorName)) {
one
} else {
extractExpectedTransformation(one.getInput, prefixOperatorName)
}
case union: UnionTransformation[_] => extractFromInputs(union.getInputs.toSeq: _*)
case p: PartitionTransformation[_] => extractFromInputs(p.getInput)
case _: SourceTransformation[_] => null
case _ => throw new UnsupportedOperationException("This should not happen.")
}
}

def getState(
operator: AbstractUdfStreamOperator[_, _],
funcName: String,
funcClass: Class[_],
stateFieldName: String): DataView = {
val function = funcClass.getDeclaredField(funcName)
function.setAccessible(true)
val generatedAggregation =
function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
val cls = generatedAggregation.getClass
val stateField = cls.getDeclaredField(stateFieldName)
stateField.setAccessible(true)
stateField.get(generatedAggregation).asInstanceOf[DataView]
}

def createHarnessTester[IN, OUT, KEY](
operator: OneInputStreamOperator[IN, OUT],
keySelector: KeySelector[IN, KEY],
keyType: TypeInformation[KEY]): KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = {
new KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT](operator, keySelector, keyType)
}

def getOperator(testHarness: OneInputStreamOperatorTestHarness[_, _])
: AbstractUdfStreamOperator[_, _] = {
val operatorField = classOf[OneInputStreamOperatorTestHarness[_, _]]
.getDeclaredField("oneInputOperator")
operatorField.setAccessible(true)
operatorField.get(testHarness).asInstanceOf[AbstractUdfStreamOperator[_, _]]
}

def verify(expected: JQueue[Object], actual: JQueue[Object]): Unit = {
verify(expected, actual, new RowResultSortComparator)
}
Expand Down

0 comments on commit 9a45fca

Please sign in to comment.