Skip to content

Commit

Permalink
[SPARK-11259][ML] Params.validateParams() should be called automatically
Browse files Browse the repository at this point in the history
See JIRA: https://issues.apache.org/jira/browse/SPARK-11259

Author: Yanbo Liang <[email protected]>

Closes apache#9224 from yanboliang/spark-11259.
  • Loading branch information
yanboliang authored and jkbradley committed Jan 4, 2016
1 parent 0171b71 commit ba5f818
Show file tree
Hide file tree
Showing 30 changed files with 63 additions and 1 deletion.
2 changes: 2 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val theStages = $(stages)
require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
Expand Down Expand Up @@ -296,6 +297,7 @@ class PipelineModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
}

Expand Down
1 change: 1 addition & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ private[ml] trait PredictorParams extends Params
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
validateParams()
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
Expand Down
1 change: 1 addition & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
protected def validateInputType(inputType: DataType): Unit = {}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains($(outputCol))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ final class Binarizer(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)

val inputFields = schema.fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ final class Bucketizer(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ final class ChiSqSelector(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
Expand Down Expand Up @@ -135,6 +136,7 @@ final class ChiSqSelectorModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val newField = prepOutputField(schema)
val outputFields = schema.fields :+ newField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class HashingTF(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
Expand Down
1 change: 1 addition & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val outputColName = $(outputCol)

Expand Down
2 changes: 2 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down Expand Up @@ -130,6 +131,7 @@ class PCAModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ final class QuantileDiscretizer(override val uid: String)
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
require(inputFields.forall(_.name != $(outputCol)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R

// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
validateParams()
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
Expand Down Expand Up @@ -178,6 +179,7 @@ class RFormulaModel private[feature](
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(withFeatures)) {
Expand Down Expand Up @@ -240,6 +242,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
}

Expand Down Expand Up @@ -288,6 +291,7 @@ private class VectorAttributeRewriter(
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(
schema.fields.filter(_.name != vectorCol) ++
schema.fields.filter(_.name == vectorCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down Expand Up @@ -143,6 +144,7 @@ class StandardScalerModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class StopWordsRemover(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
Expand Down Expand Up @@ -272,6 +273,7 @@ class IndexToString private[ml] (override val uid: String)
final def getLabels: Array[String] = $(labels)

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class VectorAssembler(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
// We do not transfer feature metadata since we do not know what types of features we will
// produce in transform().
val dataType = new VectorUDT
Expand Down Expand Up @@ -354,6 +355,7 @@ class VectorIndexerModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val dataType = new VectorUDT
require(isDefined(inputCol),
s"VectorIndexerModel requires input column parameter: $inputCol")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ final class VectorSlicer(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)

if (schema.fieldNames.contains($(outputCol))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
Expand Down Expand Up @@ -213,6 +214,7 @@ class ALSModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
protected[ml] def validateAndTransformSchema(
schema: StructType,
fitting: Boolean): StructType = {
validateParams()
if (fitting) {
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
if (hasWeightCol) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
$(estimator).transformSchema(schema)
}

Expand Down Expand Up @@ -345,6 +346,7 @@ class CrossValidatorModel private[ml] (

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
bestModel.transformSchema(schema)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
$(estimator).transformSchema(schema)
}

Expand Down Expand Up @@ -172,6 +173,7 @@ class TrainValidationSplitModel private[ml] (

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
bestModel.transformSchema(schema)
}

Expand Down
23 changes: 22 additions & 1 deletion mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
}

test("pipeline validateParams") {
val df = sqlContext.createDataFrame(
Seq(
(1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
(3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
(4, Vectors.dense(0.0, 0.0, 5.0), 4.0))
).toDF("id", "features", "label")

intercept[IllegalArgumentException] {
val scaler = new MinMaxScaler()
.setInputCol("features")
.setOutputCol("features_scaled")
.setMin(10)
.setMax(0)
val pipeline = new Pipeline().setStages(Array(scaler))
pipeline.fit(df)
}
}
}


Expand Down

0 comments on commit ba5f818

Please sign in to comment.