Skip to content

Commit

Permalink
[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFor…
Browse files Browse the repository at this point in the history
…mula

Preview:

```
> summary(m)
            features coefficients
1        (Intercept)    1.6765001
2       Sepal_Length    0.3498801
3 Species.versicolor   -0.9833885
4  Species.virginica   -1.0075104

```

Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit

cc mengxr

Author: Eric Liang <[email protected]>

Closes apache#7771 from ericl/summary and squashes the following commits:

ccd54c3 [Eric Liang] second pass
a5ca93b [Eric Liang] comments
2772111 [Eric Liang] clean up
70483ef [Eric Liang] fix test
7c247d4 [Eric Liang] Merge branch 'master' into summary
3c55024 [Eric Liang] working
8c539aa [Eric Liang] first pass
  • Loading branch information
ericl authored and mengxr committed Jul 30, 2015
1 parent be7be6d commit e7905a9
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 17 deletions.
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export("print.jobj")

# MLlib integration
exportMethods("glm",
"predict")
"predict",
"summary")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
26 changes: 26 additions & 0 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})

#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
#' @param model A fitted MLlib model
#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
#' summary.glm for more information.
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
setMethod("summary", signature(object = "PipelineModel"),
function(object) {
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", object@model)
weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelWeights", object@model)
coefficients <- as.matrix(unlist(weights))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
})
11 changes: 11 additions & 0 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", {
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})

test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
coefs <- as.vector(stats$coefficients)
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))
expect_true(all(
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
})
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType): StructType = {
val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)

Expand All @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
if (nominal.values.isDefined) {
nominal.values.map(_.map(v => inputColName + is + v))
nominal.values
} else if (nominal.numValues.isDefined) {
nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
} else {
None
}
case binary: BinaryAttribute =>
if (binary.values.isDefined) {
binary.values.map(_.map(v => inputColName + is + v))
binary.values
} else {
Some(Array.tabulate(2)(i => inputColName + is + i))
Some(Array.tabulate(2)(_.toString))
}
case _: NumericAttribute =>
throw new RuntimeException(
Expand Down Expand Up @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer

override def transform(dataset: DataFrame): DataFrame = {
// schema transformation
val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val shouldDropLast = $(dropLast)
Expand All @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
math.max(m0, m1)
}
).toInt + 1
val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
val outputAttrs: Array[Attribute] =
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
Expand Down
12 changes: 11 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.feature

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.parsing.combinator.RegexParsers

Expand Down Expand Up @@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
// TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]()
val tempColumns = ArrayBuffer[String]()
val takenNames = mutable.Set(dataset.columns: _*)
val encodedTerms = resolvedFormula.terms.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = term + "_idx_" + uid
val encodedCol = term + "_onehot_" + uid
val encodedCol = {
var tmp = term
while (takenNames.contains(tmp)) {
tmp += "_"
}
tmp
}
takenNames.add(indexCol)
takenNames.add(encodedCol)
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
tempColumns += indexCol
Expand Down
27 changes: 25 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.ml.api.r

import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.DataFrame

Expand All @@ -44,4 +45,26 @@ private[r] object SparkRWrappers {
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}

def getModelWeights(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel =>
Array(m.intercept) ++ m.weights.toArray
case _: LogisticRegressionModel =>
throw new UnsupportedOperationException(
"No weights available for LogisticRegressionModel") // SPARK-9492
}
}

def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
case _: LogisticRegressionModel =>
throw new UnsupportedOperationException(
"No features names available for LogisticRegressionModel") // SPARK-9492
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructField
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter

Expand Down Expand Up @@ -146,9 +147,10 @@ class LinearRegression(override val uid: String)

val model = new LinearRegressionModel(uid, weights, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset).select($(predictionCol), $(labelCol)),
model.transform(dataset),
$(predictionCol),
$(labelCol),
$(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
}
Expand Down Expand Up @@ -221,9 +223,10 @@ class LinearRegression(override val uid: String)

val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset).select($(predictionCol), $(labelCol)),
model.transform(dataset),
$(predictionCol),
$(labelCol),
$(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
Expand Down Expand Up @@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
val featuresCol: String,
val objectiveHistory: Array[Double])
extends LinearRegressionSummary(predictions, predictionCol, labelCol) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
}

test("input column without ML attribute") {
Expand All @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
Expand Down Expand Up @@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
}

test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array(
new BinaryAttribute(Some("a__bar"), Some(1)),
new BinaryAttribute(Some("a__foo"), Some(2)),
new NumericAttribute(Some("b"), Some(3))))
assert(attrs === expectedAttrs)
}
}

0 comments on commit e7905a9

Please sign in to comment.