Skip to content

Commit

Permalink
[SPARK-11663][STREAMING] Add Java API for trackStateByKey
Browse files Browse the repository at this point in the history
TODO
- [x] Add Java API
- [x] Add API tests
- [x] Add a function test

Author: Shixiong Zhu <[email protected]>

Closes #9636 from zsxwing/java-track.
  • Loading branch information
zsxwing authored and tdas committed Nov 13, 2015
1 parent 41bbd23 commit 0f1d00a
Show file tree
Hide file tree
Showing 12 changed files with 485 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;

/**
* A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R.
*/
public interface Function4<T1, T2, T3, T4, R> extends Serializable {
public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,15 @@
import com.google.common.base.Optional;
import com.google.common.collect.Lists;

import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import org.apache.spark.streaming.State;
import org.apache.spark.streaming.StateSpec;
import org.apache.spark.streaming.Time;
import org.apache.spark.streaming.api.java.*;

/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
Expand All @@ -63,25 +60,12 @@ public static void main(String[] args) {

StreamingExamples.setStreamingLogLevels();

// Update the cumulative count function
final Function2<List<Integer>, Optional<Integer>, Optional<Integer>> updateFunction =
new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
@Override
public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
Integer newSum = state.or(0);
for (Integer value : values) {
newSum += value;
}
return Optional.of(newSum);
}
};

// Create the context with a 1 second batch size
SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount");
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint(".");

// Initial RDD input to updateStateByKey
// Initial RDD input to trackStateByKey
@SuppressWarnings("unchecked")
List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
new Tuple2<String, Integer>("world", 1));
Expand All @@ -105,9 +89,22 @@ public Tuple2<String, Integer> call(String s) {
}
});

// Update the cumulative count function
final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc =
new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() {

@Override
public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) {
int sum = one.or(0) + (state.exists() ? state.get() : 0);
Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
state.update(sum);
return Optional.of(output);
}
};

// This will give a Dstream made of state (which is the cumulative count of the words)
JavaPairDStream<String, Integer> stateDstream = wordsDstream.updateStateByKey(updateFunction,
new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD);
JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));

stateDstream.print();
ssc.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Initial RDD input to updateStateByKey
// Initial RDD input to trackStateByKey
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a ReceiverInputDStream on target ip:port and count the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.Function4;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaTrackStateDStream;

/**
* Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
Expand Down Expand Up @@ -831,4 +834,44 @@ public void testFlatMapValues() {
Assert.assertEquals(expected, result);
}

/**
* This test is only for testing the APIs. It's not necessary to run it.
*/
public void testTrackStateByAPI() {
JavaPairRDD<String, Boolean> initialRDD = null;
JavaPairDStream<String, Integer> wordsDstream = null;

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.trackStateByKey(
StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
// Use all State's methods here
state.exists();
state.get();
state.isTimingOut();
state.remove();
state.update(true);
return Optional.of(2.0);
}).initialState(initialRDD)
.numPartitions(10)
.partitioner(new HashPartitioner(10))
.timeout(Durations.seconds(10)));

JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
wordsDstream.trackStateByKey(
StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
state.exists();
state.get();
state.isTimingOut();
state.remove();
state.update(true);
return 2.0;
}).initialState(initialRDD)
.numPartitions(10)
.partitioner(new HashPartitioner(10))
.timeout(Durations.seconds(10)));

JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
}
}
25 changes: 23 additions & 2 deletions streaming/src/main/scala/org/apache/spark/streaming/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,30 @@ import org.apache.spark.annotation.Experimental
*
* }}}
*
* Java example:
* Java example of using `State`:
* {{{
* TODO(@zsxwing)
* // A tracking function that maintains an integer state and return a String
* Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
* new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
*
* @Override
* public Optional<String> call(Optional<Integer> one, State<Integer> state) {
* if (state.exists()) {
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
* if (shouldRemove) {
* state.remove(); // Remove the state
* } else {
* int newState = ...;
* state.update(newState); // Set the new state
* }
* } else {
* int initialState = ...; // Set the initial state
* state.update(initialState);
* }
* // return something
* }
* };
* }}}
*/
@Experimental
Expand Down
84 changes: 62 additions & 22 deletions streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

package org.apache.spark.streaming

import scala.reflect.ClassTag

import com.google.common.base.Optional
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.{HashPartitioner, Partitioner}


/**
* :: Experimental ::
* Abstract class representing all the specifications of the DStream transformation
Expand All @@ -49,12 +48,12 @@ import org.apache.spark.{HashPartitioner, Partitioner}
*
* Example in Java:
* {{{
* StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
* StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
* StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
* StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* .numPartition(10);
*
* JavaDStream[EmittedDataType] emittedRecordDStream =
* javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
* JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
* javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* }}}
*/
@Experimental
Expand Down Expand Up @@ -92,6 +91,7 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
/**
* :: Experimental ::
* Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
* that is used for specifying the parameters of the DStream transformation `trackStateByKey`
* that is used for specifying the parameters of the DStream transformation
* `trackStateByKey` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
Expand All @@ -103,28 +103,27 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
* ...
* }
*
* val spec = StateSpec.function(trackingFunction).numPartitions(10)
*
* val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
* val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
* StateSpec.function(trackingFunction).numPartitions(10))
* }}}
*
* Example in Java:
* {{{
* StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
* StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
* StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
* StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* .numPartition(10);
*
* JavaDStream[EmittedDataType] emittedRecordDStream =
* javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
* JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
* javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* }}}
*/
@Experimental
object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
* `trackStateByKey` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* of the `trackStateByKey` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @tparam KeyType Class of the keys
Expand All @@ -141,9 +140,9 @@ object StateSpec {

/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
* `trackStateByKey` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* of the `trackStateByKey` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @tparam ValueType Class of the values
Expand All @@ -160,6 +159,48 @@ object StateSpec {
}
new StateSpecImpl(wrappedFunction)
}

/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
* the specifications of the `trackStateByKey` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
* @param javaTrackingFunction The function applied on every data item to manage the associated
* state and generate the emitted data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
* @tparam EmittedType Class of the emitted data
*/
def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
StateSpec[KeyType, ValueType, StateType, EmittedType] = {
val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
Option(t.orNull)
}
StateSpec.function(trackingFunc)
}

/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
* of the `trackStateByKey` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
* @param javaTrackingFunction The function applied on every data item to manage the associated
* state and generate the emitted data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
* @tparam EmittedType Class of the emitted data
*/
def function[KeyType, ValueType, StateType, EmittedType](
javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
StateSpec[KeyType, ValueType, StateType, EmittedType] = {
val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
javaTrackingFunction.call(Optional.fromNullable(v.get), s)
}
StateSpec.function(trackingFunc)
}
}


Expand All @@ -184,7 +225,6 @@ case class StateSpecImpl[K, V, S, T](
this
}


override def numPartitions(numPartitions: Int): this.type = {
this.partitioner(new HashPartitioner(numPartitions))
this
Expand Down
Loading

0 comments on commit 0f1d00a

Please sign in to comment.