Skip to content

Commit

Permalink
[SPARK-13334][ML] ML KMeansModel / BisectingKMeansModel / QuantileDis…
Browse files Browse the repository at this point in the history
…cretizer should set parent

ML ```KMeansModel / BisectingKMeansModel / QuantileDiscretizer``` should set parent.

cc mengxr

Author: Yanbo Liang <[email protected]>

Closes apache#11214 from yanboliang/spark-13334.
  • Loading branch information
yanboliang authored and MLnick committed Feb 22, 2016
1 parent e298ac9 commit 40e6d40
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class BisectingKMeans @Since("2.0.0") (
.setSeed($(seed))
val parentModel = bkm.run(rdd)
val model = new BisectingKMeansModel(uid, parentModel)
copyValues(model)
copyValues(model.setParent(this))
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = new KMeansModel(uid, parentModel)
copyValues(model)
copyValues(model.setParent(this))
}

@Since("1.5.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ final class QuantileDiscretizer(override val uid: String)
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
val splits = QuantileDiscretizer.getSplits(candidates)
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer)
copyValues(bucketizer.setParent(this))
}

override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,6 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}

test("read/write") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
.setNumBuckets(numBucket).setSeed(1)
val result = discretizer.fit(df).transform(df)
val model = discretizer.fit(df)
assert(model.hasParent)
val result = model.transform(df)

val transformedFeatures = result.select("result").collect()
.map { case Row(transformedFeature: Double) => transformedFeature }
Expand Down

0 comments on commit 40e6d40

Please sign in to comment.