Skip to content

Commit

Permalink
Merge pull request byzer-org#897 from allwefantasy/ISSUE-896
Browse files Browse the repository at this point in the history
PythonAlg supports python in MLSQL
  • Loading branch information
allwefantasy authored Jan 13, 2019
2 parents cf1144a + fe2fe89 commit b63b448
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ class SQLPythonAlg(override val uid: String) extends SQLAlg with Functions with

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
pythonCheckRequirements(df)
new PythonTrain().train(df, path, params)
autoConfigureAutoCreateProjectParams(params)
var newParams = params
if (get(scripts).isDefined) {
val autoCreateMLproject = new AutoCreateMLproject($(scripts), $(condaFile), $(entryPoint))
val projectPath = autoCreateMLproject.saveProject(df.sparkSession, path)
newParams = params
newParams += ("enableDataLocal" -> "true")
newParams += ("pythonScriptPath" -> projectPath)
newParams += ("pythonDescPath" -> projectPath)
}
new PythonTrain().train(df, path, newParams)
}

override def load(sparkSession: SparkSession, _path: String, params: Map[String, String]): Any = {
Expand Down Expand Up @@ -155,4 +165,5 @@ object SQLPythonAlg extends Logging with WowLog {
else None

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@ import org.apache.spark.ml.param.Param
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, SparkSession}
import streaming.common.HDFSOperator
import streaming.dsl.ScriptSQLExec
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.param.BaseParams
import streaming.dsl.mmlib.algs.python.PythonTrain
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.dsl.mmlib.algs.python.{AutoCreateMLproject, PythonTrain}


/**
* 2019-01-08 WilliamZhu([email protected])
*/
class SQLPythonParallelExt(override val uid: String) extends SQLAlg with Functions with BaseParams {
class SQLPythonParallelExt(override val uid: String) extends SQLAlg with Functions with WowParams {
def this() = this(BaseParams.randomUID())

private def validateParams(params: Map[String, String]) = {
Expand Down Expand Up @@ -43,51 +42,23 @@ class SQLPythonParallelExt(override val uid: String) extends SQLAlg with Functio
}
}

def projectName = "mlsql-python-project"

/*
We will automatically create project for user according the configuration
*/
private def saveProject(sparkSession: SparkSession, path: String) = {
val projectPath = path + s"/${projectName}"
$(scripts).split(",").foreach { script =>
val content = sparkSession.table(script).head().getString(0)
HDFSOperator.saveFile(projectPath, script + ".py", Seq(("", content)).iterator)
}
HDFSOperator.saveFile(projectPath, "MLproject", Seq(("", MLprojectTemplate)).iterator)
val condaContent = sparkSession.table($(condaFile)).head().getString(0)
HDFSOperator.saveFile(projectPath, "conda.yaml", Seq(("", condaContent)).iterator)
projectPath
}


private def MLprojectTemplate = {
s"""
|name: mlsql-python
|
|conda_env: conda.yaml
|
|entry_points:
| main:
| train:
| command: "python ${$(entryPoint)}.py"
| batchPredict:
| command: "python ${$(entryPoint)}.py"
""".stripMargin
}

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
pythonCheckRequirements(df)
val mlsqlContext = ScriptSQLExec.contextGetOrForTest()

validateParams(params)

val projectPath = saveProject(df.sparkSession, path)
val autoCreateMLproject = new AutoCreateMLproject($(scripts), $(condaFile), $(entryPoint))

val projectPath = autoCreateMLproject.saveProject(df.sparkSession, path)

var newParams = params

newParams += ("enableDataLocal" -> ($(feedMode) == "file").toString)
newParams += ("pythonScriptPath" -> projectPath)
newParams += ("pythonDescPath" -> projectPath)

val pt = new PythonTrain()
pt.train_per_partition(df, path, newParams)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package streaming.dsl.mmlib.algs.param

import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.sql.mlsql.session.MLSQLException

/**
* Created by allwefantasy on 28/9/2018.
Expand Down Expand Up @@ -54,4 +55,39 @@ trait SQLPythonAlgParams extends BaseParams {
final val fitParam: Param[String] = new Param[String](this, "fitParam",
"fitParam is dynamic params. e.g. fitParam.0.moduleName,fitParam.1.moduleName`")

final val scripts: Param[String] = new Param[String](this, "scripts",
"")

final val projectPath: Param[String] = new Param(this, "projectPath",
"")

final val entryPoint: Param[String] = new Param(this, "entryPoint",
"")

final val condaFile: Param[String] = new Param(this, "condaFile",
"")


def autoConfigureAutoCreateProjectParams(params: Map[String, String]) = {

params.get(scripts.name).map { item =>
set(scripts, item)
item
}.getOrElse {
}

params.get(entryPoint.name).map { item =>
set(entryPoint, item)
item
}.getOrElse {
}

params.get(condaFile.name).map { item =>
set(condaFile, item)
item
}.getOrElse {

}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,40 @@ class CondaEnvManager(options: Map[String, String]) extends Logging with WowLog
}
}

class AutoCreateMLproject(scripts: String, condaFile: String, entryPoint: String) {

def projectName = "mlsql-python-project"

/*
We will automatically create project for user according the configuration
*/
def saveProject(sparkSession: SparkSession, path: String) = {
val projectPath = path + s"/${projectName}"
scripts.split(",").foreach { script =>
val content = sparkSession.table(script).head().getString(0)
HDFSOperator.saveFile(projectPath, script + ".py", Seq(("", content)).iterator)
}
HDFSOperator.saveFile(projectPath, "MLproject", Seq(("", MLprojectTemplate)).iterator)
val condaContent = sparkSession.table(condaFile).head().getString(0)
HDFSOperator.saveFile(projectPath, "conda.yaml", Seq(("", condaContent)).iterator)
projectPath
}


private def MLprojectTemplate = {
s"""
|name: mlsql-python
|
|conda_env: conda.yaml
|
|entry_points:
| main:
| train:
| command: "python ${entryPoint}.py"
| batchPredict:
| command: "python ${entryPoint}.py"
""".stripMargin
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ class PythonMLSpec2 extends BasicSparkOperation with SpecFunctions with BasicMLS
}
}

"SQLPythonAlg auto create project" should "work fine" in {
withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime =>
//执行sql
implicit val spark = runtime.sparkSession
mockServer
val sq = createSSEL(spark, "")
//train
ScriptSQLExec.parse(ScriptCode._j2, sq)

var table = sq.getLastSelectTable().get
val res = spark.sql(s"select * from output").collect()
assert(res.length == 1)
assert(res.head.getAs[String](0).contains("jack"))

}
}

