Skip to content

Commit 5cb4695

Browse files
hhbyyhjkbradley
authored andcommitted
[SPARK-11605][MLLIB] ML 1.6 QA: API: Java compatibility, docs
jira: https://issues.apache.org/jira/browse/SPARK-11605 Check Java compatibility for MLlib for this release. fix: 1. `StreamingTest.registerStream` needs java friendly interface. 2. `GradientBoostedTreesModel.computeInitialPredictionAndError` and `GradientBoostedTreesModel.updatePredictionError` has java compatibility issue. Mark them as `developerAPI`. TBD: [updated] no fix for now per discussion. `org.apache.spark.mllib.classification.LogisticRegressionModel` `public scala.Option<java.lang.Object> getThreshold();` has wrong return type for Java invocation. `SVMModel` has the similar issue. Yet adding a `scala.Option<java.util.Double> getThreshold()` would result in an overloading error due to the same function signature. And adding a new function with different name seems to be not necessary. cc jkbradley feynmanliang Author: Yuhao Yang <[email protected]> Closes apache#10102 from hhbyyh/javaAPI.
1 parent 4bcb894 commit 5cb4695

File tree

5 files changed

+96
-27
lines changed

5 files changed

+96
-27
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.examples.mllib
1919

2020
import org.apache.spark.SparkConf
21-
import org.apache.spark.mllib.stat.test.StreamingTest
21+
import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest}
2222
import org.apache.spark.streaming.{Seconds, StreamingContext}
2323
import org.apache.spark.util.Utils
2424

