Skip to content

Commit

Permalink
support pythonenv distribute
Browse files Browse the repository at this point in the history
  • Loading branch information
allwefantasy committed Jan 16, 2019
1 parent 8d499d8 commit ebcd355
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package streaming.dsl.mmlib.algs.python

import java.io.File
import java.nio.charset.Charset
import java.util.UUID

import org.apache.commons.io.FileUtils
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods.parse
import streaming.common.HDFSOperator
import streaming.common.shell.ShellCommand
import streaming.log.Logging


object BasicCondaEnvManager {
val condaHomeKey = "MLFLOW_CONDA_HOME"
}

class BasicCondaEnvManager(options: Map[String, String]) extends Logging {

def validateCondaExec = {
val condaPath = getCondaBinExecutable("conda")
try {
ShellCommand.execCmd(s"${condaPath} --help")
} catch {
case e: Exception =>
logError(s"Could not find Conda executable at ${condaPath}.", e)
throw new RuntimeException(
s"""
|Could not find Conda executable at ${condaPath}.
|Ensure Conda is installed as per the instructions
|at https://conda.io/docs/user-guide/install/index.html. You can
|also configure MLSQL to look for a specific Conda executable
|by setting the MLFLOW_CONDA_HOME environment variable to the path of the Conda
""".stripMargin)
}
condaPath
}

def getOrCreateCondaEnv(condaEnvPath: Option[String]) = {

val condaPath = validateCondaExec
val stdout = ShellCommand.execCmd(s"${condaPath} env list --json")
implicit val formats = DefaultFormats
val envNames = (parse(stdout) \ "envs").extract[List[String]].map(_.split("/").last).toSet
val projectEnvName = getCondaEnvName(condaEnvPath)
if (!envNames.contains(projectEnvName)) {
logInfo(s"=== Creating conda environment $projectEnvName ===")
condaEnvPath match {
case Some(path) =>
val tempFile = "/tmp/" + UUID.randomUUID() + ".yaml"
try {
FileUtils.write(new File(tempFile), getCondaYamlContent(condaEnvPath), Charset.forName("utf-8"))
ShellCommand.execCmd(s"${condaPath} env create -n $projectEnvName --file $tempFile")
} finally {
FileUtils.deleteQuietly(new File(tempFile))
}

case None =>
ShellCommand.execCmd(s"${condaPath} create -n $projectEnvName python")
}
}

projectEnvName
}

def removeEnv(condaEnvPath: Option[String]) = {
val condaPath = validateCondaExec
val projectEnvName = getCondaEnvName(condaEnvPath)
ShellCommand.execCmd(s"${condaPath} env remove --name ${projectEnvName}")
}

def sha1(str: String) = {
val md = java.security.MessageDigest.getInstance("SHA-1")
val ha = md.digest(str.getBytes).map("%02x".format(_)).mkString
ha
}

def getCondaEnvName(condaEnvPath: Option[String]) = {
val condaEnvContents = condaEnvPath match {
case Some(cep) =>
// we should read from local ,but for now, we read from hdfs
// scala.io.Source.fromFile(new File(cep)).getLines().mkString("\n")
HDFSOperator.readFile(cep)
case None => ""
}
s"mlflow-${sha1(condaEnvContents)}"
}

def getCondaYamlContent(condaEnvPath: Option[String]) = {
val condaEnvContents = condaEnvPath match {
case Some(cep) =>
// we should read from local ,but for now, we read from hdfs
// scala.io.Source.fromFile(new File(cep)).getLines().mkString("\n")
HDFSOperator.readFile(cep)
case None => ""
}
condaEnvContents
}

def getCondaBinExecutable(executableName: String) = {
val condaHome = options.get(BasicCondaEnvManager.condaHomeKey) match {
case Some(home) => home
case None => System.getenv(BasicCondaEnvManager.condaHomeKey)
}
if (condaHome != null) {
s"${condaHome}/bin/${executableName}"
} else executableName
}
}




Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package streaming.dsl.mmlib.algs

