forked from byzer-org/byzer-lang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8d499d8
commit ebcd355
Showing
7 changed files
with
271 additions
and
28 deletions.
There are no files selected for viewing
132 changes: 132 additions & 0 deletions
132
...amingpro-commons/src/main/java/streaming/dsl/mmlib/algs/python/BasicCondaEnvManager.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package streaming.dsl.mmlib.algs.python | ||
|
||
import java.io.File | ||
import java.nio.charset.Charset | ||
import java.util.UUID | ||
|
||
import org.apache.commons.io.FileUtils | ||
import org.json4s.DefaultFormats | ||
import org.json4s.jackson.JsonMethods.parse | ||
import streaming.common.HDFSOperator | ||
import streaming.common.shell.ShellCommand | ||
import streaming.log.Logging | ||
|
||
|
||
object BasicCondaEnvManager { | ||
val condaHomeKey = "MLFLOW_CONDA_HOME" | ||
} | ||
|
||
class BasicCondaEnvManager(options: Map[String, String]) extends Logging { | ||
|
||
def validateCondaExec = { | ||
val condaPath = getCondaBinExecutable("conda") | ||
try { | ||
ShellCommand.execCmd(s"${condaPath} --help") | ||
} catch { | ||
case e: Exception => | ||
logError(s"Could not find Conda executable at ${condaPath}.", e) | ||
throw new RuntimeException( | ||
s""" | ||
|Could not find Conda executable at ${condaPath}. | ||
|Ensure Conda is installed as per the instructions | ||
|at https://conda.io/docs/user-guide/install/index.html. You can | ||
|also configure MLSQL to look for a specific Conda executable | ||
|by setting the MLFLOW_CONDA_HOME environment variable to the path of the Conda | ||
""".stripMargin) | ||
} | ||
condaPath | ||
} | ||
|
||
def getOrCreateCondaEnv(condaEnvPath: Option[String]) = { | ||
|
||
val condaPath = validateCondaExec | ||
val stdout = ShellCommand.execCmd(s"${condaPath} env list --json") | ||
implicit val formats = DefaultFormats | ||
val envNames = (parse(stdout) \ "envs").extract[List[String]].map(_.split("/").last).toSet | ||
val projectEnvName = getCondaEnvName(condaEnvPath) | ||
if (!envNames.contains(projectEnvName)) { | ||
logInfo(s"=== Creating conda environment $projectEnvName ===") | ||
condaEnvPath match { | ||
case Some(path) => | ||
val tempFile = "/tmp/" + UUID.randomUUID() + ".yaml" | ||
try { | ||
FileUtils.write(new File(tempFile), getCondaYamlContent(condaEnvPath), Charset.forName("utf-8")) | ||
ShellCommand.execCmd(s"${condaPath} env create -n $projectEnvName --file $tempFile") | ||
} finally { | ||
FileUtils.deleteQuietly(new File(tempFile)) | ||
} | ||
|
||
case None => | ||
ShellCommand.execCmd(s"${condaPath} create -n $projectEnvName python") | ||
} | ||
} | ||
|
||
projectEnvName | ||
} | ||
|
||
def removeEnv(condaEnvPath: Option[String]) = { | ||
val condaPath = validateCondaExec | ||
val projectEnvName = getCondaEnvName(condaEnvPath) | ||
ShellCommand.execCmd(s"${condaPath} env remove --name ${projectEnvName}") | ||
} | ||
|
||
def sha1(str: String) = { | ||
val md = java.security.MessageDigest.getInstance("SHA-1") | ||
val ha = md.digest(str.getBytes).map("%02x".format(_)).mkString | ||
ha | ||
} | ||
|
||
def getCondaEnvName(condaEnvPath: Option[String]) = { | ||
val condaEnvContents = condaEnvPath match { | ||
case Some(cep) => | ||
// we should read from local ,but for now, we read from hdfs | ||
// scala.io.Source.fromFile(new File(cep)).getLines().mkString("\n") | ||
HDFSOperator.readFile(cep) | ||
case None => "" | ||
} | ||
s"mlflow-${sha1(condaEnvContents)}" | ||
} | ||
|
||
def getCondaYamlContent(condaEnvPath: Option[String]) = { | ||
val condaEnvContents = condaEnvPath match { | ||
case Some(cep) => | ||
// we should read from local ,but for now, we read from hdfs | ||
// scala.io.Source.fromFile(new File(cep)).getLines().mkString("\n") | ||
HDFSOperator.readFile(cep) | ||
case None => "" | ||
} | ||
condaEnvContents | ||
} | ||
|
||
def getCondaBinExecutable(executableName: String) = { | ||
val condaHome = options.get(BasicCondaEnvManager.condaHomeKey) match { | ||
case Some(home) => home | ||
case None => System.getenv(BasicCondaEnvManager.condaHomeKey) | ||
} | ||
if (condaHome != null) { | ||
s"${condaHome}/bin/${executableName}" | ||
} else executableName | ||
} | ||
} | ||
|
||
|
||
|
||
|
78 changes: 78 additions & 0 deletions
78
streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLPythonEnvExt.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package streaming.dsl.mmlib.algs | ||
|
||
import org.apache.spark.ml.param.Param | ||
import org.apache.spark.ps.cluster.Message | ||
import org.apache.spark.sql.expressions.UserDefinedFunction | ||
import org.apache.spark.sql.mlsql.session.MLSQLException | ||
import org.apache.spark.sql.{DataFrame, SparkSession} | ||
import streaming.common.HDFSOperator | ||
import streaming.core.strategy.platform.{PlatformManager, SparkRuntime} | ||
import streaming.dsl.mmlib.SQLAlg | ||
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams} | ||
|
||
/** | ||
* 2019-01-16 WilliamZhu([email protected]) | ||
*/ | ||
class SQLPythonEnvExt(override val uid: String) extends SQLAlg with WowParams { | ||
|
||
def this() = this(BaseParams.randomUID()) | ||
|
||
override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { | ||
val spark = df.sparkSession | ||
|
||
params.get(command.name).map { s => | ||
set(command, s) | ||
s | ||
}.getOrElse { | ||
throw new MLSQLException(s"${command.name} is required") | ||
} | ||
|
||
params.get(condaYamlFilePath.name).map { s => | ||
set(condaYamlFilePath, s) | ||
}.getOrElse { | ||
params.get(condaFile.name).map { s => | ||
val condaContent = spark.table(s).head().getString(0) | ||
val baseFile = path + "/__mlsql_temp_dir__/conda" | ||
val fileName = "conda.yaml" | ||
HDFSOperator.saveFile(baseFile, fileName, Seq(("", condaContent)).iterator) | ||
set(condaYamlFilePath, baseFile + "/" + fileName) | ||
}.getOrElse { | ||
throw new MLSQLException(s"${condaFile.name} || ${condaYamlFilePath} is required") | ||
} | ||
|
||
} | ||
|
||
val wowCommand = $(command) match { | ||
case "create" => Message.AddEnvCommand | ||
case "remove" => Message.RemoveEnvCommand | ||
} | ||
|
||
val remoteCommand = Message.CreateOrRemovePythonCondaEnv($(condaYamlFilePath), params, wowCommand) | ||
|
||
val executorNum = if (spark.sparkContext.isLocal) { | ||
val psDriverBackend = PlatformManager.getRuntime.asInstanceOf[SparkRuntime].localSchedulerBackend | ||
psDriverBackend.localEndpoint.askSync[Integer](remoteCommand) | ||
} else { | ||
val psDriverBackend = PlatformManager.getRuntime.asInstanceOf[SparkRuntime].psDriverBackend | ||
psDriverBackend.psDriverRpcEndpointRef.askSync[Integer](remoteCommand) | ||
} | ||
import spark.implicits._ | ||
Seq[Seq[Int]](Seq(executorNum)).toDF("success_executor_num") | ||
} | ||
|
||
|
||
override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { | ||
train(df, path, params) | ||
} | ||
|
||
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = throw new RuntimeException("register is not support") | ||
|
||
override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = throw new RuntimeException("register is not support") | ||
|
||
final val command: Param[String] = new Param[String](this, "command", "", isValid = (s:String) => { | ||
s == "create" || s == "remove" | ||
}) | ||
|
||
final val condaYamlFilePath: Param[String] = new Param[String](this, "condaYamlFilePath", "") | ||
final val condaFile: Param[String] = new Param[String](this, "condaFile", "") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.