@@ -66,7 +66,7 @@ object StreamingTestExample {
6666

6767
// $example on$
6868
val data = ssc.textFileStream(dataDir).map(line => line.split(",") match {
69-
case Array(label, value) => (label.toBoolean, value.toDouble)
69+
case Array(label, value) => BinarySample(label.toBoolean, value.toDouble)
7070
})
7171

7272
val streamingTest = new StreamingTest()

mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala

+41-9
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,30 @@
1717

1818
package org.apache.spark.mllib.stat.test
1919

20+
import scala.beans.BeanInfo
21+
2022
import org.apache.spark.Logging
2123
import org.apache.spark.annotation.{Experimental, Since}
22-
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.streaming.api.java.JavaDStream
2325
import org.apache.spark.streaming.dstream.DStream
2426
import org.apache.spark.util.StatCounter
2527

28+
/**
29+
* Class that represents the group and value of a sample.
30+
*
31+
* @param isExperiment if the sample is of the experiment group.
32+
* @param value numeric value of the observation.
33+
*/
34+
@Since("1.6.0")
35+
@BeanInfo
36+
case class BinarySample @Since("1.6.0") (
37+
@Since("1.6.0") isExperiment: Boolean,
38+
@Since("1.6.0") value: Double) {
39+
override def toString: String = {
40+
s"($isExperiment, $value)"
41+
}
42+
}
43+
2644
/**
2745
* :: Experimental ::
2846
* Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The
@@ -83,23 +101,36 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
83101
/**
84102
* Register a [[DStream]] of values for significance testing.
85103
*
86-
* @param data stream of (key,value) pairs where the key denotes group membership (true =
87-
* experiment, false = control) and the value is the numerical metric to test for
88-
* significance
104+
* @param data stream of BinarySample(key,value) pairs where the key denotes group membership
105+
* (true = experiment, false = control) and the value is the numerical metric to
106+
* test for significance
89107
* @return stream of significance testing results
90108
*/
91109
@Since("1.6.0")
92-
def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = {
110+
def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = {
93111
val dataAfterPeacePeriod = dropPeacePeriod(data)
94112
val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod)
95113
val pairedSummaries = pairSummaries(summarizedData)
96114

97115
testMethod.doTest(pairedSummaries)
98116
}
99117

118+
/**
119+
* Register a [[JavaDStream]] of values for significance testing.
120+
*
121+
* @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes
122+
* group (true = experiment, false = control) and the value is the numerical metric
123+
* to test for significance
124+
* @return stream of significance testing results
125+
*/
126+
@Since("1.6.0")
127+
def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = {
128+
JavaDStream.fromDStream(registerStream(data.dstream))
129+
}
130+
100131
/** Drop all batches inside the peace period. */
101132
private[stat] def dropPeacePeriod(
102-
data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = {
133+
data: DStream[BinarySample]): DStream[BinarySample] = {
103134
data.transform { (rdd, time) =>
104135
if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) {
105136
rdd
@@ -111,17 +142,18 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
111142

112143
/** Compute summary statistics over each key and the specified test window size. */
113144
private[stat] def summarizeByKeyAndWindow(
114-
data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = {
145+
data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = {
146+
val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value))
115147
if (this.windowSize == 0) {
116-
data.updateStateByKey[StatCounter](
148+
categoryValuePair.updateStateByKey[StatCounter](
117149
(newValues: Seq[Double], oldSummary: Option[StatCounter]) => {
118150
val newSummary = oldSummary.getOrElse(new StatCounter())
119151
newSummary.merge(newValues)
120152
Some(newSummary)
121153
})
122154
} else {
123155
val windowDuration = data.slideDuration * this.windowSize
124-
data
156+
categoryValuePair
125157
.groupByKeyAndWindow(windowDuration)
126158
.mapValues { values =>
127159
val summary = new StatCounter()

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.json4s.JsonDSL._
2525
import org.json4s.jackson.JsonMethods._
2626

2727
import org.apache.spark.{Logging, SparkContext}
28-
import org.apache.spark.annotation.Since
28+
import org.apache.spark.annotation.{DeveloperApi, Since}
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.mllib.linalg.Vector
3131
import org.apache.spark.mllib.regression.LabeledPoint
@@ -186,6 +186,7 @@ class GradientBoostedTreesModel @Since("1.2.0") (
186186
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
187187

188188
/**
189+
* :: DeveloperApi ::
189190
* Compute the initial predictions and errors for a dataset for the first
190191
* iteration of gradient boosting.
191192
* @param data: training data.
@@ -196,6 +197,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
196197
* corresponding to every sample.
197198
*/
198199
@Since("1.4.0")
200+
@DeveloperApi
199201
def computeInitialPredictionAndError(
200202
data: RDD[LabeledPoint],
201203
initTreeWeight: Double,
@@ -209,6 +211,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
209211
}
210212

211213
/**
214+
* :: DeveloperApi ::
212215
* Update a zipped predictionError RDD
213216
* (as obtained with computeInitialPredictionAndError)
214217
* @param data: training data.
@@ -220,6 +223,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
220223
* corresponding to each sample.
221224
*/
222225
@Since("1.4.0")
226+
@DeveloperApi
223227
def updatePredictionError(
224228
data: RDD[LabeledPoint],
225229
predictionAndError: RDD[(Double, Double)],

mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java

+35-3
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,49 @@
1818
package org.apache.spark.mllib.stat;
1919

2020
import java.io.Serializable;
21-
2221
import java.util.Arrays;
22+
import java.util.List;
2323

2424
import org.junit.After;
2525
import org.junit.Before;
2626
import org.junit.Test;
2727

28+
import static org.apache.spark.streaming.JavaTestUtils.*;
2829
import static org.junit.Assert.assertEquals;
2930

31+
import org.apache.spark.SparkConf;
3032
import org.apache.spark.api.java.JavaRDD;
3133
import org.apache.spark.api.java.JavaDoubleRDD;
3234
import org.apache.spark.api.java.JavaSparkContext;
3335
import org.apache.spark.mllib.linalg.Vectors;
3436
import org.apache.spark.mllib.regression.LabeledPoint;
37+
import org.apache.spark.mllib.stat.test.BinarySample;
3538
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
3639
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
40+
import org.apache.spark.mllib.stat.test.StreamingTest;
41+
import org.apache.spark.streaming.Duration;
42+
import org.apache.spark.streaming.api.java.JavaDStream;
43+
import org.apache.spark.streaming.api.java.JavaStreamingContext;
3744

3845
public class JavaStatisticsSuite implements Serializable {
3946
private transient JavaSparkContext sc;
47+
private transient JavaStreamingContext ssc;
4048

4149
@Before
4250
public void setUp() {
43-
sc = new JavaSparkContext("local", "JavaStatistics");
51+
SparkConf conf = new SparkConf()
52+
.setMaster("local[2]")
53+
.setAppName("JavaStatistics")
54+
.set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
55+
sc = new JavaSparkContext(conf);
56+
ssc = new JavaStreamingContext(sc, new Duration(1000));
57+
ssc.checkpoint("checkpoint");
4458
}
4559

4660
@After
4761
public void tearDown() {
48-
sc.stop();
62+
ssc.stop();
63+
ssc = null;
4964
sc = null;
5065
}
5166

@@ -76,4 +91,21 @@ public void chiSqTest() {
7691
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
7792
ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
7893
}
94+
95+
@Test
96+
public void streamingTest() {
97+
List<BinarySample> trainingBatch = Arrays.asList(
98+
new BinarySample(true, 1.0),
99+
new BinarySample(false, 2.0));
100+
JavaDStream<BinarySample> training =
101+
attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
102+
int numBatches = 2;
103+
StreamingTest model = new StreamingTest()
104+
.setWindowSize(0)
105+
.setPeacePeriod(0)
106+
.setTestMethod("welch");
107+
model.registerStream(training);
108+
attachTestOutputStream(training);
109+
runStreams(ssc, numBatches, numBatches);
110+
}
79111
}

mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala

+13-12
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.mllib.stat
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest}
21+
import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest,
22+
WelchTTest, BinarySample}
2223
import org.apache.spark.streaming.TestSuiteBase
2324
import org.apache.spark.streaming.dstream.DStream
2425
import org.apache.spark.util.StatCounter
@@ -48,7 +49,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
4849

4950
// setup and run the model
5051
val ssc = setupStreams(
51-
input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
52+
input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
5253
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
5354

5455
assert(outputBatches.flatten.forall(res =>
@@ -75,7 +76,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
7576

7677
// setup and run the model
7778
val ssc = setupStreams(
78-
input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
79+
input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
7980
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
8081

8182
assert(outputBatches.flatten.forall(res =>
@@ -102,7 +103,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
102103

103104
// setup and run the model
104105
val ssc = setupStreams(
105-
input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
106+
input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
106107
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
107108

108109

@@ -130,7 +131,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
130131

131132
// setup and run the model
132133
val ssc = setupStreams(
133-
input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
134+
input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
134135
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
135136

136137
assert(outputBatches.flatten.forall(res =>
@@ -157,7 +158,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
157158
// setup and run the model
158159
val ssc = setupStreams(
159160
input,
160-
(inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream))
161+
(inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream))
161162
val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches)
162163
val outputCounts = outputBatches.flatten.map(_._2.count)
163164

@@ -190,7 +191,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
190191

191192
// setup and run the model
192193
val ssc = setupStreams(
193-
input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream))
194+
input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream))
194195
val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches)
195196

196197
assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch)
@@ -210,11 +211,11 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
210211
.setPeacePeriod(0)
211212

212213
val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
213-
.map(batch => batch.filter(_._1)) // only keep one test group
214+
.map(batch => batch.filter(_.isExperiment)) // only keep one test group
214215

215216
// setup and run the model
216217
val ssc = setupStreams(
217-
input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
218+
input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
218219
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
219220

220221
assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001))
@@ -228,13 +229,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
228229
stdevA: Double,
229230
meanB: Double,
230231
stdevB: Double,
231-
seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = {
232+
seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = {
232233
val rand = new XORShiftRandom(seed)
233234
val numTrues = pointsPerBatch / 2
234235
val data = (0 until numBatches).map { i =>
235-
(0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++
236+
(0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++
236237
(pointsPerBatch / 2 until pointsPerBatch).map { idx =>
237-
(false, meanB + stdevB * rand.nextGaussian())
238+
BinarySample(false, meanB + stdevB * rand.nextGaussian())
238239
}
239240
}
240241

0 commit comments

Comments
 (0)