import org.apache.spark.ml.param.Param
import org.apache.spark.ps.cluster.Message
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, SparkSession}
import streaming.common.HDFSOperator
import streaming.core.strategy.platform.{PlatformManager, SparkRuntime}
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}

/**
* 2019-01-16 WilliamZhu([email protected])
*/
class SQLPythonEnvExt(override val uid: String) extends SQLAlg with WowParams {

def this() = this(BaseParams.randomUID())

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val spark = df.sparkSession

params.get(command.name).map { s =>
set(command, s)
s
}.getOrElse {
throw new MLSQLException(s"${command.name} is required")
}

params.get(condaYamlFilePath.name).map { s =>
set(condaYamlFilePath, s)
}.getOrElse {
params.get(condaFile.name).map { s =>
val condaContent = spark.table(s).head().getString(0)
val baseFile = path + "/__mlsql_temp_dir__/conda"
val fileName = "conda.yaml"
HDFSOperator.saveFile(baseFile, fileName, Seq(("", condaContent)).iterator)
set(condaYamlFilePath, baseFile + "/" + fileName)
}.getOrElse {
throw new MLSQLException(s"${condaFile.name} || ${condaYamlFilePath} is required")
}

}

val wowCommand = $(command) match {
case "create" => Message.AddEnvCommand
case "remove" => Message.RemoveEnvCommand
}

val remoteCommand = Message.CreateOrRemovePythonCondaEnv($(condaYamlFilePath), params, wowCommand)

val executorNum = if (spark.sparkContext.isLocal) {
val psDriverBackend = PlatformManager.getRuntime.asInstanceOf[SparkRuntime].localSchedulerBackend
psDriverBackend.localEndpoint.askSync[Integer](remoteCommand)
} else {
val psDriverBackend = PlatformManager.getRuntime.asInstanceOf[SparkRuntime].psDriverBackend
psDriverBackend.psDriverRpcEndpointRef.askSync[Integer](remoteCommand)
}
import spark.implicits._
Seq[Seq[Int]](Seq(executorNum)).toDF("success_executor_num")
}


override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
train(df, path, params)
}

override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = throw new RuntimeException("register is not support")

override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = throw new RuntimeException("register is not support")

final val command: Param[String] = new Param[String](this, "command", "", isValid = (s:String) => {
s == "create" || s == "remove"
})

