Skip to content

Commit

Permalink
[FLINK-8456] Add Scala API for Connected Streams with Broadcast State.
Browse files Browse the repository at this point in the history
  • Loading branch information
kl0u committed Feb 9, 2018
1 parent 8395508 commit 9628dc8
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ public DataStream<T> broadcast() {
/**
* Sets the partitioning of the {@link DataStream} so that the output elements
* are broadcasted to every parallel instance of the next operation. In addition,
* it implicitly creates a {@link org.apache.flink.api.common.state.BroadcastState broadcast state}
* which can be used to store the element of the stream.
* it implicitly as many {@link org.apache.flink.api.common.state.BroadcastState broadcast states}
* as the specified descriptors which can be used to store the element of the stream.
*
* @param broadcastStateDescriptors the descriptors of the broadcast states to create.
* @return A {@link BroadcastStream} which can be used in the {@link #connect(BroadcastStream)} to
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.scala

import org.apache.flink.annotation.PublicEvolving
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.streaming.api.datastream.{BroadcastConnectedStream => JavaBCStream}
import org.apache.flink.streaming.api.functions.co.{BroadcastProcessFunction, KeyedBroadcastProcessFunction}

class BroadcastConnectedStream[IN1, IN2](javaStream: JavaBCStream[IN1, IN2]) {

/**
* Assumes as inputs a [[org.apache.flink.streaming.api.datastream.BroadcastStream]] and a
* [[KeyedStream]] and applies the given [[KeyedBroadcastProcessFunction]] on them, thereby
* creating a transformed output stream.
*
* @param function The [[KeyedBroadcastProcessFunction]] applied to each element in the stream.
* @tparam KS The type of the keys in the keyed stream.
* @tparam OUT The type of the output elements.
* @return The transformed [[DataStream]].
*/
@PublicEvolving
def process[KS, OUT: TypeInformation](
function: KeyedBroadcastProcessFunction[KS, IN1, IN2, OUT])
: DataStream[OUT] = {

if (function == null) {
throw new NullPointerException("KeyedBroadcastProcessFunction function must not be null.")
}

val outputTypeInfo : TypeInformation[OUT] = implicitly[TypeInformation[OUT]]
asScalaStream(javaStream.process(function, outputTypeInfo))
}

/**
* Assumes as inputs a [[org.apache.flink.streaming.api.datastream.BroadcastStream]]
* and a non-keyed [[DataStream]] and applies the given
* [[org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction]]
* on them, thereby creating a transformed output stream.
*
* @param function The [[BroadcastProcessFunction]] applied to each element in the stream.
* @tparam OUT The type of the output elements.
* @return The transformed { @link DataStream}.
*/
@PublicEvolving
def process[OUT: TypeInformation](
function: BroadcastProcessFunction[IN1, IN2, OUT])
: DataStream[OUT] = {

if (function == null) {
throw new NullPointerException("BroadcastProcessFunction function must not be null.")
}

val outputTypeInfo : TypeInformation[OUT] = implicitly[TypeInformation[OUT]]
asScalaStream(javaStream.process(function, outputTypeInfo))
}

/**
* Returns a "closure-cleaned" version of the given function. Cleans only if closure cleaning
* is not disabled in the [[org.apache.flink.api.common.ExecutionConfig]]
*/
private[flink] def clean[F <: AnyRef](f: F) = {
new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.flink.api.common.functions.{FilterFunction, FlatMapFunction, M
import org.apache.flink.api.common.io.OutputFormat
import org.apache.flink.api.common.operators.ResourceSpec
import org.apache.flink.api.common.serialization.SerializationSchema
import org.apache.flink.api.common.state.MapStateDescriptor
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.tuple.{Tuple => JavaTuple}
Expand Down Expand Up @@ -363,6 +364,27 @@ class DataStream[T](stream: JavaStream[T]) {
def connect[T2](dataStream: DataStream[T2]): ConnectedStreams[T, T2] =
asScalaStream(stream.connect(dataStream.javaStream))

/**
* Creates a new [[BroadcastConnectedStream]] by connecting the current
* [[DataStream]] or [[KeyedStream]] with a [[BroadcastStream]].
*
* The latter can be created using the [[broadcast(MapStateDescriptor[])]] method.
*
* The resulting stream can be further processed using the
* ``broadcastConnectedStream.process(myFunction)``
* method, where ``myFunction`` can be either a
* [[org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction]]
* or a [[org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction]]
* depending on the current stream being a [[KeyedStream]] or not.
*
* @param broadcastStream The broadcast stream with the broadcast state to be
* connected with this stream.
* @return The [[BroadcastConnectedStream]].
*/
@PublicEvolving
def connect[R](broadcastStream: BroadcastStream[R]): BroadcastConnectedStream[T, R] =
asScalaStream(stream.connect(broadcastStream))

/**
* Groups the elements of a DataStream by the given key positions (for tuple/array types) to
* be used with grouped operators like grouped reduce or grouped aggregations.
Expand Down Expand Up @@ -441,6 +463,26 @@ class DataStream[T](stream: JavaStream[T]) {
*/
def broadcast: DataStream[T] = asScalaStream(stream.broadcast())

/**
* Sets the partitioning of the [[DataStream]] so that the output elements
* are broadcasted to every parallel instance of the next operation. In addition,
* it implicitly creates as many
* [[org.apache.flink.api.common.state.BroadcastState broadcast states]]
* as the specified descriptors which can be used to store the element of the stream.
*
* @param broadcastStateDescriptors the descriptors of the broadcast states to create.
* @return A [[BroadcastStream]] which can be used in the
* [[DataStream.connect(BroadcastStream)]] to create a
* [[BroadcastConnectedStream]] for further processing of the elements.
*/
@PublicEvolving
def broadcast(broadcastStateDescriptors: MapStateDescriptor[_, _]*): BroadcastStream[T] = {
if (broadcastStateDescriptors == null) {
throw new NullPointerException("Map function must not be null.")
}
stream.broadcast(broadcastStateDescriptors: _*)
}

/**
* Sets the partitioning of the DataStream so that the output values all go to
* the first instance of the next processing operator. Use this setting with care
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.flink.api.scala.typeutils.{CaseClassTypeInfo, TypeUtils}
import org.apache.flink.streaming.api.datastream.{ DataStream => JavaStream }
import org.apache.flink.streaming.api.datastream.{ SplitStream => SplitJavaStream }
import org.apache.flink.streaming.api.datastream.{ ConnectedStreams => ConnectedJavaStreams }
import org.apache.flink.streaming.api.datastream.{ BroadcastConnectedStream => BroadcastConnectedJavaStreams }
import org.apache.flink.streaming.api.datastream.{ KeyedStream => KeyedJavaStream }

import language.implicitConversions
Expand Down Expand Up @@ -61,8 +62,13 @@ package object scala {
*/
private[flink] def asScalaStream[IN1, IN2](stream: ConnectedJavaStreams[IN1, IN2])
= new ConnectedStreams[IN1, IN2](stream)
/**
* Converts an [[org.apache.flink.streaming.api.datastream.BroadcastConnectedStream]] to a
* [[org.apache.flink.streaming.api.scala.BroadcastConnectedStream]].
*/
private[flink] def asScalaStream[IN1, IN2](stream: BroadcastConnectedJavaStreams[IN1, IN2])
= new BroadcastConnectedStream[IN1, IN2](stream)


private[flink] def fieldNames2Indices(
typeInfo: TypeInformation[_],
fields: Array[String]): Array[Int] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.scala

import org.apache.flink.api.common.state.MapStateDescriptor
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.test.util.AbstractTestBase
import org.apache.flink.util.Collector
import org.junit.Assert.assertEquals
import org.junit.{Assert, Test}

/**
* ITCase for the [[org.apache.flink.api.common.state.BroadcastState]].
*/
class BroadcastStateITCase extends AbstractTestBase {

@Test
@throws[Exception]
def testConnectWithBroadcastTranslation(): Unit = {

val timerTimestamp = 100000L

val DESCRIPTOR = new MapStateDescriptor[Long, String](
"broadcast-state",
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
BasicTypeInfo.STRING_TYPE_INFO)

val expected = Map[Long, String](
0L -> "test:0",
1L -> "test:1",
2L -> "test:2",
3L -> "test:3",
4L -> "test:4",
5L -> "test:5")

val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)

val srcOne = env
.generateSequence(0L, 5L)
.assignTimestampsAndWatermarks(new AssignerWithPunctuatedWatermarks[Long]() {

override def extractTimestamp(element: Long, previousElementTimestamp: Long): Long =
element

override def checkAndGetNextWatermark(lastElement: Long, extractedTimestamp: Long) =
new Watermark(extractedTimestamp)

})
.keyBy((value: Long) => value)

val srcTwo = env
.fromCollection(expected.values.toSeq)
.assignTimestampsAndWatermarks(new AssignerWithPunctuatedWatermarks[String]() {

override def extractTimestamp(element: String, previousElementTimestamp: Long): Long =
element.split(":")(1).toLong

override def checkAndGetNextWatermark(lastElement: String, extractedTimestamp: Long) =
new Watermark(extractedTimestamp)
})

val broadcast = srcTwo.broadcast(DESCRIPTOR)
// the timestamp should be high enough to trigger the timer after all the elements arrive.
val output = srcOne.connect(broadcast)
.process(new TestBroadcastProcessFunction(100000L, expected))

output
.addSink(new TestSink(expected.size))
.setParallelism(1)
env.execute
}
}

class TestBroadcastProcessFunction(
expectedTimestamp: Long,
expectedBroadcastState: Map[Long, String])
extends KeyedBroadcastProcessFunction[Long, Long, String, String] {

val localDescriptor = new MapStateDescriptor[Long, String](
"broadcast-state",
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
BasicTypeInfo.STRING_TYPE_INFO)

@throws[Exception]
override def processElement(
value: Long,
ctx: KeyedBroadcastProcessFunction[Long, Long, String, String]#KeyedReadOnlyContext,
out: Collector[String]): Unit = {

ctx.timerService.registerEventTimeTimer(expectedTimestamp)
}

@throws[Exception]
override def processBroadcastElement(
value: String,
ctx: KeyedBroadcastProcessFunction[Long, Long, String, String]#KeyedContext,
out: Collector[String]): Unit = {

val key = value.split(":")(1).toLong
ctx.getBroadcastState(localDescriptor).put(key, value)
}

@throws[Exception]
override def onTimer(
timestamp: Long,
ctx: KeyedBroadcastProcessFunction[Long, Long, String, String]#OnTimerContext,
out: Collector[String]): Unit = {

var map = Map[Long, String]()

import scala.collection.JavaConversions._
for (entry <- ctx.getBroadcastState(localDescriptor).immutableEntries()) {
val v = expectedBroadcastState.get(entry.getKey).get
assertEquals(v, entry.getValue)
map += (entry.getKey -> entry.getValue)
}

Assert.assertEquals(expectedBroadcastState, map)

out.collect(timestamp.toString)
}
}

class TestSink(val expectedOutputCounter: Int) extends RichSinkFunction[String] {

var outputCounter: Int = 0

override def invoke(value: String) = {
outputCounter = outputCounter + 1
}

@throws[Exception]
override def close(): Unit = {
super.close()

// make sure that all the timers fired
assertEquals(expectedOutputCounter, outputCounter)
}
}

0 comments on commit 9628dc8

Please sign in to comment.