Skip to content

Commit

Permalink
Merge pull request byzer-org#901 from allwefantasy/ISSUE-898
Browse files Browse the repository at this point in the history
Issue 898
  • Loading branch information
allwefantasy authored Jan 14, 2019
2 parents fb400cb + 4022456 commit 1323a6e
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package streaming.core.datasource

import java.net.URLEncoder

import net.csdn.ServiceFramwork
import net.csdn.common.path.Url
import net.csdn.modules.http.RestRequest
import net.csdn.modules.transport.HttpTransportService
import net.sf.json.{JSONArray, JSONObject}
import streaming.dsl.ScriptSQLExec

import scala.collection.JavaConverters._

/**
* 2019-01-14 WilliamZhu([email protected])
*/
class DataSourceRepository(url: String) {
def httpClient = ServiceFramwork.injector.getInstance[HttpTransportService](classOf[HttpTransportService])

//"http://respository.datasource.mlsql.tech"
def getOrDefaultUrl = {
if (url == null || url.isEmpty) {
val context = ScriptSQLExec.contextGetOrForTest()
require(context.userDefinedParam.contains("__datasource_repository_url__"), "context.__datasource_repository_url__ should be configure if you want use connect DataSourceRepository")
context.userDefinedParam.get("__datasource_repository_url__")
} else {
url
}
}

def listCommand = {
val res = httpClient.http(new Url(s"${getOrDefaultUrl}/jar/manager/source/mapper"), "{}", RestRequest.Method.POST)

JSONObject.fromObject(res.getContent).asScala.flatMap { kv =>
kv._2.asInstanceOf[JSONArray].asScala.map { _item =>
val item = _item.asInstanceOf[JSONObject]
item.put("name", kv._1)
val temp = new JSONArray()
temp.add(item)
val versionList = versionCommand(temp)
val versionArray = new JSONArray
versionList.foreach { v => versionArray.add(v) }
item.put("versions", versionArray)
item.toString
}
}.toSeq
}

def versionCommand(items: JSONArray) = {
val request = new JSONArray
items.asScala.map { _item =>
val item = _item.asInstanceOf[JSONObject]
val json = new JSONObject()
json.put("jarname", item.getString("groupid") + "/" + item.getString("artifactId"))
request.add(json)
}

val res = httpClient.http(new Url(s"${getOrDefaultUrl}/jar/manager/versions"), request.toString(), RestRequest.Method.POST)
JSONArray.fromObject(res.getContent).get(0).asInstanceOf[JSONObject].asScala.map { kv =>
val version = kv._1.asInstanceOf[String].split("/").last
version
}.toSeq
}

def addCommand(format: String, groupid: String, artifactId: String, version: String) = {
val url = s"http://central.maven.org/maven2/${groupid.replaceAll("\\.", "/")}/${artifactId}/${version}"
// fileName format e.g es, mongodb
s"${getOrDefaultUrl}/jar/manager/http?fileName=${URLEncoder.encode(format, "utf-8")}&url=${URLEncoder.encode(url, "utf-8")}"
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package streaming.dsl.mmlib.algs

import org.apache.spark.ml.param.Param
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, SparkSession}
import streaming.core.datasource.DataSourceRepository
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.log.{Logging, WowLog}

/**
* 2019-01-14 WilliamZhu([email protected])
*/
class SQLDataSourceExt(override val uid: String) extends SQLAlg with WowParams with Logging with WowLog {


override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {

params.get(command.name).map { item =>
set(command, item)
item
}.getOrElse {
set(command, "list")
}

params.get(repository.name).map { item =>
set(repository, item)
item
}

val dataSourceRepository = new DataSourceRepository($(repository))

val spark = df.sparkSession
import spark.implicits._
$(command) match {
case "list" =>
spark.read.json(spark.createDataset(dataSourceRepository.listCommand))
case "add" =>
val Array(dsFormat, groupid, artifactId, version) = path.split("/")
val url = dataSourceRepository.addCommand(dsFormat, groupid, artifactId, version)
val logMsg = format(s"Datasource is loading jar from ${url}")
logInfo(logMsg)
spark.sparkContext.addJar(url)
Seq[Seq[String]](Seq(logMsg)).toDF("desc")
}
}


override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
train(df, path, params)
}

override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
null
}

override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
null
}

final val command: Param[String] = new Param[String](this, "command", "list|version|add", isValid = (m: String) => {
m == "list" || m == "version" || m == "add"
})

final val repository: Param[String] = new Param[String](this, "repository", "repository url")

def this() = this(BaseParams.randomUID())
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class PythonTrain extends Functions with Serializable {


if (!keepVersion) {
if (path.contains("..") || path == "/" || path.split("\\.").length < 3) {
if (path.contains("..") || path == "/" || path.split("/").length < 3) {
throw new MLSQLException("path should at least three layer")
}
HDFSOperator.deleteDir(SQLPythonFunc.getAlgModelPath(path, keepVersion))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}:
# log level for this class is used to overwrite the root logger's log level, so that
# the user can have different defaults for the shell and regular Spark apps.
log4j.logger.org.apache.spark.repl.Main=WARN
log4j.logger.org.apache.spark=WARN
#log4j.logger.org.apache.spark=WARN

# Settings to quiet third party logs that are too verbose
log4j.logger.org.spark_project.jetty=WARN
Expand Down

0 comments on commit 1323a6e

Please sign in to comment.