Skip to content

Commit

Permalink
[PSCluster][fix] fix parameter server start fail when set `streaming.…
Browse files Browse the repository at this point in the history
…ps.local.enable`
  • Loading branch information
cfmcgrady committed Dec 18, 2018
1 parent 92a2fb3 commit b6709b6
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.apache.spark.sql.mlsql.session.{SessionIdentifier, SessionManager}
import org.apache.spark.sql.{SQLContext, SparkSession}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._

/**
* Created by allwefantasy on 30/3/2017.
Expand Down Expand Up @@ -64,7 +65,8 @@ class SparkRuntime(_params: JMap[Any, Any]) extends StreamingRuntime with Platfo
conf.setAppName(MLSQLConf.MLSQL_NAME.readFrom(configReader))

def isLocalMaster(conf: SparkConf): Boolean = {
val master = MLSQLConf.MLSQL_MASTER.readFrom(configReader).getOrElse("")
// val master = MLSQLConf.MLSQL_MASTER.readFrom(configReader).getOrElse("")
val master = conf.get("spark.master", "")
master == "local" || master.startsWith("local[")
}

Expand Down Expand Up @@ -172,6 +174,7 @@ class SparkRuntime(_params: JMap[Any, Any]) extends StreamingRuntime with Platfo
if (MLSQLConf.MLSQL_DISABLE_SPARK_LOG.readFrom(configReader)) {
WowLoggerFilter.redirectSparkInfoLogs()
}
show(params.asScala.map(kv => (kv._1.toString, kv._2.toString)).toMap)
ss
}

Expand Down Expand Up @@ -242,6 +245,23 @@ class SparkRuntime(_params: JMap[Any, Any]) extends StreamingRuntime with Platfo

override def startHttpServer: Unit = {}

private def show(conf: Map[String, String]) {
val keyLength = conf.keys.map(_.size).max
val valueLength = conf.values.map(_.size).max
val header = "-" * (keyLength + valueLength + 3)
logInfo("mlsql server start with configuration!")
logInfo(header)
conf.map {
case (key, value) =>
val keyStr = key + (" " * (keyLength - key.size))
val valueStr = value + (" " * (valueLength - value.size))
s"|${keyStr}|${valueStr}|"
}.foreach(line => {
logInfo(line)
})
logInfo(header)
}

}

object SparkRuntime {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,17 @@ class RestController extends ApplicationController {

@At(path = Array("/debug/executor/ping"), types = Array(GET, POST))
def pingExecuotrs = {
val psDriverBackend = runtime.asInstanceOf[SparkRuntime].psDriverBackend
psDriverBackend.psDriverRpcEndpointRef.ask(Message.Ping)
runtime match {
case sparkRuntime: SparkRuntime =>
val endpoint = if (sparkRuntime.sparkSession.sparkContext.isLocal) {
sparkRuntime.localSchedulerBackend.localEndpoint
} else {
sparkRuntime.psDriverBackend.psDriverRpcEndpointRef
}
endpoint.ask(Message.Ping)
case _ =>
throw new RuntimeException(s"unsupport runtime ${runtime.getClass} !")
}
render("{}")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class LocalPSEndpoint(override val rpcEnv: RpcEnv,
val localExecutorHostname = "localhost"

override def receive: PartialFunction[Any, Unit] = {
case Message.Pong(id) =>
logInfo(s"received message ${Message.Pong} from executor ${id}!")
case _ =>

}
Expand All @@ -41,6 +43,11 @@ class LocalPSEndpoint(override val rpcEnv: RpcEnv,
HDFSOperator.copyToLocalFile(destPath, modelPath, true)
context.reply(true)
}
case Message.Ping =>
logInfo(s"received message ${Message.Ping}")
val response = Message.Pong("localhost")
self.send(response)
context.reply(response)
}
}

Expand Down Expand Up @@ -99,4 +106,4 @@ class LocalPSSchedulerBackend(sparkContext: SparkContext)

object LocalExecutorBackend {
var executorBackend: Option[LocalPSSchedulerBackend] = None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class LocalPSEndpoint(override val rpcEnv: RpcEnv,
val localExecutorHostname = "localhost"

override def receive: PartialFunction[Any, Unit] = {
case Message.Pong(id) =>
logInfo(s"received message ${Message.Pong} from executor ${id}!")
case _ =>

}
Expand All @@ -41,6 +43,11 @@ class LocalPSEndpoint(override val rpcEnv: RpcEnv,
HDFSOperator.copyToLocalFile(destPath, modelPath, true)
context.reply(true)
}
case Message.Ping =>
logInfo(s"received message ${Message.Ping}")
val response = Message.Pong("localhost")
self.send(response)
context.reply(response)
}
}

Expand Down Expand Up @@ -99,4 +106,4 @@ class LocalPSSchedulerBackend(sparkContext: SparkContext)

object LocalExecutorBackend {
var executorBackend: Option[LocalPSSchedulerBackend] = None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class LocalPSEndpoint(override val rpcEnv: RpcEnv,
val localExecutorHostname = "localhost"

override def receive: PartialFunction[Any, Unit] = {
case Message.Pong(id) =>
logInfo(s"received message ${Message.Pong} from executor ${id}!")
case _ =>

}
Expand All @@ -41,6 +43,11 @@ class LocalPSEndpoint(override val rpcEnv: RpcEnv,
HDFSOperator.copyToLocalFile(destPath, modelPath, true)
context.reply(true)
}
case Message.Ping =>
logInfo(s"received message ${Message.Ping}")
val response = Message.Pong("localhost")
self.send(response)
context.reply(response)
}
}

Expand Down Expand Up @@ -99,4 +106,4 @@ class LocalPSSchedulerBackend(sparkContext: SparkContext)

object LocalExecutorBackend {
var executorBackend: Option[LocalPSSchedulerBackend] = None
}
}

0 comments on commit b6709b6

Please sign in to comment.