Skip to content

Commit

Permalink
[SPARK-5902] [ml] Made PipelineStage.transformSchema public instead o…
Browse files Browse the repository at this point in the history
…f private to ml

For users to implement their own PipelineStages, we need to make PipelineStage.transformSchema be public instead of private to ml.  This would be nice to include in Spark 1.3

CC: mengxr

Author: Joseph K. Bradley <[email protected]>

Closes apache#4682 from jkbradley/SPARK-5902 and squashes the following commits:

6f02357 [Joseph K. Bradley] Made transformSchema public
0e6d0a0 [Joseph K. Bradley] made implementations of transformSchema protected as well
fdaf26a [Joseph K. Bradley] Made PipelineStage.transformSchema protected instead of private[ml]
  • Loading branch information
jkbradley authored and mengxr committed Feb 19, 2015
1 parent 8ca3418 commit a5fed34
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
16 changes: 12 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml
import scala.collection.mutable.ListBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
Expand All @@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
abstract class PipelineStage extends Serializable with Logging {

/**
* :: DeveloperAPI ::
*
* Derives the output schema from the input schema and parameters.
* The schema describes the columns and types of the data.
*
* @param schema Input schema to this stage
* @param paramMap Parameters passed to this stage
* @return Output schema from this stage
*/
private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
@DeveloperApi
def transformSchema(schema: StructType, paramMap: ParamMap): StructType

/**
* Derives the output schema from the input schema and parameters, optionally with logging.
Expand Down Expand Up @@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
new PipelineModel(this, map, transformers.toArray)
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
Expand Down Expand Up @@ -171,7 +179,7 @@ class PipelineModel private[ml] (
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
model
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
Expand Down Expand Up @@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
}

Expand Down Expand Up @@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class ALSModel private[ml] (
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
}

override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
Expand Down Expand Up @@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
model
}

override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
cvModel
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(estimator).transformSchema(schema, paramMap)
}
Expand All @@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
bestModel.transform(dataset, paramMap)
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
bestModel.transformSchema(schema, paramMap)
}
}

0 comments on commit a5fed34

Please sign in to comment.