Skip to content

Commit

Permalink
[minor] update streaming linear algorithms
Browse files Browse the repository at this point in the history
Author: Xiangrui Meng <[email protected]>

Closes apache#4329 from mengxr/streaming-lr and squashes the following commits:

78731e1 [Xiangrui Meng] update streaming linear algorithms
  • Loading branch information
mengxr committed Feb 3, 2015
1 parent 980764f commit 659329f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] (

/** Set the initial weights. Default: [0.0, 0.0]. */
def setInitialWeights(initialWeights: Vector): this.type = {
this.model = Option(algorithm.createModel(initialWeights, 0.0))
this.model = Some(algorithm.createModel(initialWeights, 0.0))
this
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.reflect.ClassTag

import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.streaming.dstream.DStream

/**
Expand Down Expand Up @@ -58,7 +58,7 @@ abstract class StreamingLinearAlgorithm[
A <: GeneralizedLinearAlgorithm[M]] extends Logging {

/** The model to be updated and used for prediction. */
protected var model: Option[M] = null
protected var model: Option[M] = None

/** The algorithm to use for updating. */
protected val algorithm: A
Expand All @@ -77,18 +77,25 @@ abstract class StreamingLinearAlgorithm[
* @param data DStream containing labeled data
*/
def trainOn(data: DStream[LabeledPoint]) {
if (Option(model) == None) {
logError("Model must be initialized before starting training")
throw new IllegalArgumentException
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting training.")
}
data.foreachRDD { (rdd, time) =>
model = Option(algorithm.run(rdd, model.get.weights))
logInfo("Model updated at time %s".format(time.toString))
val display = model.get.weights.size match {
case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
case _ => model.get.weights.toArray.mkString("[", ",", "]")
val initialWeights =
model match {
case Some(m) =>
m.weights
case None =>
val numFeatures = rdd.first().features.size
Vectors.dense(numFeatures)
}
logInfo("Current model: weights, %s".format (display))
model = Some(algorithm.run(rdd, initialWeights))
logInfo("Model updated at time %s".format(time.toString))
val display = model.get.weights.size match {
case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
case _ => model.get.weights.toArray.mkString("[", ",", "]")
}
logInfo("Current model: weights, %s".format (display))
}
}

Expand All @@ -99,10 +106,8 @@ abstract class StreamingLinearAlgorithm[
* @return DStream containing predictions
*/
def predictOn(data: DStream[Vector]): DStream[Double] = {
if (Option(model) == None) {
val msg = "Model must be initialized before starting prediction"
logError(msg)
throw new IllegalArgumentException(msg)
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction.")
}
data.map(model.get.predict)
}
Expand All @@ -114,10 +119,8 @@ abstract class StreamingLinearAlgorithm[
* @return DStream containing the input keys and the predictions as values
*/
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
if (Option(model) == None) {
val msg = "Model must be initialized before starting prediction"
logError(msg)
throw new IllegalArgumentException(msg)
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction")
}
data.mapValues(model.get.predict)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] (

/** Set the initial weights. Default: [0.0, 0.0]. */
def setInitialWeights(initialWeights: Vector): this.type = {
this.model = Option(algorithm.createModel(initialWeights, 0.0))
this.model = Some(algorithm.createModel(initialWeights, 0.0))
this
}

Expand Down

0 comments on commit 659329f

Please sign in to comment.