From b22578de82c8cf2a600f4f715cadc621d8e75477 Mon Sep 17 00:00:00 2001 From: andie huang Date: Tue, 24 Aug 2021 11:06:58 +0800 Subject: [PATCH] support batch prediction and register as udfs for logistic regression and linear regression add it test for udf register and batch prediction --- ...ang_01_train_logistic_regression_model.mlsql | 5 ++++- ...huang_02_train_linear_regression_model.mlsql | 7 +++++-- .../dsl/mmlib/algs/SQLLinearRegressionExt.scala | 13 +++++++++---- .../dsl/mmlib/algs/SQLLogisticRegression.scala | 17 ++++++++--------- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/streamingpro-it/src/test/resources/sql/simple/andie_huang_01_train_logistic_regression_model.mlsql b/streamingpro-it/src/test/resources/sql/simple/andie_huang_01_train_logistic_regression_model.mlsql index 8d80108b0..5a5e23e57 100644 --- a/streamingpro-it/src/test/resources/sql/simple/andie_huang_01_train_logistic_regression_model.mlsql +++ b/streamingpro-it/src/test/resources/sql/simple/andie_huang_01_train_logistic_regression_model.mlsql @@ -41,4 +41,7 @@ as model_result; select name,value from model_result where name="status" as result; -- make sure status of all models are success. -!assert result ''':value=="success"''' "all model status should be success"; \ No newline at end of file +!assert result ''':value=="success"''' "all model status should be success"; +predict data1 as LogisticRegression.`/tmp/model`; +register LogisticRegression.`/tmp/model` as lr_predict; +select lr_predict(features) as predict_label, label from data1 as output; diff --git a/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql b/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql index 9894bf9ca..a1282bc57 100644 --- a/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql +++ b/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql @@ -19,7 +19,7 @@ as data1; -- select * from data1 as output1; -- -- use RandomForest -train data1 as LogisticRegression.`/tmp/model` where +train data1 as LinearRegression.`/tmp/model` where -- -- once set true,every time you run this script, MLSQL will generate new directory for you model keepVersion="true" @@ -40,4 +40,7 @@ as model_result; select name,value from model_result where name="status" as result; -- make sure status of all models are success. -!assert result ''':value=="success"''' "all model status should be success"; \ No newline at end of file +!assert result ''':value=="success"''' "all model status should be success"; +predict data1 as LinearRegression.`/tmp/model`; +register LinearRegression.`/tmp/model` as lr_predict; +select lr_predict(features) as predict_label, label from data1 as output; \ No newline at end of file diff --git a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLinearRegressionExt.scala b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLinearRegressionExt.scala index 2bb26437d..37f775f2b 100644 --- a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLinearRegressionExt.scala +++ b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLinearRegressionExt.scala @@ -144,13 +144,13 @@ class SQLLinearRegressionExt(override val uid: String) extends SQLAlg with Mllib override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = { - val model = LinearRegressionModel.load(path) - model + val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession) + val model = LinearRegressionModel.load(bestModelPath(0)) + ArrayBuffer(model) } override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = { - val model = sparkSession.sparkContext.broadcast(_model.asInstanceOf[LinearRegressionModel]) - + val model = sparkSession.sparkContext.broadcast(_model.asInstanceOf[ArrayBuffer[LinearRegressionModel]].head) val f = (vec: Vector) => { val result = model.value.getClass.getMethod("predict", classOf[Vector]).invoke(model.value, vec) result @@ -158,6 +158,11 @@ class SQLLinearRegressionExt(override val uid: String) extends SQLAlg with Mllib MLSQLUtils.createUserDefinedFunction(f, DoubleType, Some(Seq(VectorType))) } + override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { + val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[LinearRegressionModel]].head + model.transform(df) + } + override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = { val vtable = MLSQLTable( Option(DB_DEFAULT.MLSQL_SYSTEM.toString), diff --git a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLogisticRegression.scala b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLogisticRegression.scala index aaefed60e..6d2ebd3de 100644 --- a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLogisticRegression.scala +++ b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLLogisticRegression.scala @@ -135,7 +135,6 @@ class SQLLogisticRegression(override val uid: String) extends SQLAlg with MllibF Seq("uid", model.uid), Seq("numFeatures", model.numFeatures.toString), Seq("numClasses", model.numClasses.toString), - Seq("binarySummary", model.binarySummary.toString()), Seq("intercept", model.intercept.toString()), Seq("coefficients", model.coefficients.toString()) ) ++ modelParams @@ -148,18 +147,18 @@ class SQLLogisticRegression(override val uid: String) extends SQLAlg with MllibF override def modelType: ModelType = AlgType override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = { - val model = LogisticRegressionModel.load(path) - model + val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession) + val model = LogisticRegressionModel.load(bestModelPath(0)) + ArrayBuffer(model) } override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = { - val model = sparkSession.sparkContext.broadcast(_model.asInstanceOf[LogisticRegressionModel]) + predict_classification(sparkSession, _model, name) + } - val f = (vec: Vector) => { - val result = model.value.getClass.getMethod("predict", classOf[Vector]).invoke(model.value, vec) - result - } - MLSQLUtils.createUserDefinedFunction(f, DoubleType, Some(Seq(VectorType))) + override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { + val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[LogisticRegressionModel]].head + model.transform(df) } override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = {