Skip to content

Commit

Permalink
[SPARK-1406] Mllib pmml model export
Browse files Browse the repository at this point in the history
See PDF attached to the JIRA issue 1406.

The contribution is my original work and I license the work to the project under the project's open source license.

Author: Vincenzo Selvaggio <[email protected]>
Author: Xiangrui Meng <[email protected]>
Author: selvinsource <[email protected]>

Closes apache#3062 from selvinsource/mllib_pmml_model_export_SPARK-1406 and squashes the following commits:

852aac6 [Vincenzo Selvaggio] [SPARK-1406] Update JPMML version to 1.1.15 in LICENSE file
085cf42 [Vincenzo Selvaggio] [SPARK-1406] Added Double Min and Max Fixed scala style
30165c4 [Vincenzo Selvaggio] [SPARK-1406] Fixed extreme cases for logit
7a5e0ec [Vincenzo Selvaggio] [SPARK-1406] Binary classification for SVM and Logistic Regression
cfcb596 [Vincenzo Selvaggio] [SPARK-1406] Throw IllegalArgumentException when exporting a multinomial logistic regression
25dce33 [Vincenzo Selvaggio] [SPARK-1406] Update code to latest pmml model
dea98ca [Vincenzo Selvaggio] [SPARK-1406] Exclude transitive dependency for pmml model
66b7c12 [Vincenzo Selvaggio] [SPARK-1406] Updated pmml model lib to 1.1.15, latest Java 6 compatible
a0a55f7 [Vincenzo Selvaggio] Merge pull request apache#2 from mengxr/SPARK-1406
3c22f79 [Xiangrui Meng] more code style
e2313df [Vincenzo Selvaggio] Merge pull request apache#1 from mengxr/SPARK-1406
472d757 [Xiangrui Meng] fix code style
1676e15 [Vincenzo Selvaggio] fixed scala issue
e2ffae8 [Vincenzo Selvaggio] fixed scala style
b8823b0 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
b25bbf7 [Vincenzo Selvaggio] [SPARK-1406] Added export of pmml to distributed file system using the spark context
7a949d0 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style
f46c75c [Vincenzo Selvaggio] [SPARK-1406] Added PMMLExportable to supported models
7b33b4e [Vincenzo Selvaggio] [SPARK-1406] Added a PMMLExportable interface Restructured code in a new package mllib.pmml Supported models implements the new PMMLExportable interface: LogisticRegression, SVM, KMeansModel, LinearRegression, RidgeRegression, Lasso
d559ec5 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
8fe12bb [Vincenzo Selvaggio] [SPARK-1406] Adjusted logistic regression export description and target categories
03bc3a5 [Vincenzo Selvaggio] added logistic regression
da2ec11 [Vincenzo Selvaggio] [SPARK-1406] added linear SVM PMML export
82f2131 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
19adf29 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style
1faf985 [Vincenzo Selvaggio] [SPARK-1406] Added target field to the regression model for completeness Adjusted unit test to deal with this change
3ae8ae5 [Vincenzo Selvaggio] [SPARK-1406] Adjusted imported order according to the guidelines
c67ce81 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
78515ec [Vincenzo Selvaggio] [SPARK-1406] added pmml export for LinearRegressionModel, RidgeRegressionModel and LassoModel
e29dfb9 [Vincenzo Selvaggio] removed version, by default is set to 4.2 (latest from jpmml) removed copyright
ae8b993 [Vincenzo Selvaggio] updated some commented tests to use the new ModelExporter object reordered the imports
df8a89e [Vincenzo Selvaggio] added pmml version to pmml model changed the copyright to spark
a1b4dc3 [Vincenzo Selvaggio] updated imports
834ca44 [Vincenzo Selvaggio] reordered the import accordingly to the guidelines
349a76b [Vincenzo Selvaggio] new helper object to serialize the models to pmml format
c3ef9b8 [Vincenzo Selvaggio] set it to private
6357b98 [Vincenzo Selvaggio] set it to private
e1eb251 [Vincenzo Selvaggio] removed serialization part, this will be part of the ModelExporter helper object
aba5ee1 [Vincenzo Selvaggio] fixed cluster export
cd6c07c [Vincenzo Selvaggio] fixed scala style to run tests
f75b988 [Vincenzo Selvaggio] Merge remote-tracking branch 'origin/master' into mllib_pmml_model_export_SPARK-1406
07a29bf [selvinsource] Update LICENSE
8841439 [Vincenzo Selvaggio] adjust scala style in order to compile
1433b11 [Vincenzo Selvaggio] complete suite tests
8e71b8d [Vincenzo Selvaggio] kmeans pmml export implementation
9bc494f [Vincenzo Selvaggio] added scala suite tests added saveLocalFile to ModelExport trait
226e184 [Vincenzo Selvaggio] added javadoc and export model type in case there is a need to support other types of export (not just PMML)
a0e3679 [Vincenzo Selvaggio] export and pmml export traits kmeans test implementation
  • Loading branch information