final val condaYamlFilePath: Param[String] = new Param[String](this, "condaYamlFilePath", "")
final val condaFile: Param[String] = new Param[String](this, "condaFile", "")
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class MLProject(val projectDir: String, project: Settings, options: Map[String,
}

def entryPointCommandWithConda(commandType: String) = {
val condaEnvManager = new CondaEnvManager(options)
val condaEnvManager = new BasicCondaEnvManager(options)
val condaEnvName = condaEnvManager.getOrCreateCondaEnv(Option(projectDir + s"/${MLProject.DEFAULT_CONDA_ENV_NAME}"))
val entryPointCommandWithConda = commandWithConda(
condaEnvManager.getCondaBinExecutable("activate"),
Expand All @@ -91,7 +91,7 @@ class MLProject(val projectDir: String, project: Settings, options: Map[String,
}

def condaEnvCommand = {
val condaEnvManager = new CondaEnvManager(options)
val condaEnvManager = new BasicCondaEnvManager(options)
val condaEnvName = condaEnvManager.getOrCreateCondaEnv(Option(projectDir + s"/${MLProject.DEFAULT_CONDA_ENV_NAME}"))
val command = s"source ${condaEnvManager.getCondaBinExecutable("activate")} ${condaEnvName}"
logInfo(format(s"=== generate command '${command}' for ${projectDir} === "))
Expand Down Expand Up @@ -198,9 +198,9 @@ class CondaEnvManager(options: Map[String, String]) extends Logging with WowLog
}

def getCondaBinExecutable(executableName: String) = {
val condaHome = options.get(CondaEnvManager.condaHomeKey) match {
val condaHome = options.get(BasicCondaEnvManager.condaHomeKey) match {
case Some(home) => home
case None => System.getenv(CondaEnvManager.condaHomeKey)
case None => System.getenv(BasicCondaEnvManager.condaHomeKey)
}
if (condaHome != null) {
s"${condaHome}/bin/${executableName}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.spark.ps.cluster

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisteredExecutor

/**
* Created by allwefantasy on 30/1/2018.
Expand All @@ -39,10 +38,17 @@ object Message {
cores: Int,
logUrls: Map[String, String])

case class TensorFlowModelClean(modelPath: String)

case class CopyModelToLocal(modelPath: String, destPath: String)

case class CreateOrRemovePythonCondaEnv(condaYamlFile: String, options: Map[String, String], command: EnvCommand)

sealed abstract class EnvCommand

case object AddEnvCommand extends EnvCommand

case object RemoveEnvCommand extends EnvCommand

case object Ping

case class Pong(executorId: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.ps.cluster

import java.util.Locale

import org.apache.spark.internal.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.util.ThreadUtils
import streaming.common.HDFSOperator
import streaming.dsl.mmlib.algs.python.BasicCondaEnvManager

import scala.util.{Failure, Success}

Expand Down Expand Up @@ -70,6 +71,16 @@ class PSExecutorBackend(env: SparkEnv, override val rpcEnv: RpcEnv, psDriverUrl:
HDFSOperator.copyToLocalFile(destPath, modelPath, true)
context.reply(true)
}
case Message.CreateOrRemovePythonCondaEnv(condaYamlFile, options, command) => {
val condaEnvManager = new BasicCondaEnvManager(options)
command match {
case Message.AddEnvCommand =>
condaEnvManager.getOrCreateCondaEnv(Option(condaYamlFile))
case Message.RemoveEnvCommand =>
condaEnvManager.removeEnv(Option(condaYamlFile))
}
context.reply(true)
}
case Message.Ping =>
logInfo(s"received message ${Message.Ping}")
context.reply(Message.Pong(psExecutorId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ package org.apache.spark.ps.local
import java.io.File
import java.net.URL

import org.apache.spark.{SparkConf, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
import org.apache.spark.ps.cluster.Message
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.StopExecutor
import org.apache.spark.{SparkConf, SparkContext, SparkEnv}
import streaming.common.HDFSOperator
import streaming.dsl.mmlib.algs.python.BasicCondaEnvManager


private case class TensorFlowModelClean(modelPath: String)
Expand Down Expand Up @@ -61,6 +62,16 @@ class LocalPSEndpoint(override val rpcEnv: RpcEnv,
HDFSOperator.copyToLocalFile(destPath, modelPath, true)
context.reply(true)
}
case Message.CreateOrRemovePythonCondaEnv(condaYamlFile, options, command) => {
val condaEnvManager = new BasicCondaEnvManager(options)
command match {
case Message.AddEnvCommand =>
condaEnvManager.getOrCreateCondaEnv(Option(condaYamlFile))
case Message.RemoveEnvCommand =>
condaEnvManager.removeEnv(Option(condaYamlFile))
}
context.reply(1)
}
case Message.Ping =>
logInfo(s"received message ${Message.Ping}")
val response = Message.Pong("localhost")
Expand Down Expand Up @@ -107,10 +118,6 @@ class LocalPSSchedulerBackend(sparkContext: SparkContext)
stop(SparkAppHandle.State.FINISHED)
}

def cleanTensorFlowModel(modelPath: String) = {
localEndpoint.askSync[Boolean](TensorFlowModelClean(modelPath))
}

private def stop(finalState: SparkAppHandle.State): Unit = {
localEndpoint.ask(StopExecutor)
try {
Expand Down
Loading

0 comments on commit ebcd355

Please sign in to comment.