Skip to content

Commit

Permalink
[SPARK-12244][SPARK-12245][STREAMING] Rename trackStateByKey to mapWi…
Browse files Browse the repository at this point in the history
…thState and change tracking function signature

SPARK-12244:

Based on feedback from early users and personal experience attempting to explain it, the name trackStateByKey had two problem.
"trackState" is a completely new term which really does not give any intuition on what the operation is
the resultant data stream of objects returned by the function is called in docs as the "emitted" data for the lack of a better.
"mapWithState" makes sense because the API is like a mapping function like (Key, Value) => T with State as an additional parameter. The resultant data stream is "mapped data". So both problems are solved.

SPARK-12245:

From initial experiences, not having the key in the function makes it hard to return mapped stuff, as the whole information of the records is not there. Basically the user is restricted to doing something like mapValue() instead of map(). So adding the key as a parameter.

Author: Tathagata Das <[email protected]>

Closes apache#10224 from tdas/rename.
  • Loading branch information
tdas authored and zsxwing committed Dec 10, 2015
1 parent 2166c2a commit bd2cd4f
Show file tree
Hide file tree
Showing 13 changed files with 389 additions and 382 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static void main(String[] args) {
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint(".");

// Initial RDD input to trackStateByKey
// Initial state RDD input to mapWithState
@SuppressWarnings("unchecked")
List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
new Tuple2<String, Integer>("world", 1));
Expand All @@ -90,21 +90,21 @@ public Tuple2<String, Integer> call(String s) {
});

// Update the cumulative count function
final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc =
new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() {
final Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>> mappingFunc =
new Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>>() {

@Override
public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) {
public Tuple2<String, Integer> call(String word, Optional<Integer> one, State<Integer> state) {
int sum = one.or(0) + (state.exists() ? state.get() : 0);
Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
state.update(sum);
return Optional.of(output);
return output;
}
};

// This will give a Dstream made of state (which is the cumulative count of the words)
JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
// DStream made of get cumulative counts that get updated in every batch
JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD));

stateDstream.print();
ssc.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Initial RDD input to trackStateByKey
// Initial state RDD for mapWithState operation
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a ReceiverInputDStream on target ip:port and count the
Expand All @@ -58,17 +58,17 @@ object StatefulNetworkWordCount {
val words = lines.flatMap(_.split(" "))
val wordDstream = words.map(x => (x, 1))

// Update the cumulative count using updateStateByKey
// Update the cumulative count using mapWithState
// This will give a DStream made of state (which is the cumulative count of the words)
val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => {
val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {
val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
val output = (word, sum)
state.update(sum)
Some(output)
output
}

val stateDstream = wordDstream.trackStateByKey(
StateSpec.function(trackStateFunc).initialState(initialRDD))
val stateDstream = wordDstream.mapWithState(
StateSpec.function(mappingFunc).initialState(initialRDD))
stateDstream.print()
ssc.start()
ssc.awaitTermination()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.Function4;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;

/**
* Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
Expand Down Expand Up @@ -863,12 +861,12 @@ public void testFlatMapValues() {
/**
* This test is only for testing the APIs. It's not necessary to run it.
*/
public void testTrackStateByAPI() {
public void testMapWithStateAPI() {
JavaPairRDD<String, Boolean> initialRDD = null;
JavaPairDStream<String, Integer> wordsDstream = null;

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.trackStateByKey(
JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.mapWithState(
StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
// Use all State's methods here
state.exists();
Expand All @@ -884,9 +882,9 @@ StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state)

JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
wordsDstream.trackStateByKey(
StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
wordsDstream.mapWithState(
StateSpec.<String, Integer, Boolean, Double>function((key, value, state) -> {
state.exists();
state.get();
state.isTimingOut();
Expand All @@ -898,6 +896,6 @@ StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state)
.partitioner(new HashPartitioner(10))
.timeout(Durations.seconds(10)));

JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
JavaPairDStream<String, Boolean> mappedDStream = stateDstream2.stateSnapshots();
}
}
20 changes: 11 additions & 9 deletions streaming/src/main/scala/org/apache/spark/streaming/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
* a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* Abstract class for getting and updating the state in mapping function used in the `mapWithState`
* operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
* or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Scala example of using `State`:
* {{{
* // A tracking function that maintains an integer state and return a String
* def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
* // A mapping function that maintains an integer state and returns a String
* def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
* // Check if state exists
* if (state.exists) {
* val existingState = state.get // Get the existing state
Expand All @@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental
*
* Java example of using `State`:
* {{{
* // A tracking function that maintains an integer state and return a String
* Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
* new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
* // A mapping function that maintains an integer state and returns a String
* Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
* new Function3<String, Optional<Integer>, State<Integer>, String>() {
*
* @Override
* public Optional<String> call(Optional<Integer> one, State<Integer> state) {
* public String call(String key, Optional<Integer> value, State<Integer> state) {
* if (state.exists()) {
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
Expand All @@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
* }
* };
* }}}
*
* @tparam S Class of the state
*/
@Experimental
sealed abstract class State[S] {
Expand Down
Loading

0 comments on commit bd2cd4f

Please sign in to comment.