forked from byzer-org/byzer-lang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request byzer-org#901 from allwefantasy/ISSUE-898
Issue 898
- Loading branch information
Showing
4 changed files
with
141 additions
and
2 deletions.
There are no files selected for viewing
71 changes: 71 additions & 0 deletions
71
streamingpro-mlsql/src/main/java/streaming/core/datasource/DataSourceRepository.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package streaming.core.datasource | ||
|
||
import java.net.URLEncoder | ||
|
||
import net.csdn.ServiceFramwork | ||
import net.csdn.common.path.Url | ||
import net.csdn.modules.http.RestRequest | ||
import net.csdn.modules.transport.HttpTransportService | ||
import net.sf.json.{JSONArray, JSONObject} | ||
import streaming.dsl.ScriptSQLExec | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
/** | ||
* 2019-01-14 WilliamZhu([email protected]) | ||
*/ | ||
class DataSourceRepository(url: String) { | ||
def httpClient = ServiceFramwork.injector.getInstance[HttpTransportService](classOf[HttpTransportService]) | ||
|
||
//"http://respository.datasource.mlsql.tech" | ||
def getOrDefaultUrl = { | ||
if (url == null || url.isEmpty) { | ||
val context = ScriptSQLExec.contextGetOrForTest() | ||
require(context.userDefinedParam.contains("__datasource_repository_url__"), "context.__datasource_repository_url__ should be configure if you want use connect DataSourceRepository") | ||
context.userDefinedParam.get("__datasource_repository_url__") | ||
} else { | ||
url | ||
} | ||
} | ||
|
||
def listCommand = { | ||
val res = httpClient.http(new Url(s"${getOrDefaultUrl}/jar/manager/source/mapper"), "{}", RestRequest.Method.POST) | ||
|
||
JSONObject.fromObject(res.getContent).asScala.flatMap { kv => | ||
kv._2.asInstanceOf[JSONArray].asScala.map { _item => | ||
val item = _item.asInstanceOf[JSONObject] | ||
item.put("name", kv._1) | ||
val temp = new JSONArray() | ||
temp.add(item) | ||
val versionList = versionCommand(temp) | ||
val versionArray = new JSONArray | ||
versionList.foreach { v => versionArray.add(v) } | ||
item.put("versions", versionArray) | ||
item.toString | ||
} | ||
}.toSeq | ||
} | ||
|
||
def versionCommand(items: JSONArray) = { | ||
val request = new JSONArray | ||
items.asScala.map { _item => | ||
val item = _item.asInstanceOf[JSONObject] | ||
val json = new JSONObject() | ||
json.put("jarname", item.getString("groupid") + "/" + item.getString("artifactId")) | ||
request.add(json) | ||
} | ||
|
||
val res = httpClient.http(new Url(s"${getOrDefaultUrl}/jar/manager/versions"), request.toString(), RestRequest.Method.POST) | ||
JSONArray.fromObject(res.getContent).get(0).asInstanceOf[JSONObject].asScala.map { kv => | ||
val version = kv._1.asInstanceOf[String].split("/").last | ||
version | ||
}.toSeq | ||
} | ||
|
||
def addCommand(format: String, groupid: String, artifactId: String, version: String) = { | ||
val url = s"http://central.maven.org/maven2/${groupid.replaceAll("\\.", "/")}/${artifactId}/${version}" | ||
// fileName format e.g es, mongodb | ||
s"${getOrDefaultUrl}/jar/manager/http?fileName=${URLEncoder.encode(format, "utf-8")}&url=${URLEncoder.encode(url, "utf-8")}" | ||
} | ||
|
||
} |
68 changes: 68 additions & 0 deletions
68
streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/SQLDataSourceExt.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package streaming.dsl.mmlib.algs | ||
|
||
import org.apache.spark.ml.param.Param | ||
import org.apache.spark.sql.expressions.UserDefinedFunction | ||
import org.apache.spark.sql.{DataFrame, SparkSession} | ||
import streaming.core.datasource.DataSourceRepository | ||
import streaming.dsl.mmlib.SQLAlg | ||
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams} | ||
import streaming.log.{Logging, WowLog} | ||
|
||
/** | ||
* 2019-01-14 WilliamZhu([email protected]) | ||
*/ | ||
class SQLDataSourceExt(override val uid: String) extends SQLAlg with WowParams with Logging with WowLog { | ||
|
||
|
||
override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { | ||
|
||
params.get(command.name).map { item => | ||
set(command, item) | ||
item | ||
}.getOrElse { | ||
set(command, "list") | ||
} | ||
|
||
params.get(repository.name).map { item => | ||
set(repository, item) | ||
item | ||
} | ||
|
||
val dataSourceRepository = new DataSourceRepository($(repository)) | ||
|
||
val spark = df.sparkSession | ||
import spark.implicits._ | ||
$(command) match { | ||
case "list" => | ||
spark.read.json(spark.createDataset(dataSourceRepository.listCommand)) | ||
case "add" => | ||
val Array(dsFormat, groupid, artifactId, version) = path.split("/") | ||
val url = dataSourceRepository.addCommand(dsFormat, groupid, artifactId, version) | ||
val logMsg = format(s"Datasource is loading jar from ${url}") | ||
logInfo(logMsg) | ||
spark.sparkContext.addJar(url) | ||
Seq[Seq[String]](Seq(logMsg)).toDF("desc") | ||
} | ||
} | ||
|
||
|
||
override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { | ||
train(df, path, params) | ||
} | ||
|
||
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = { | ||
null | ||
} | ||
|
||
override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = { | ||
null | ||
} | ||
|
||
final val command: Param[String] = new Param[String](this, "command", "list|version|add", isValid = (m: String) => { | ||
m == "list" || m == "version" || m == "add" | ||
}) | ||
|
||
final val repository: Param[String] = new Param[String](this, "repository", "repository url") | ||
|
||
def this() = this(BaseParams.randomUID()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters