Skip to content

Commit

Permalink
[FLINK-4937] [table] Add incremental group window aggregation for str…
Browse files Browse the repository at this point in the history
…eaming Table API.

This closes apache#2792.
  • Loading branch information
sunjincheng121 authored and fhueske committed Nov 23, 2016
1 parent 06d252e commit 74e0971
Show file tree
Hide file tree
Showing 16 changed files with 848 additions and 265 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,22 @@ class DataSetAggregate(
}

override def translateToPlan(
tableEnv: BatchTableEnvironment,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
tableEnv: BatchTableEnvironment,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {

val config = tableEnv.getConfig

val groupingKeys = grouping.indices.toArray
// add grouping fields, position keys in the input, and input type
val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates(

val mapFunction = AggregateUtil.createPrepareMapFunction(
namedAggregates,
grouping,
inputType)

val groupReduceFunction = AggregateUtil.createAggregateGroupReduceFunction(
namedAggregates,
inputType,
getRowType,
rowRelDataType,
grouping)

val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(
Expand All @@ -111,10 +116,9 @@ class DataSetAggregate(
val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)
val prepareOpName = s"prepare select: ($aggString)"
val mappedInput = inputDS
.map(aggregateResult._1)
.map(mapFunction)
.name(prepareOpName)

val groupReduceFunction = aggregateResult._2
val rowTypeInfo = new RowTypeInfo(fieldTypes)

val result = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.flink.api.common.functions.RichGroupReduceFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.table.FlinkRelBuilder.NamedWindowProperty
Expand All @@ -31,12 +30,11 @@ import org.apache.flink.api.table.plan.logical._
import org.apache.flink.api.table.plan.nodes.FlinkAggregate
import org.apache.flink.api.table.plan.nodes.datastream.DataStreamAggregate._
import org.apache.flink.api.table.runtime.aggregate.AggregateUtil._
import org.apache.flink.api.table.runtime.aggregate._
import org.apache.flink.api.table.runtime.aggregate.{Aggregate, _}
import org.apache.flink.api.table.typeutils.TypeCheckUtils.isTimeInterval
import org.apache.flink.api.table.typeutils.{RowIntervalTypeInfo, RowTypeInfo, TimeIntervalTypeInfo, TypeConverter}
import org.apache.flink.api.table.{TableException, FlinkTypeFactory, Row, StreamTableEnvironment}
import org.apache.flink.api.table.{FlinkTypeFactory, Row, StreamTableEnvironment}
import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream}
import org.apache.flink.streaming.api.functions.windowing.{WindowFunction, AllWindowFunction}
import org.apache.flink.streaming.api.windowing.assigners._
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
Expand Down Expand Up @@ -103,30 +101,24 @@ class DataStreamAggregate(
}

override def translateToPlan(
tableEnv: StreamTableEnvironment,
expectedType: Option[TypeInformation[Any]])
: DataStream[Any] = {

val config = tableEnv.getConfig
tableEnv: StreamTableEnvironment,
expectedType: Option[TypeInformation[Any]]): DataStream[Any] = {

val config = tableEnv.getConfig
val groupingKeys = grouping.indices.toArray
// add grouping fields, position keys in the input, and input type
val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates(
namedAggregates,
inputType,
getRowType,
grouping)

val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(
tableEnv,
// tell the input operator that this operator currently only supports Rows as input
Some(TypeConverter.DEFAULT_ROW_TYPE))

// get the output types
val fieldTypes: Array[TypeInformation[_]] = getRowType.getFieldList.asScala
val fieldTypes: Array[TypeInformation[_]] =
getRowType.getFieldList.asScala
.map(field => FlinkTypeFactory.toTypeInfo(field.getType))
.toArray

val rowTypeInfo = new RowTypeInfo(fieldTypes)

val aggString = aggregationToString(
inputType,
grouping,
Expand All @@ -135,50 +127,118 @@ class DataStreamAggregate(
namedProperties)

val prepareOpName = s"prepare select: ($aggString)"
val mappedInput = inputDS
.map(aggregateResult._1)
.name(prepareOpName)

val groupReduceFunction = aggregateResult._2
val rowTypeInfo = new RowTypeInfo(fieldTypes)
val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " +
s"window: ($window), " +
s"select: ($aggString)"
val nonKeyedAggOpName = s"window: ($window), select: ($aggString)"

val result = {
// grouped / keyed aggregation
if (groupingKeys.length > 0) {
val aggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " +
s"window: ($window), " +
s"select: ($aggString)"
val aggregateFunction =
createWindowAggregationFunction(window, namedProperties, groupReduceFunction)

val keyedStream = mappedInput.keyBy(groupingKeys: _*)

val windowedStream = createKeyedWindowedStream(window, keyedStream)
.asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]

windowedStream
.apply(aggregateFunction)
val mapFunction = AggregateUtil.createPrepareMapFunction(
namedAggregates,
grouping,
inputType)

val mappedInput = inputDS.map(mapFunction).name(prepareOpName)

val result: DataStream[Any] = {
// check whether all aggregates support partial aggregate
if (AggregateUtil.doAllSupportPartialAggregation(
namedAggregates.map(_.getKey),
inputType,
grouping.length)) {
// do Incremental Aggregation
val reduceFunction = AggregateUtil.createIncrementalAggregateReduceFunction(
namedAggregates,
inputType,
getRowType,
grouping)
// grouped / keyed aggregation
if (groupingKeys.length > 0) {
val windowFunction = AggregateUtil.createWindowIncrementalAggregationFunction(
window,
namedAggregates,
inputType,
rowRelDataType,
grouping,
namedProperties)

val keyedStream = mappedInput.keyBy(groupingKeys: _*)
val windowedStream =
createKeyedWindowedStream(window, keyedStream)
.asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]

windowedStream
.apply(reduceFunction, windowFunction)
.returns(rowTypeInfo)
.name(keyedAggOpName)
.asInstanceOf[DataStream[Any]]
}
// global / non-keyed aggregation
else {
val windowFunction = AggregateUtil.createAllWindowIncrementalAggregationFunction(
window,
namedAggregates,
inputType,
rowRelDataType,
grouping,
namedProperties)

val windowedStream =
createNonKeyedWindowedStream(window, mappedInput)
.asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]

windowedStream
.apply(reduceFunction, windowFunction)
.returns(rowTypeInfo)
.name(aggOpName)
.name(nonKeyedAggOpName)
.asInstanceOf[DataStream[Any]]
}
}
// global / non-keyed aggregation
else {
val aggOpName = s"window: ($window), select: ($aggString)"
val aggregateFunction =
createAllWindowAggregationFunction(window, namedProperties, groupReduceFunction)

val windowedStream = createNonKeyedWindowedStream(window, mappedInput)
.asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]

windowedStream
.apply(aggregateFunction)
// do non-Incremental Aggregation
// grouped / keyed aggregation
if (groupingKeys.length > 0) {

val windowFunction = AggregateUtil.createWindowAggregationFunction(
window,
namedAggregates,
inputType,
rowRelDataType,
grouping,
namedProperties)

val keyedStream = mappedInput.keyBy(groupingKeys: _*)
val windowedStream =
createKeyedWindowedStream(window, keyedStream)
.asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]]

windowedStream
.apply(windowFunction)
.returns(rowTypeInfo)
.name(aggOpName)
.name(keyedAggOpName)
.asInstanceOf[DataStream[Any]]
}
// global / non-keyed aggregation
else {
val windowFunction = AggregateUtil.createAllWindowAggregationFunction(
window,
namedAggregates,
inputType,
rowRelDataType,
grouping,
namedProperties)

val windowedStream =
createNonKeyedWindowedStream(window, mappedInput)
.asInstanceOf[AllWindowedStream[Row, DataStreamWindow]]

windowedStream
.apply(windowFunction)
.returns(rowTypeInfo)
.name(nonKeyedAggOpName)
.asInstanceOf[DataStream[Any]]
}
}
}

// if the expected type is not a Row, inject a mapper to convert to the expected type
expectedType match {
case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>
Expand All @@ -196,72 +256,8 @@ class DataStreamAggregate(
}
}
}

object DataStreamAggregate {

private def createAllWindowAggregationFunction(
window: LogicalWindow,
properties: Seq[NamedWindowProperty],
aggFunction: RichGroupReduceFunction[Row, Row])
: AllWindowFunction[Row, Row, DataStreamWindow] = {

if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
new AggregateAllTimeWindowFunction(aggFunction, startPos, endPos)
.asInstanceOf[AllWindowFunction[Row, Row, DataStreamWindow]]
} else {
new AggregateAllWindowFunction(aggFunction)
}

}

private def createWindowAggregationFunction(
window: LogicalWindow,
properties: Seq[NamedWindowProperty],
aggFunction: RichGroupReduceFunction[Row, Row])
: WindowFunction[Row, Row, Tuple, DataStreamWindow] = {

if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
new AggregateTimeWindowFunction(aggFunction, startPos, endPos)
.asInstanceOf[WindowFunction[Row, Row, Tuple, DataStreamWindow]]
} else {
new AggregateWindowFunction(aggFunction)
}

}

private def isTimeWindow(window: LogicalWindow) = {
window match {
case ProcessingTimeTumblingGroupWindow(_, size) => isTimeInterval(size.resultType)
case ProcessingTimeSlidingGroupWindow(_, size, _) => isTimeInterval(size.resultType)
case ProcessingTimeSessionGroupWindow(_, _) => true
case EventTimeTumblingGroupWindow(_, _, size) => isTimeInterval(size.resultType)
case EventTimeSlidingGroupWindow(_, _, size, _) => isTimeInterval(size.resultType)
case EventTimeSessionGroupWindow(_, _, _) => true
}
}

