forked from byzer-org/byzer-lang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add run-test.sh to dev directory add #!/usr/bin/env bash to dev/run-test.sh; remove redundant import in ray
- Loading branch information
1 parent
9ec165e
commit cf994d6
Showing
2 changed files
with
9 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env bash | ||
## example | ||
## ./dev/run-test.sh | ||
|
||
mvn clean install -DskipTests | ||
mvn test -pl streamingpro-it |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,14 @@ | ||
package tech.mlsql.ets | ||
|
||
import java.net.{InetAddress, ServerSocket} | ||
|
||
import java.util | ||
import java.util.concurrent.atomic.AtomicReference | ||
|
||
import org.apache.spark.ml.param.Param | ||
import org.apache.spark.sql.expressions.UserDefinedFunction | ||
import org.apache.spark.sql.mlsql.session.MLSQLException | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.{DataFrame, Row, SparkSession, SparkUtils} | ||
import org.apache.spark.{MLSQLSparkUtils, SparkEnv, SparkInstanceService, TaskContext, WowRowEncoder} | ||
import streaming.core.datasource.util.MLSQLJobCollect | ||
import org.apache.spark.{TaskContext, WowRowEncoder} | ||
import streaming.dsl.ScriptSQLExec | ||
import streaming.dsl.mmlib._ | ||
import streaming.dsl.mmlib.algs.Functions | ||
|
@@ -19,21 +17,15 @@ import tech.mlsql.arrow.python.PythonWorkerFactory | |
import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext} | ||
import tech.mlsql.arrow.python.ispark.SparkContextImp | ||
import tech.mlsql.arrow.python.runner._ | ||
import tech.mlsql.common.utils.base.TryTool | ||
import tech.mlsql.common.utils.distribute.socket.server.{ReportHostAndPort, SocketServerInExecutor} | ||
import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros | ||
import tech.mlsql.common.utils.net.NetTool | ||
import tech.mlsql.common.utils.network.NetUtils | ||
import tech.mlsql.common.utils.serder.json.JSONTool | ||
import tech.mlsql.ets.ray.{CollectServerInDriver, DataServer} | ||
import tech.mlsql.log.WriteLog | ||
import tech.mlsql.ets.ray.DataServer | ||
import tech.mlsql.schema.parser.SparkSimpleSchemaParser | ||
import tech.mlsql.session.SetSession | ||
import tech.mlsql.tool.MasterSlaveInSpark | ||
import tech.mlsql.version.VersionCompatibility | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
/** | ||
* 24/12/2019 WilliamZhu([email protected]) | ||
*/ | ||
|
@@ -68,87 +60,6 @@ class Ray(override val uid: String) extends SQLAlg with VersionCompatibility wit | |
newdf | ||
} | ||
|
||
|
||
private def computeSplits(session: SparkSession, df: DataFrame) = { | ||
var targetLen = df.rdd.partitions.length | ||
var sort = false | ||
val context = ScriptSQLExec.context() | ||
|
||
MLSQLSparkUtils.isFileTypeTable(df) match { | ||
case true => | ||
targetLen = 1 | ||
sort = true | ||
case false => | ||
TryTool.tryOrElse { | ||
val resource = new SparkInstanceService(session).resources | ||
val jobInfo = new MLSQLJobCollect(session, context.owner) | ||
val leftResource = resource.totalCores - jobInfo.resourceSummary(null).activeTasks | ||
logInfo(s"RayMode: Resource:[${leftResource}(${resource.totalCores}-${jobInfo.resourceSummary(null).activeTasks})] TargetLen:[${targetLen}]") | ||
if (leftResource / 2 <= targetLen) { | ||
targetLen = Math.max(Math.floor(leftResource / 2) - 1, 1).toInt | ||
} | ||
} { | ||
logWarning(format("Warning: Fail to detect instance resource. Setup 4 data server for Python.")) | ||
if (targetLen > 4) targetLen = 4 | ||
} | ||
} | ||
|
||
(targetLen, sort) | ||
} | ||
|
||
private def buildDataSocketServers(session: SparkSession, df: DataFrame, tempdf: DataFrame, _owner: String): CollectServerInDriver[String] = { | ||
val refs = new AtomicReference[ArrayBuffer[ReportHostAndPort]]() | ||
refs.set(ArrayBuffer[ReportHostAndPort]()) | ||
val stopFlag = new AtomicReference[String]() | ||
stopFlag.set("false") | ||
|
||
val tempSocketServerInDriver = new CollectServerInDriver(refs, stopFlag) | ||
|
||
val thread = new Thread("temp-data-server-in-spark") { | ||
override def run(): Unit = { | ||
|
||
val dataSchema = df.schema | ||
val tempSocketServerHost = tempSocketServerInDriver._host | ||
val tempSocketServerPort = tempSocketServerInDriver._port | ||
val timezoneID = session.sessionState.conf.sessionLocalTimeZone | ||
val owner = _owner | ||
tempdf.rdd.mapPartitions { iter => | ||
|
||
val host: String = if (SparkEnv.get == null || MLSQLSparkUtils.blockManager == null || MLSQLSparkUtils.blockManager.blockManagerId == null) { | ||
WriteLog.write(List("Ray: Cannot get MLSQLSparkUtils.rpcEnv().address, using NetTool.localHostName()").iterator, | ||
Map("PY_EXECUTE_USER" -> owner)) | ||
NetTool.localHostName() | ||
} else if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.mlsql.deploy.on.k8s", false)) { | ||
InetAddress.getLocalHost.getHostAddress | ||
} | ||
else MLSQLSparkUtils.blockManager.blockManagerId.host | ||
|
||
val socketRunner = new SparkSocketRunner("serveToStreamWithArrow", host, timezoneID) | ||
val commonTaskContext = new SparkContextImp(TaskContext.get(), null) | ||
val convert = WowRowEncoder.fromRow(dataSchema) | ||
val newIter = iter.map { irow => | ||
convert(irow) | ||
} | ||
val Array(_server, _host, _port) = socketRunner.serveToStreamWithArrow(newIter, dataSchema, 1000, commonTaskContext) | ||
|
||
// send server info back | ||
SocketServerInExecutor.reportHostAndPort(tempSocketServerHost, | ||
tempSocketServerPort, | ||
ReportHostAndPort(_host.toString, _port.toString.toInt)) | ||
|
||
while (_server != null && !_server.asInstanceOf[ServerSocket].isClosed) { | ||
Thread.sleep(1 * 1000) | ||
} | ||
List[String]().iterator | ||
}.count() | ||
logInfo("Exit all data server") | ||
} | ||
} | ||
thread.setDaemon(true) | ||
thread.start() | ||
tempSocketServerInDriver | ||
} | ||
|
||
private def schemaFromStr(schemaStr: String) = { | ||
val targetSchema = schemaStr.trim match { | ||
case item if item.startsWith("{") => | ||
|