Skip to content

Commit

Permalink
remove useless function in ET Ray
Browse files Browse the repository at this point in the history
add run-test.sh to dev directory

add #!/usr/bin/env bash to dev/run-test.sh; remove redundant import in ray
  • Loading branch information
allwefantasy committed Aug 20, 2021
1 parent 9ec165e commit cf994d6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 92 deletions.
6 changes: 6 additions & 0 deletions dev/run-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash
## example
## ./dev/run-test.sh

mvn clean install -DskipTests
mvn test -pl streamingpro-it
95 changes: 3 additions & 92 deletions streamingpro-mlsql/src/main/java/tech/mlsql/ets/Ray.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
package tech.mlsql.ets

import java.net.{InetAddress, ServerSocket}

import java.util
import java.util.concurrent.atomic.AtomicReference

import org.apache.spark.ml.param.Param
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession, SparkUtils}
import org.apache.spark.{MLSQLSparkUtils, SparkEnv, SparkInstanceService, TaskContext, WowRowEncoder}
import streaming.core.datasource.util.MLSQLJobCollect
import org.apache.spark.{TaskContext, WowRowEncoder}
import streaming.dsl.ScriptSQLExec
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.Functions
Expand All @@ -19,21 +17,15 @@ import tech.mlsql.arrow.python.PythonWorkerFactory
import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext}
import tech.mlsql.arrow.python.ispark.SparkContextImp
import tech.mlsql.arrow.python.runner._
import tech.mlsql.common.utils.base.TryTool
import tech.mlsql.common.utils.distribute.socket.server.{ReportHostAndPort, SocketServerInExecutor}
import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros
import tech.mlsql.common.utils.net.NetTool
import tech.mlsql.common.utils.network.NetUtils
import tech.mlsql.common.utils.serder.json.JSONTool
import tech.mlsql.ets.ray.{CollectServerInDriver, DataServer}
import tech.mlsql.log.WriteLog
import tech.mlsql.ets.ray.DataServer
import tech.mlsql.schema.parser.SparkSimpleSchemaParser
import tech.mlsql.session.SetSession
import tech.mlsql.tool.MasterSlaveInSpark
import tech.mlsql.version.VersionCompatibility

import scala.collection.mutable.ArrayBuffer

/**
* 24/12/2019 WilliamZhu([email protected])
*/
Expand Down Expand Up @@ -68,87 +60,6 @@ class Ray(override val uid: String) extends SQLAlg with VersionCompatibility wit
newdf
}


private def computeSplits(session: SparkSession, df: DataFrame) = {
var targetLen = df.rdd.partitions.length
var sort = false
val context = ScriptSQLExec.context()

MLSQLSparkUtils.isFileTypeTable(df) match {
case true =>
targetLen = 1
sort = true
case false =>
TryTool.tryOrElse {
val resource = new SparkInstanceService(session).resources
val jobInfo = new MLSQLJobCollect(session, context.owner)
val leftResource = resource.totalCores - jobInfo.resourceSummary(null).activeTasks
logInfo(s"RayMode: Resource:[${leftResource}(${resource.totalCores}-${jobInfo.resourceSummary(null).activeTasks})] TargetLen:[${targetLen}]")
if (leftResource / 2 <= targetLen) {
targetLen = Math.max(Math.floor(leftResource / 2) - 1, 1).toInt
}
} {
logWarning(format("Warning: Fail to detect instance resource. Setup 4 data server for Python."))
if (targetLen > 4) targetLen = 4
}
}

(targetLen, sort)
}

private def buildDataSocketServers(session: SparkSession, df: DataFrame, tempdf: DataFrame, _owner: String): CollectServerInDriver[String] = {
val refs = new AtomicReference[ArrayBuffer[ReportHostAndPort]]()
refs.set(ArrayBuffer[ReportHostAndPort]())
val stopFlag = new AtomicReference[String]()
stopFlag.set("false")

val tempSocketServerInDriver = new CollectServerInDriver(refs, stopFlag)

val thread = new Thread("temp-data-server-in-spark") {
override def run(): Unit = {

val dataSchema = df.schema
val tempSocketServerHost = tempSocketServerInDriver._host
val tempSocketServerPort = tempSocketServerInDriver._port
val timezoneID = session.sessionState.conf.sessionLocalTimeZone
val owner = _owner
tempdf.rdd.mapPartitions { iter =>

val host: String = if (SparkEnv.get == null || MLSQLSparkUtils.blockManager == null || MLSQLSparkUtils.blockManager.blockManagerId == null) {
WriteLog.write(List("Ray: Cannot get MLSQLSparkUtils.rpcEnv().address, using NetTool.localHostName()").iterator,
Map("PY_EXECUTE_USER" -> owner))
NetTool.localHostName()
} else if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.mlsql.deploy.on.k8s", false)) {
InetAddress.getLocalHost.getHostAddress
}
else MLSQLSparkUtils.blockManager.blockManagerId.host

val socketRunner = new SparkSocketRunner("serveToStreamWithArrow", host, timezoneID)
val commonTaskContext = new SparkContextImp(TaskContext.get(), null)
val convert = WowRowEncoder.fromRow(dataSchema)
val newIter = iter.map { irow =>
convert(irow)
}
val Array(_server, _host, _port) = socketRunner.serveToStreamWithArrow(newIter, dataSchema, 1000, commonTaskContext)

// send server info back
SocketServerInExecutor.reportHostAndPort(tempSocketServerHost,
tempSocketServerPort,
ReportHostAndPort(_host.toString, _port.toString.toInt))

while (_server != null && !_server.asInstanceOf[ServerSocket].isClosed) {
Thread.sleep(1 * 1000)
}
List[String]().iterator
}.count()
logInfo("Exit all data server")
}
}
thread.setDaemon(true)
thread.start()
tempSocketServerInDriver
}

private def schemaFromStr(schemaStr: String) = {
val targetSchema = schemaStr.trim match {
case item if item.startsWith("{") =>
Expand Down

0 comments on commit cf994d6

Please sign in to comment.