diff --git a/streamingpro-core/src/main/java/streaming/core/datasource/MLSQLSource.scala b/streamingpro-core/src/main/java/streaming/core/datasource/MLSQLSource.scala index 84a8afccc..b73b2f517 100644 --- a/streamingpro-core/src/main/java/streaming/core/datasource/MLSQLSource.scala +++ b/streamingpro-core/src/main/java/streaming/core/datasource/MLSQLSource.scala @@ -22,8 +22,8 @@ import _root_.streaming.dsl.MLSQLExecuteContext import org.apache.spark.sql._ /** - * 2018-12-20 WilliamZhu(allwefantasy@gmail.com) - */ + * 2018-12-20 WilliamZhu(allwefantasy@gmail.com) + */ trait MLSQLDataSource { def dbSplitter = { @@ -44,6 +44,10 @@ trait MLSQLSource extends MLSQLDataSource with MLSQLSourceInfo { def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame } +trait MLSQLSourceConfig { + def skipDynamicEvaluation = false +} + trait RewriteableSource { def rewrite(df: DataFrame, config: DataSourceConfig, diff --git a/streamingpro-core/src/main/java/streaming/dsl/mmlib/SQLAlg.scala b/streamingpro-core/src/main/java/streaming/dsl/mmlib/SQLAlg.scala index a453d7a79..8f5faeaad 100644 --- a/streamingpro-core/src/main/java/streaming/dsl/mmlib/SQLAlg.scala +++ b/streamingpro-core/src/main/java/streaming/dsl/mmlib/SQLAlg.scala @@ -52,6 +52,8 @@ trait SQLAlg extends Serializable { def skipOriginalDFName: Boolean = true def skipResultDFName: Boolean = true + def skipDynamicEvaluation: Boolean = false + def modelType: ModelType = UndefinedType def doc: Doc = Doc(TextDoc, "") diff --git a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala index 4a2cb0d63..0f9ecb8cd 100644 --- a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala +++ b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala @@ -68,12 +68,7 @@ class LoadAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt } override def parse(ctx: SqlContext): Unit = { - val LoadStatement(_, format, path, _option, tableName) = analyze(ctx) - - val option = _option.map { case (k, v) => - val newV = Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap) - (k, newV) - } + val LoadStatement(_, format, path, option, tableName) = analyze(ctx) def isStream = { scriptSQLExecListener.env().contains("streamName") @@ -82,7 +77,7 @@ class LoadAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt if (isStream) { scriptSQLExecListener.addEnv("stream", "true") } - new LoadPRocessing(scriptSQLExecListener, option, path, tableName, format).parse + new LoadProcessing(scriptSQLExecListener, option, path, tableName, format).parse scriptSQLExecListener.setLastSelectTable(tableName) @@ -91,8 +86,8 @@ class LoadAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt case class LoadStatement(raw: String, format: String, path: String, option: Map[String, String] = Map[String, String](), tableName: String) -class LoadPRocessing(scriptSQLExecListener: ScriptSQLExecListener, - option: Map[String, String], +class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener, + _option: Map[String, String], var path: String, tableName: String, format: String @@ -100,6 +95,20 @@ class LoadPRocessing(scriptSQLExecListener: ScriptSQLExecListener, def parse = { var table: DataFrame = null val sparkSession = scriptSQLExecListener.sparkSession + var option = _option + val tempDS = DataSourceRegistry.fetch(format, option) + + if (tempDS.isDefined ) { + // DataSource who is not MLSQLSourceConfig or if it's MLSQLSourceConfig then skipDynamicEvaluation is false + // should evaluate the v with dynamic expression + if (!tempDS.isInstanceOf[MLSQLSourceConfig] || !tempDS.asInstanceOf[MLSQLSourceConfig].skipDynamicEvaluation) { + option = _option.map { case (k, v) => + val newV = Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap) + (k, newV) + } + } + } + val reader = scriptSQLExecListener.sparkSession.read reader.options(option) path = TemplateMerge.merge(path, scriptSQLExecListener.env().toMap) @@ -113,8 +122,8 @@ class LoadPRocessing(scriptSQLExecListener: ScriptSQLExecListener, var sourceInfo: Option[SourceInfo] = None DataSourceRegistry.fetch(format, option).map { datasource => - table = datasource.asInstanceOf[ {def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame}]. - load(reader, dsConf) + val ds = datasource.asInstanceOf[ {def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame}] + table = ds.load(reader, dsConf) // extract source info if the datasource is MLSQLSourceInfo if (datasource.isInstanceOf[MLSQLSourceInfo]) { diff --git a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/TrainAdaptor.scala b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/TrainAdaptor.scala index f5d8031d5..1c83f5e94 100644 --- a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/TrainAdaptor.scala +++ b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/TrainAdaptor.scala @@ -70,15 +70,18 @@ class TrainAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdap override def parse(ctx: SqlContext): Unit = { val TrainStatement(_, tableName, format, _path, _options, asTableName) = analyze(ctx) + val sqlAlg = MLMapping.findAlg(format) + var path = _path + var options = _options.map { case (k, v) => - val newV = Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap) + val newV = if(sqlAlg.skipDynamicEvaluation) v else Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap) (k, newV) } + val owner = options.get("owner") val df = scriptSQLExecListener.sparkSession.table(tableName) - val sqlAlg = MLMapping.findAlg(format) - + if (!sqlAlg.skipPathPrefix) { path = withPathPrefix(scriptSQLExecListener.pathPrefix(owner), path) } diff --git a/streamingpro-mlsql/src/main/java/tech/mlsql/datasource/impl/MLSQLRest.scala b/streamingpro-mlsql/src/main/java/tech/mlsql/datasource/impl/MLSQLRest.scala index d2dd05a2f..a08373b0c 100644 --- a/streamingpro-mlsql/src/main/java/tech/mlsql/datasource/impl/MLSQLRest.scala +++ b/streamingpro-mlsql/src/main/java/tech/mlsql/datasource/impl/MLSQLRest.scala @@ -3,25 +3,30 @@ package tech.mlsql.datasource.impl import java.net.URLEncoder import java.nio.charset.Charset +import com.jayway.jsonpath.JsonPath import org.apache.http.client.fluent.{Form, Request} import org.apache.http.entity.ContentType import org.apache.http.entity.mime.{HttpMultipartMode, MultipartEntityBuilder} import org.apache.http.util.EntityUtils import org.apache.spark.ml.param.Param import org.apache.spark.sql.mlsql.session.MLSQLException -import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} -import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter, Row} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter, Row, SparkSession, functions => F} import streaming.core.datasource._ import streaming.dsl.ScriptSQLExec import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams} import tech.mlsql.common.form._ import tech.mlsql.common.utils.distribute.socket.server.JavaUtils import tech.mlsql.dsl.adaptor.DslTool -import tech.mlsql.tool.HDFSOperatorV2 +import tech.mlsql.tool.{HDFSOperatorV2, Templates2} import scala.collection.mutable.ArrayBuffer -class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink with MLSQLSourceInfo with MLSQLRegistry with DslTool with WowParams { +class MLSQLRest(override val uid: String) extends MLSQLSource + with MLSQLSink + with MLSQLSourceInfo + with MLSQLSourceConfig + with MLSQLRegistry with DslTool with WowParams { def this() = this(BaseParams.randomUID()) @@ -32,6 +37,11 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit * `config.connect-timeout`="10s" * and `config.method`="GET" * and `header.content-type`="application/json" + * + * and `config.page.next`="http://mlsql.tech/api?cursor=:{:page}" + * and `config.page.value`="$.path" -- json path + * and `config.page.limit`="100" + * * and `body`=''' * { * "query":"b" @@ -53,29 +63,59 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit * select * from table_3 as output; **/ override def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame = { - val params = config.config + + (config.config.get("config.page.next"), config.config.get("config.page.value")) match { + case (Some(urlTemplate), Some(jsonPath)) => + val maxSize = config.config.getOrElse("config.page.limit", "1").toInt + var count = 0 + var firstDf = _http(config.path, config.config, config.df.get.sparkSession) + + while (count < maxSize) { + count += 1 + val content = firstDf.select(F.col("content").cast(StringType), F.col("status")).head.getString(0) + val pageValue = JsonPath.read[String](content, jsonPath) + if (pageValue == null || pageValue.isEmpty) { + count = maxSize + } else { + val newUrl = Templates2.dynamicEvaluateExpression(urlTemplate, Map("page" -> pageValue)) + firstDf = firstDf.union(_http(newUrl, config.config, config.df.get.sparkSession)) + } + + } + firstDf + + case (None, None) => + _http(config.path, config.config, config.df.get.sparkSession) + } + + + } + + override def skipDynamicEvaluation = true + + + private def _http(url: String, params: Map[String, String], session: SparkSession): DataFrame = { val httpMethod = params.getOrElse(configMethod.name, "get").toLowerCase val request = httpMethod match { case "get" => val paramsBuf = ArrayBuffer[(String, String)]() params.filter(_._1.startsWith("form.")).foreach { case (k, v) => - paramsBuf.append((k.stripPrefix("form."), URLEncoder.encode(v, "utf-8"))) + paramsBuf.append((k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(URLEncoder.encode(v, "utf-8"),ScriptSQLExec.context().execListener.env().toMap))) } val finalUrl = if (paramsBuf.length > 0) { val urlParam = paramsBuf.map { case (k, v) => s"${k}=${v}" }.mkString("&") - if (config.path.contains("?")) { - config.path + urlParam + if (url.contains("?")) { + url + "&" + urlParam } else { - config.path + "?" + urlParam + url + "?" + urlParam } - } else config.path - + } else url Request.Get(finalUrl) - case "post" => Request.Post(config.path) - case "put" => Request.Put(config.path) + case "post" => Request.Post(url) + case "put" => Request.Put(url) case v => throw new MLSQLException(s"HTTP method ${v} is not support yet") } @@ -104,7 +144,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit case ("post", "application/x-www-form-urlencoded") => val form = Form.form() params.filter(_._1.startsWith("form.")).foreach { case (k, v) => - form.add(k.stripPrefix("form."), v) + form.add(k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(v,ScriptSQLExec.context().execListener.env().toMap)) } request.bodyForm(form.build(), Charset.forName("utf-8")).execute() @@ -115,7 +155,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit val httpResponse = response.returnResponse() val status = httpResponse.getStatusLine.getStatusCode val content = EntityUtils.toByteArray(httpResponse.getEntity) - val session = config.df.get.sparkSession + session.createDataFrame(session.sparkContext.makeRDD(Seq(Row.fromSeq(Seq(content, status)))) , StructType(fields = Seq( StructField("content", BinaryType), StructField("status", IntegerType) @@ -143,7 +183,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit val paramsBuf = ArrayBuffer[(String, String)]() params.filter(_._1.startsWith("form.")).foreach { case (k, v) => - paramsBuf.append((k.stripPrefix("form."), URLEncoder.encode(v, "utf-8"))) + paramsBuf.append((k.stripPrefix("form."),Templates2.dynamicEvaluateExpression(URLEncoder.encode(v, "utf-8"),ScriptSQLExec.context().execListener.env().toMap))) } val finalUrl = if (paramsBuf.length > 0) { @@ -186,7 +226,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit case ("post", "application/x-www-form-urlencoded") => val form = Form.form() params.filter(_._1.startsWith("form.")).foreach { case (k, v) => - form.add(k.stripPrefix("form."), v) + form.add(k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(v,ScriptSQLExec.context().execListener.env().toMap)) } request.bodyForm(form.build(), Charset.forName("utf-8")).execute() @@ -207,7 +247,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit params.filter(_._1.startsWith("form.")). filter(v => v._1 != "form.file-path" && v._1 != "form.file-name").foreach { case (k, v) => - entity.addTextBody(k.stripPrefix("form."), v) + entity.addTextBody(k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(v,ScriptSQLExec.context().execListener.env().toMap)) } request.body(entity.build()).execute() case (_, v) =>