"SQLPythonParallelExt " should "work fine" in {
withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime =>
//执行sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,65 @@ object ScriptCode {
|
""".stripMargin


val _j2 =
"""
|set python1='''
|import os
|import warnings
|import sys
|
|import mlsql
|
|if __name__ == "__main__":
| warnings.filterwarnings("ignore")
|
| tempDataLocalPath = mlsql.internal_system_param["tempDataLocalPath"]
|
| isp = mlsql.params()["internalSystemParam"]
| tempModelLocalPath = isp["tempModelLocalPath"]
| if not os.path.exists(tempModelLocalPath):
| os.makedirs(tempModelLocalPath)
| with open(tempModelLocalPath + "/result.txt", "w") as f:
| f.write("jack")
|''';
|
|set dependencies='''
|name: tutorial
|dependencies:
| - python=3.6
| - pip:
| - numpy==1.14.3
| - kafka-python==1.4.3
| - pyspark==2.3.2
| - pandas==0.22.0
| - scikit-learn==0.19.1
| - scipy==1.1.0
|''';
|
|set modelPath="/tmp/jack2";
|
|set data='''
|{"jack":1}
|''';
|
|load jsonStr.`data` as testData;
|load script.`python1` as python1;
|load script.`dependencies` as dependencies;
|
|-- train sklearn model
|run testData as PythonAlg.`${modelPath}`
|where scripts="python1"
|and entryPoint="python1"
|and condaFile="dependencies"
|and fitParam.0.abc="test"
|;
|
|load text.`${modelPath}/model/0` as output; -- 查看目标文件
|
|
""".stripMargin

val train =
"""
|load csv.`${projectPath}/wine-quality.csv`
Expand All @@ -93,6 +152,7 @@ object ScriptCode {
| and keepVersion="true"
| and enableDataLocal="true"
| and dataLocalFormat="csv"
| and fitParam.0.abc="example"
| ${kv}
|-- and systemParam.envs='''{"MLFLOW_CONDA_HOME":"/anaconda3"}'''
| ;
Expand Down

0 comments on commit b63b448

Please sign in to comment.