Skip to content

Commit

Permalink
support batch prediction and register as udfs for logistic regression…
Browse files Browse the repository at this point in the history
… and linear regression

add it test for udf register and batch prediction
  • Loading branch information
ckeys committed Aug 24, 2021
1 parent cf87629 commit b22578d
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,7 @@ as model_result;

select name,value from model_result where name="status" as result;
-- make sure status of all models are success.
!assert result ''':value=="success"''' "all model status should be success";
!assert result ''':value=="success"''' "all model status should be success";
predict data1 as LogisticRegression.`/tmp/model`;
register LogisticRegression.`/tmp/model` as lr_predict;
select lr_predict(features) as predict_label, label from data1 as output;
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ as data1;

-- select * from data1 as output1;
-- -- use RandomForest
train data1 as LogisticRegression.`/tmp/model` where
train data1 as LinearRegression.`/tmp/model` where

-- -- once set true,every time you run this script, MLSQL will generate new directory for you model
keepVersion="true"
Expand All @@ -40,4 +40,7 @@ as model_result;

select name,value from model_result where name="status" as result;
-- make sure status of all models are success.
!assert result ''':value=="success"''' "all model status should be success";
!assert result ''':value=="success"''' "all model status should be success";
predict data1 as LinearRegression.`/tmp/model`;
register LinearRegression.`/tmp/model` as lr_predict;
select lr_predict(features) as predict_label, label from data1 as output;
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,25 @@ class SQLLinearRegressionExt(override val uid: String) extends SQLAlg with Mllib


override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
val model = LinearRegressionModel.load(path)
model
val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
val model = LinearRegressionModel.load(bestModelPath(0))
ArrayBuffer(model)
}

override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
val model = sparkSession.sparkContext.broadcast(_model.asInstanceOf[LinearRegressionModel])

val model = sparkSession.sparkContext.broadcast(_model.asInstanceOf[ArrayBuffer[LinearRegressionModel]].head)
val f = (vec: Vector) => {
val result = model.value.getClass.getMethod("predict", classOf[Vector]).invoke(model.value, vec)
result
}
MLSQLUtils.createUserDefinedFunction(f, DoubleType, Some(Seq(VectorType)))
}

override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[LinearRegressionModel]].head
model.transform(df)
}

override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = {
val vtable = MLSQLTable(
Option(DB_DEFAULT.MLSQL_SYSTEM.toString),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ class SQLLogisticRegression(override val uid: String) extends SQLAlg with MllibF
Seq("uid", model.uid),
Seq("numFeatures", model.numFeatures.toString),
Seq("numClasses", model.numClasses.toString),
Seq("binarySummary", model.binarySummary.toString()),
Seq("intercept", model.intercept.toString()),
Seq("coefficients", model.coefficients.toString())
) ++ modelParams
Expand All @@ -148,18 +147,18 @@ class SQLLogisticRegression(override val uid: String) extends SQLAlg with MllibF
override def modelType: ModelType = AlgType

override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
val model = LogisticRegressionModel.load(path)
model
val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
val model = LogisticRegressionModel.load(bestModelPath(0))
ArrayBuffer(model)
}

override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
val model = sparkSession.sparkContext.broadcast(_model.asInstanceOf[LogisticRegressionModel])
predict_classification(sparkSession, _model, name)
}

val f = (vec: Vector) => {
val result = model.value.getClass.getMethod("predict", classOf[Vector]).invoke(model.value, vec)
result
}
MLSQLUtils.createUserDefinedFunction(f, DoubleType, Some(Seq(VectorType)))
override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[LogisticRegressionModel]].head
model.transform(df)
}

override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = {
Expand Down

0 comments on commit b22578d

Please sign in to comment.