selvinsource authored and mengxr committed Apr 30, 2015
1 parent 4459514 commit 254e050
Show file tree
Hide file tree
Showing 18 changed files with 774 additions and 6 deletions.
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ BSD-style licenses
The following components are provided under a BSD-style license. See project link for details.

(BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model)
(BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
Expand Down
15 changes: 15 additions & 0 deletions mllib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
<version>1.1.15</version>
<exclusions>
<exclusion>
<groupId>com.sun.xml.fastinfoset</groupId>
<artifactId>FastInfoset</artifactId>
</exclusion>
<exclusion>
<groupId>com.sun.istack</groupId>
<artifactId>istack-commons-runtime</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
<profiles>
<profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
Expand All @@ -46,7 +47,7 @@ class LogisticRegressionModel (
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {
with Saveable with PMMLExportable {

if (numClasses == 2) {
require(weights.size == numFeatures,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.rdd.RDD
Expand All @@ -36,7 +37,7 @@ class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {
with Saveable with PMMLExportable {

private var threshold: Option[Double] = Some(0.0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
Expand All @@ -34,7 +35,8 @@ import org.apache.spark.sql.Row
/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
class KMeansModel (
val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable {

/** A Java-friendly constructor that takes an Iterable of Vectors. */
def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.pmml

import java.io.{File, OutputStream, StringWriter}
import javax.xml.transform.stream.StreamResult

import org.jpmml.model.JAXBUtil

import org.apache.spark.SparkContext
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory

/**
* Export model to the PMML format
* Predictive Model Markup Language (PMML) is an XML-based file format
* developed by the Data Mining Group (www.dmg.org).
*/
trait PMMLExportable {

/**
* Export the model to the stream result in PMML format
*/
private def toPMML(streamResult: StreamResult): Unit = {
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
}

/**
* Export the model to a local file in PMML format
*/
def toPMML(localPath: String): Unit = {
toPMML(new StreamResult(new File(localPath)))
}

/**
* Export the model to a directory on a distributed file system in PMML format
*/
def toPMML(sc: SparkContext, path: String): Unit = {
val pmml = toPMML()
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
}

/**
* Export the model to the OutputStream in PMML format
*/
def toPMML(outputStream: OutputStream): Unit = {
toPMML(new StreamResult(outputStream))
}

/**
* Export the model to a String in PMML format
*/
def toPMML(): String = {
val writer = new StringWriter
toPMML(new StreamResult(writer))
writer.toString
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.pmml.export

import scala.{Array => SArray}

import org.dmg.pmml._

import org.apache.spark.mllib.regression.GeneralizedLinearModel

/**
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
*/
private[mllib] class BinaryClassificationPMMLModelExport(
model : GeneralizedLinearModel,
description : String,
normalizationMethod : RegressionNormalizationMethodType,
threshold: Double)
extends PMMLModelExport {

populateBinaryClassificationPMML()

/**
* Export the input LogisticRegressionModel or SVMModel to PMML format.
*/
private def populateBinaryClassificationPMML(): Unit = {
pmml.getHeader.setDescription(description)

if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
var interceptNO = threshold
if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
if (threshold <= 0) {
interceptNO = Double.MinValue
} else if (threshold >= 1) {
interceptNO = Double.MaxValue
} else {
interceptNO = -math.log(1 / threshold - 1)
}
}
val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0")
val regressionModel = new RegressionModel()
.withFunctionName(MiningFunctionType.CLASSIFICATION)
.withMiningSchema(miningSchema)
.withModelName(description)
.withNormalizationMethod(normalizationMethod)
.withRegressionTables(regressionTableYES, regressionTableNO)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

// add target field
val targetField = FieldName.create("target")
dataDictionary
.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING))
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))

dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.pmml.export

import scala.{Array => SArray}

import org.dmg.pmml._

import org.apache.spark.mllib.regression.GeneralizedLinearModel

/**
* PMML Model Export for GeneralizedLinearModel abstract class
*/
private[mllib] class GeneralizedLinearPMMLModelExport(
model: GeneralizedLinearModel,
description: String)
extends PMMLModelExport {

populateGeneralizedLinearPMML(model)

/**
* Export the input GeneralizedLinearModel model to PMML format.
*/
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
pmml.getHeader.setDescription(description)

if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)
val regressionModel = new RegressionModel()
.withFunctionName(MiningFunctionType.REGRESSION)
.withMiningSchema(miningSchema)
.withModelName(description)
.withRegressionTables(regressionTable)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

// for completeness add target field
val targetField = FieldName.create("target")
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))

dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
}
}
Loading

0 comments on commit 254e050

Please sign in to comment.