def computeWindowStartEndPropertyPos(properties: Seq[NamedWindowProperty])
: (Option[Int], Option[Int]) = {

val propPos = properties.foldRight((None: Option[Int], None: Option[Int], 0)) {
(p, x) => p match {
case NamedWindowProperty(name, prop) =>
prop match {
case WindowStart(_) if x._1.isDefined =>
throw new TableException("Duplicate WindowStart property encountered. This is a bug.")
case WindowStart(_) =>
(Some(x._3), x._2, x._3 - 1)
case WindowEnd(_) if x._2.isDefined =>
throw new TableException("Duplicate WindowEnd property encountered. This is a bug.")
case WindowEnd(_) =>
(x._1, Some(x._3), x._3 - 1)
}
}
}
(propPos._1, propPos._2)
}

private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple])
: WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ class AggregateAllTimeWindowFunction(
groupReduceFunction: RichGroupReduceFunction[Row, Row],
windowStartPos: Option[Int],
windowEndPos: Option[Int])

extends RichAllWindowFunction[Row, Row, TimeWindow] {
extends AggregateAllWindowFunction[TimeWindow](groupReduceFunction) {

private var collector: TimeWindowPropertyCollector = _

override def open(parameters: Configuration): Unit = {
groupReduceFunction.open(parameters)
collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos)
super.open(parameters)
}

override def apply(window: TimeWindow, input: Iterable[Row], out: Collector[Row]): Unit = {
Expand All @@ -48,6 +47,6 @@ class AggregateAllTimeWindowFunction(
collector.timeWindow = window

// call wrapped reduce function with property collector
groupReduceFunction.reduce(input, collector)
super.apply(window, input, collector)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction
import org.apache.flink.streaming.api.windowing.windows.Window
import org.apache.flink.util.Collector

class AggregateAllWindowFunction(groupReduceFunction: RichGroupReduceFunction[Row, Row])
extends RichAllWindowFunction[Row, Row, Window] {
class AggregateAllWindowFunction[W <: Window](
groupReduceFunction: RichGroupReduceFunction[Row, Row])
extends RichAllWindowFunction[Row, Row, W] {

override def open(parameters: Configuration): Unit = {
groupReduceFunction.open(parameters)
}

override def apply(window: Window, input: Iterable[Row], out: Collector[Row]): Unit = {
override def apply(window: W, input: Iterable[Row], out: Collector[Row]): Unit = {
groupReduceFunction.reduce(input, out)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class AggregateMapFunction[IN, OUT](
private val aggFields: Array[Int],
private val groupingKeys: Array[Int],
@transient private val returnType: TypeInformation[OUT])
extends RichMapFunction[IN, OUT] with ResultTypeQueryable[OUT] {
extends RichMapFunction[IN, OUT]
with ResultTypeQueryable[OUT] {

private var output: Row = _

Expand Down
Loading

0 comments on commit 74e0971

Please sign in to comment.