Skip to content

Commit

Permalink
Rest paging support
Browse files Browse the repository at this point in the history
add stop condition

add config skipDynamicEvaluation for datasource/alg

update
  • Loading branch information
allwefantasy committed Nov 22, 2021
1 parent 49ddc53 commit 87bce9e
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import _root_.streaming.dsl.MLSQLExecuteContext
import org.apache.spark.sql._

/**
* 2018-12-20 WilliamZhu([email protected])
*/
* 2018-12-20 WilliamZhu([email protected])
*/

trait MLSQLDataSource {
def dbSplitter = {
Expand All @@ -44,6 +44,10 @@ trait MLSQLSource extends MLSQLDataSource with MLSQLSourceInfo {
def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame
}

trait MLSQLSourceConfig {
def skipDynamicEvaluation = false
}

trait RewriteableSource {
def rewrite(df: DataFrame,
config: DataSourceConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ trait SQLAlg extends Serializable {
def skipOriginalDFName: Boolean = true
def skipResultDFName: Boolean = true

def skipDynamicEvaluation: Boolean = false

def modelType: ModelType = UndefinedType

def doc: Doc = Doc(TextDoc, "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,7 @@ class LoadAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt
}

override def parse(ctx: SqlContext): Unit = {
val LoadStatement(_, format, path, _option, tableName) = analyze(ctx)

val option = _option.map { case (k, v) =>
val newV = Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap)
(k, newV)
}
val LoadStatement(_, format, path, option, tableName) = analyze(ctx)

def isStream = {
scriptSQLExecListener.env().contains("streamName")
Expand All @@ -82,7 +77,7 @@ class LoadAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt
if (isStream) {
scriptSQLExecListener.addEnv("stream", "true")
}
new LoadPRocessing(scriptSQLExecListener, option, path, tableName, format).parse
new LoadProcessing(scriptSQLExecListener, option, path, tableName, format).parse

scriptSQLExecListener.setLastSelectTable(tableName)

Expand All @@ -91,15 +86,29 @@ class LoadAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt

case class LoadStatement(raw: String, format: String, path: String, option: Map[String, String] = Map[String, String](), tableName: String)

class LoadPRocessing(scriptSQLExecListener: ScriptSQLExecListener,
option: Map[String, String],
class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
_option: Map[String, String],
var path: String,
tableName: String,
format: String
) extends DslTool {
def parse = {
var table: DataFrame = null
val sparkSession = scriptSQLExecListener.sparkSession
var option = _option
val tempDS = DataSourceRegistry.fetch(format, option)

if (tempDS.isDefined ) {
// DataSource who is not MLSQLSourceConfig or if it's MLSQLSourceConfig then skipDynamicEvaluation is false
// should evaluate the v with dynamic expression
if (!tempDS.isInstanceOf[MLSQLSourceConfig] || !tempDS.asInstanceOf[MLSQLSourceConfig].skipDynamicEvaluation) {
option = _option.map { case (k, v) =>
val newV = Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap)
(k, newV)
}
}
}

val reader = scriptSQLExecListener.sparkSession.read
reader.options(option)
path = TemplateMerge.merge(path, scriptSQLExecListener.env().toMap)
Expand All @@ -113,8 +122,8 @@ class LoadPRocessing(scriptSQLExecListener: ScriptSQLExecListener,
var sourceInfo: Option[SourceInfo] = None

DataSourceRegistry.fetch(format, option).map { datasource =>
table = datasource.asInstanceOf[ {def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame}].
load(reader, dsConf)
val ds = datasource.asInstanceOf[ {def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame}]
table = ds.load(reader, dsConf)

// extract source info if the datasource is MLSQLSourceInfo
if (datasource.isInstanceOf[MLSQLSourceInfo]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,18 @@ class TrainAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdap

override def parse(ctx: SqlContext): Unit = {
val TrainStatement(_, tableName, format, _path, _options, asTableName) = analyze(ctx)
val sqlAlg = MLMapping.findAlg(format)

var path = _path

var options = _options.map { case (k, v) =>
val newV = Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap)
val newV = if(sqlAlg.skipDynamicEvaluation) v else Templates2.dynamicEvaluateExpression(v, ScriptSQLExec.context().execListener.env().toMap)
(k, newV)
}

val owner = options.get("owner")
val df = scriptSQLExecListener.sparkSession.table(tableName)
val sqlAlg = MLMapping.findAlg(format)


if (!sqlAlg.skipPathPrefix) {
path = withPathPrefix(scriptSQLExecListener.pathPrefix(owner), path)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,30 @@ package tech.mlsql.datasource.impl
import java.net.URLEncoder
import java.nio.charset.Charset

import com.jayway.jsonpath.JsonPath
import org.apache.http.client.fluent.{Form, Request}
import org.apache.http.entity.ContentType
import org.apache.http.entity.mime.{HttpMultipartMode, MultipartEntityBuilder}
import org.apache.http.util.EntityUtils
import org.apache.spark.ml.param.Param
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter, Row}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter, Row, SparkSession, functions => F}
import streaming.core.datasource._
import streaming.dsl.ScriptSQLExec
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import tech.mlsql.common.form._
import tech.mlsql.common.utils.distribute.socket.server.JavaUtils
import tech.mlsql.dsl.adaptor.DslTool
import tech.mlsql.tool.HDFSOperatorV2
import tech.mlsql.tool.{HDFSOperatorV2, Templates2}

import scala.collection.mutable.ArrayBuffer

class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink with MLSQLSourceInfo with MLSQLRegistry with DslTool with WowParams {
class MLSQLRest(override val uid: String) extends MLSQLSource
with MLSQLSink
with MLSQLSourceInfo
with MLSQLSourceConfig
with MLSQLRegistry with DslTool with WowParams {


def this() = this(BaseParams.randomUID())
Expand All @@ -32,6 +37,11 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit
* `config.connect-timeout`="10s"
* and `config.method`="GET"
* and `header.content-type`="application/json"
*
* and `config.page.next`="http://mlsql.tech/api?cursor=:{:page}"
* and `config.page.value`="$.path" -- json path
* and `config.page.limit`="100"
*
* and `body`='''
* {
* "query":"b"
Expand All @@ -53,29 +63,59 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit
* select * from table_3 as output;
**/
override def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame = {
val params = config.config

(config.config.get("config.page.next"), config.config.get("config.page.value")) match {
case (Some(urlTemplate), Some(jsonPath)) =>
val maxSize = config.config.getOrElse("config.page.limit", "1").toInt
var count = 0
var firstDf = _http(config.path, config.config, config.df.get.sparkSession)

while (count < maxSize) {
count += 1
val content = firstDf.select(F.col("content").cast(StringType), F.col("status")).head.getString(0)
val pageValue = JsonPath.read[String](content, jsonPath)
if (pageValue == null || pageValue.isEmpty) {
count = maxSize
} else {
val newUrl = Templates2.dynamicEvaluateExpression(urlTemplate, Map("page" -> pageValue))
firstDf = firstDf.union(_http(newUrl, config.config, config.df.get.sparkSession))
}

}
firstDf

case (None, None) =>
_http(config.path, config.config, config.df.get.sparkSession)
}


}

override def skipDynamicEvaluation = true


private def _http(url: String, params: Map[String, String], session: SparkSession): DataFrame = {
val httpMethod = params.getOrElse(configMethod.name, "get").toLowerCase
val request = httpMethod match {
case "get" =>

val paramsBuf = ArrayBuffer[(String, String)]()
params.filter(_._1.startsWith("form.")).foreach { case (k, v) =>
paramsBuf.append((k.stripPrefix("form."), URLEncoder.encode(v, "utf-8")))
paramsBuf.append((k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(URLEncoder.encode(v, "utf-8"),ScriptSQLExec.context().execListener.env().toMap)))
}

val finalUrl = if (paramsBuf.length > 0) {
val urlParam = paramsBuf.map { case (k, v) => s"${k}=${v}" }.mkString("&")
if (config.path.contains("?")) {
config.path + urlParam
if (url.contains("?")) {
url + "&" + urlParam
} else {
config.path + "?" + urlParam
url + "?" + urlParam
}
} else config.path

} else url
Request.Get(finalUrl)

case "post" => Request.Post(config.path)
case "put" => Request.Put(config.path)
case "post" => Request.Post(url)
case "put" => Request.Put(url)
case v => throw new MLSQLException(s"HTTP method ${v} is not support yet")
}

Expand Down Expand Up @@ -104,7 +144,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit
case ("post", "application/x-www-form-urlencoded") =>
val form = Form.form()
params.filter(_._1.startsWith("form.")).foreach { case (k, v) =>
form.add(k.stripPrefix("form."), v)
form.add(k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(v,ScriptSQLExec.context().execListener.env().toMap))
}
request.bodyForm(form.build(), Charset.forName("utf-8")).execute()

Expand All @@ -115,7 +155,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit
val httpResponse = response.returnResponse()
val status = httpResponse.getStatusLine.getStatusCode
val content = EntityUtils.toByteArray(httpResponse.getEntity)
val session = config.df.get.sparkSession

session.createDataFrame(session.sparkContext.makeRDD(Seq(Row.fromSeq(Seq(content, status))))
, StructType(fields = Seq(
StructField("content", BinaryType), StructField("status", IntegerType)
Expand Down Expand Up @@ -143,7 +183,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit

val paramsBuf = ArrayBuffer[(String, String)]()
params.filter(_._1.startsWith("form.")).foreach { case (k, v) =>
paramsBuf.append((k.stripPrefix("form."), URLEncoder.encode(v, "utf-8")))
paramsBuf.append((k.stripPrefix("form."),Templates2.dynamicEvaluateExpression(URLEncoder.encode(v, "utf-8"),ScriptSQLExec.context().execListener.env().toMap)))
}

val finalUrl = if (paramsBuf.length > 0) {
Expand Down Expand Up @@ -186,7 +226,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit
case ("post", "application/x-www-form-urlencoded") =>
val form = Form.form()
params.filter(_._1.startsWith("form.")).foreach { case (k, v) =>
form.add(k.stripPrefix("form."), v)
form.add(k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(v,ScriptSQLExec.context().execListener.env().toMap))
}
request.bodyForm(form.build(), Charset.forName("utf-8")).execute()

Expand All @@ -207,7 +247,7 @@ class MLSQLRest(override val uid: String) extends MLSQLSource with MLSQLSink wit

params.filter(_._1.startsWith("form.")).
filter(v => v._1 != "form.file-path" && v._1 != "form.file-name").foreach { case (k, v) =>
entity.addTextBody(k.stripPrefix("form."), v)
entity.addTextBody(k.stripPrefix("form."), Templates2.dynamicEvaluateExpression(v,ScriptSQLExec.context().execListener.env().toMap))
}
request.body(entity.build()).execute()
case (_, v) =>
Expand Down

0 comments on commit 87bce9e

Please sign in to comment.