Skip to content

Commit

Permalink
[SPARK-5986][MLLib] Add save/load for k-means
Browse files Browse the repository at this point in the history
This PR adds save/load for K-means as described in SPARK-5986. Python version will be added in another PR.

Author: Xusen Yin <[email protected]>

Closes apache#4951 from yinxusen/SPARK-5986 and squashes the following commits:

6dd74a0 [Xusen Yin] rewrite some functions and classes
cd390fd [Xusen Yin] add indexed point
b144216 [Xusen Yin] remove invalid comments
dce7055 [Xusen Yin] add save/load for k-means for SPARK-5986
  • Loading branch information
yinxusen authored and mengxr committed Mar 11, 2015
1 parent 2672374 commit 2d4e00e
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@

package org.apache.spark.mllib.clustering

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.Row

/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {

/** Total number of clusters. */
def k: Int = clusterCenters.length
Expand Down Expand Up @@ -58,4 +65,59 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {

private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
clusterCenters.map(new VectorWithNorm(_))

override def save(sc: SparkContext, path: String): Unit = {
KMeansModel.SaveLoadV1_0.save(sc, this, path)
}

override protected def formatVersion: String = "1.0"
}

object KMeansModel extends Loader[KMeansModel] {
override def load(sc: SparkContext, path: String): KMeansModel = {
KMeansModel.SaveLoadV1_0.load(sc, path)
}

private case class Cluster(id: Int, point: Vector)

private object Cluster {
def apply(r: Row): Cluster = {
Cluster(r.getInt(0), r.getAs[Vector](1))
}
}

private[clustering]
object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"

def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
Cluster(id, point)
}.toDF()
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): KMeansModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
val centriods = sqlContext.parquetFile(Loader.dataPath(path))
Loader.checkSchema[Cluster](centriods.schema)
val localCentriods = centriods.map(Cluster.apply).collect()
assert(k == localCentriods.size)
new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class KMeansSuite extends FunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(predicts(0) != predicts(3))
}
}

test("model save/load") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

Array(true, false).foreach { case selector =>
val model = KMeansSuite.createModel(10, 3, selector)
// Save model, load it back, and compare.
try {
model.save(sc, path)
val sameModel = KMeansModel.load(sc, path)
KMeansSuite.checkEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}
}

object KMeansSuite extends FunSuite {
def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
val singlePoint = isSparse match {
case true =>
Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
case _ =>
Vectors.dense(Array.fill[Double](dim)(0.0))
}
new KMeansModel(Array.fill[Vector](k)(singlePoint))
}

def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
assert(a.k === b.k)
a.clusterCenters.zip(b.clusterCenters).foreach {
case (ca: SparseVector, cb: SparseVector) =>
assert(ca === cb)
case (ca: DenseVector, cb: DenseVector) =>
assert(ca === cb)
case _ =>
throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
}
}
}

class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
Expand Down

0 comments on commit 2d4e00e

Please sign in to comment.