Skip to content

Commit

Permalink
[SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical signi…
Browse files Browse the repository at this point in the history
…ficance testing

Implementation of significance testing using Streaming API.

Author: Feynman Liang <[email protected]>
Author: Feynman Liang <[email protected]>

Closes #4716 from feynmanliang/ab_testing.
  • Loading branch information
Feynman Liang authored and mengxr committed Sep 21, 2015
1 parent ba882db commit aeef44a
Show file tree
Hide file tree
Showing 5 changed files with 667 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import org.apache.spark.SparkConf
import org.apache.spark.mllib.stat.test.StreamingTest
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.util.Utils

/**
* Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data
* stream arrives as text files in a directory. Stops when the two groups are statistically
* significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded.
*
* The rows of the text files must be in the form `Boolean, Double`. For example:
* false, -3.92
* true, 99.32
*
* Usage:
* StreamingTestExample <dataDir> <batchDuration> <numBatchesTimeout>
*
* To run on your local machine using the directory `dataDir` with 5 seconds between each batch and
* a timeout after 100 insignificant batches, call:
* $ bin/run-example mllib.StreamingTestExample dataDir 5 100
*
* As you add text files to `dataDir` the significance test wil continually update every
* `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of
* batches processed exceeds `numBatchesTimeout`.
*/
object StreamingTestExample {

def main(args: Array[String]) {
if (args.length != 3) {
// scalastyle:off println
System.err.println(
"Usage: StreamingTestExample " +
"<dataDir> <batchDuration> <numBatchesTimeout>")
// scalastyle:on println
System.exit(1)
}
val dataDir = args(0)
val batchDuration = Seconds(args(1).toLong)
val numBatchesTimeout = args(2).toInt

val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample")
val ssc = new StreamingContext(conf, batchDuration)
ssc.checkpoint({
val dir = Utils.createTempDir()
dir.toString
})

val data = ssc.textFileStream(dataDir).map(line => line.split(",") match {
case Array(label, value) => (label.toBoolean, value.toDouble)
})

val streamingTest = new StreamingTest()
.setPeacePeriod(0)
.setWindowSize(0)
.setTestMethod("welch")

val out = streamingTest.registerStream(data)
out.print()

// Stop processing if test becomes significant or we time out
var timeoutCounter = numBatchesTimeout
out.foreachRDD { rdd =>
timeoutCounter -= 1
val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _)
if (timeoutCounter == 0 || anySignificant) rdd.context.stop()
}

ssc.start()
ssc.awaitTermination()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.stat.test

import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.StatCounter

/**
* :: Experimental ::
* Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The
* Boolean identifies which sample each observation comes from, and the Double is the numeric value
* of the observation.
*
* To address novelty affects, the `peacePeriod` specifies a set number of initial
* [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing.
*
* The `windowSize` sets the number of batches each significance test is to be performed over. The
* window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform
* cumulative processing, using all batches seen so far.
*
* Different tests may be used for assessing statistical significance depending on assumptions
* satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies
* which test will be used.
*
* Use a builder pattern to construct a streaming test in an application, for example:
* {{{
* val model = new StreamingTest()
* .setPeacePeriod(10)
* .setWindowSize(0)
* .setTestMethod("welch")
* .registerStream(DStream)
* }}}
*/
@Experimental
@Since("1.6.0")
class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
private var peacePeriod: Int = 0
private var windowSize: Int = 0
private var testMethod: StreamingTestMethod = WelchTTest

/** Set the number of initial batches to ignore. Default: 0. */
@Since("1.6.0")
def setPeacePeriod(peacePeriod: Int): this.type = {
this.peacePeriod = peacePeriod
this
}

/**
* Set the number of batches to compute significance tests over. Default: 0.
* A value of 0 will use all batches seen so far.
*/
@Since("1.6.0")
def setWindowSize(windowSize: Int): this.type = {
this.windowSize = windowSize
this
}

/** Set the statistical method used for significance testing. Default: "welch" */
@Since("1.6.0")
def setTestMethod(method: String): this.type = {
this.testMethod = StreamingTestMethod.getTestMethodFromName(method)
this
}

/**
* Register a [[DStream]] of values for significance testing.
*
* @param data stream of (key,value) pairs where the key denotes group membership (true =
* experiment, false = control) and the value is the numerical metric to test for
* significance
* @return stream of significance testing results
*/
@Since("1.6.0")
def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = {
val dataAfterPeacePeriod = dropPeacePeriod(data)
val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod)
val pairedSummaries = pairSummaries(summarizedData)

testMethod.doTest(pairedSummaries)
}

/** Drop all batches inside the peace period. */
private[stat] def dropPeacePeriod(
data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = {
data.transform { (rdd, time) =>
if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) {
rdd
} else {
data.context.sparkContext.parallelize(Seq())
}
}
}

/** Compute summary statistics over each key and the specified test window size. */
private[stat] def summarizeByKeyAndWindow(
data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = {
if (this.windowSize == 0) {
data.updateStateByKey[StatCounter](
(newValues: Seq[Double], oldSummary: Option[StatCounter]) => {
val newSummary = oldSummary.getOrElse(new StatCounter())
newSummary.merge(newValues)
Some(newSummary)
})
} else {
val windowDuration = data.slideDuration * this.windowSize
data
.groupByKeyAndWindow(windowDuration)
.mapValues { values =>
val summary = new StatCounter()
values.foreach(value => summary.merge(value))
summary
}
}
}

/**
* Transform a stream of summaries into pairs representing summary statistics for control group
* and experiment group up to this batch.
*/
private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)])
: DStream[(StatCounter, StatCounter)] = {
summarizedData
.map[(Int, StatCounter)](x => (0, x._2))
.groupByKey() // should be length two (control/experiment group)
.map(x => (x._2.head, x._2.last))
}
}
Loading

0 comments on commit aeef44a

Please sign in to comment.