Skip to content

Commit

Permalink
允许用户自定义一个插件,该插件会在用户执行 load, save 以及 !hdfs/!fs 之前被执行 (byzer-org#1802)
Browse files Browse the repository at this point in the history
* support hooks before load/save/fs execute

* rename rewrite_[number] to more specific name

* pathPrefix fix in LoadAdaptor & SaveAdaptor

* pathPrefix fix in LoadAdaptor & SaveAdaptor

Co-authored-by: jiachuan.zhu <[email protected]>
  • Loading branch information
allwefantasy and chncaesar authored Aug 5, 2022
1 parent 8f0ebe5 commit 844c0cb
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package streaming.core.datasource

import _root_.streaming.dsl.MLSQLExecuteContext
import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql._

/**
Expand Down Expand Up @@ -48,17 +49,37 @@ trait MLSQLSourceConfig {
def skipDynamicEvaluation = false
}

trait RewriteableSource {
trait RewritableSource {
def rewrite(df: DataFrame,
config: DataSourceConfig,
sourceInfo: Option[SourceInfo],
context: MLSQLExecuteContext): DataFrame
}

trait RewritableSourceConfig {
def rewrite_conf(config: DataSourceConfig, format: String,
context: MLSQLExecuteContext): DataSourceConfig

def rewrite_source(sourceInfo: SourceInfo, format: String,
context: MLSQLExecuteContext): SourceInfo
}

trait MLSQLSink extends MLSQLDataSource {
def save(writer: DataFrameWriter[Row], config: DataSinkConfig): Any
}

trait RewritableSinkConfig {
def rewrite(config: DataSinkConfig, format: String,
context: MLSQLExecuteContext): DataSinkConfig
}

case class FSConfig(conf: Configuration, path: String, params: Map[String, String])

trait RewritableFSConfig {
def rewrite(config: FSConfig,
context: MLSQLExecuteContext): FSConfig
}

trait MLSQLDirectSource extends MLSQLDataSource {
def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ case class LoadStatement(raw: String, format: String, path: String, option: Map[

class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
_option: Map[String, String],
var path: String,
var _path: String,
tableName: String,
format: String
) extends DslTool {
Expand All @@ -98,7 +98,7 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
var option = _option
val tempDS = DataSourceRegistry.fetch(format, option)

if (tempDS.isDefined ) {
if (tempDS.isDefined) {
// DataSource who is not MLSQLSourceConfig or if it's MLSQLSourceConfig then skipDynamicEvaluation is false
// should evaluate the v with dynamic expression
if (tempDS.isInstanceOf[MLSQLSourceConfig] && !tempDS.asInstanceOf[MLSQLSourceConfig].skipDynamicEvaluation) {
Expand All @@ -110,15 +110,22 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
}

val reader = scriptSQLExecListener.sparkSession.read
reader.options(option)
path = TemplateMerge.merge(path, scriptSQLExecListener.env().toMap)
val tempPath = TemplateMerge.merge(_path, scriptSQLExecListener.env().toMap)

def emptyDataFrame = {
import sparkSession.implicits._
Seq.empty[String].toDF("name")
}

val dsConf = DataSourceConfig(cleanStr(path), option, Option(emptyDataFrame))
val dsConf = optionsRewrite(
AppRuntimeStore.LOAD_BEFORE_CONFIG_KEY,
DataSourceConfig(cleanStr(tempPath), option, Option(emptyDataFrame)),
format,
ScriptSQLExec.context())

val path = dsConf.path

reader.options(dsConf.config)
var sourceInfo: Option[SourceInfo] = None

DataSourceRegistry.fetch(format, option).map { datasource =>
Expand All @@ -128,7 +135,11 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
// extract source info if the datasource is MLSQLSourceInfo
if (datasource.isInstanceOf[MLSQLSourceInfo]) {
val authConf = DataAuthConfig(dsConf.path, dsConf.config)
sourceInfo = Option(datasource.asInstanceOf[MLSQLSourceInfo].sourceInfo(authConf))
sourceInfo = Option(sourceInfoRewrite(
AppRuntimeStore.LOAD_BEFORE_CONFIG_KEY,
datasource.asInstanceOf[MLSQLSourceInfo].sourceInfo(authConf),
format,
ScriptSQLExec.context()))
}
if (datasource.isInstanceOf[DatasourceAuth]) {
datasource.asInstanceOf[DatasourceAuth].auth(dsConf.path, dsConf.config)
Expand All @@ -151,7 +162,7 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
}

table = customRewrite(AppRuntimeStore.LOAD_BEFORE_KEY, table, dsConf, sourceInfo, ScriptSQLExec.context())
// In order to control the access of columns, we should rewrite the final sql (conver * to specify column names)
// In order to control the access of columns, we should rewrite the final sql (convert * to specify column names)
table = authRewrite(table, dsConf, sourceInfo, ScriptSQLExec.context())
// finally use the build-in or third-party plugins to rewrite the table.
table = customRewrite(AppRuntimeStore.LOAD_AFTER_KEY, table, dsConf, sourceInfo, ScriptSQLExec.context())
Expand Down Expand Up @@ -220,7 +231,7 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
}


path = TemplateMerge.merge(path, scriptSQLExecListener.env().toMap)
//path = TemplateMerge.merge(path, scriptSQLExecListener.env().toMap)
}

table.createOrReplaceTempView(tableName)
Expand All @@ -237,14 +248,44 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
if (rewrite && implClass != "") {
val instance = Class.forName(implClass)
instance.newInstance()
.asInstanceOf[RewriteableSource]
.asInstanceOf[RewritableSource]
.rewrite(df, config, sourceInfo, context)

} else {
df
}
}

def optionsRewrite(orderKey: String,
config: DataSourceConfig,
format: String,
context: MLSQLExecuteContext) = {
AppRuntimeStore.store.getLoadSave(orderKey) match {
case Some(item) =>
item.customClassItems.classNames.map { className =>
val instance = Class.forName(className).newInstance().asInstanceOf[RewritableSourceConfig]
instance.rewrite_conf(config, format, context)
}.headOption.getOrElse(config)
case None =>
config
}
}

def sourceInfoRewrite(orderKey: String,
sourceInfo: SourceInfo,
format: String,
context: MLSQLExecuteContext) = {
AppRuntimeStore.store.getLoadSave(orderKey) match {
case Some(item) =>
item.customClassItems.classNames.map { className =>
val instance = Class.forName(className).newInstance().asInstanceOf[RewritableSourceConfig]
instance.rewrite_source(sourceInfo, format, context)
}.headOption.getOrElse(sourceInfo)
case None =>
sourceInfo
}
}

def customRewrite(orderKey: String, df: DataFrame,
config: DataSourceConfig,
sourceInfo: Option[SourceInfo],
Expand All @@ -256,7 +297,7 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
item.customClassItems.classNames.foreach { className =>
val instance = Class.forName(className)
newDF = instance.newInstance()
.asInstanceOf[RewriteableSource]
.asInstanceOf[RewritableSource]
.rewrite(df, config, sourceInfo, context)
}
case None =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@

package tech.mlsql.dsl.adaptor

import java.util.UUID

import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SaveMode}
import streaming.core.datasource.{DataSinkConfig, DataSourceRegistry}
import streaming.core.datasource.{DataSinkConfig, DataSourceRegistry, RewritableSinkConfig}
import streaming.core.stream.MLSQLStreamManager
import streaming.dsl.parser.DSLSQLParser._
import streaming.dsl.template.TemplateMerge
import streaming.dsl.{ScriptSQLExec, ScriptSQLExecListener}
import streaming.dsl.{MLSQLExecuteContext, ScriptSQLExec, ScriptSQLExecListener}
import tech.mlsql.job.{JobManager, MLSQLJobType}
import tech.mlsql.runtime.AppRuntimeStore

import java.util.UUID
import scala.collection.mutable.ArrayBuffer

/**
* Created by allwefantasy on 27/8/2017.
*/
* Created by allwefantasy on 27/8/2017.
*/
class SaveAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdaptor {

def evaluate(value: String) = {
Expand Down Expand Up @@ -84,18 +84,30 @@ class SaveAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt
override def parse(ctx: SqlContext): Unit = {


val SaveStatement(_, tableName, format, path, option, _mode, partitionByCol) = analyze(ctx)
val owner = option.get("owner")
val mode = SaveMode.valueOf(_mode)
val SaveStatement(_, tableName, format, _path, _option, _mode, partitionByCol) = analyze(ctx)


val context = ScriptSQLExec.context()

var oldDF: DataFrame = scriptSQLExecListener.sparkSession.table(tableName)
val spark = oldDF.sparkSession
import spark.implicits._


val mode = SaveMode.valueOf(_mode)

val dsc = optionsRewrite(AppRuntimeStore.SAVE_BEFORE_CONFIG_KEY, DataSinkConfig(_path, _option,
mode, Option(oldDF)), format, context)

val option = dsc.config
val owner = option.get("owner")
val path = dsc.path

def isStream = {
MLSQLStreamManager.isStream
}

val spark = oldDF.sparkSession
import spark.implicits._
val context = ScriptSQLExec.context()

var job = JobManager.getJobInfo(context.groupId)


Expand Down Expand Up @@ -136,7 +148,7 @@ class SaveAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt
if (path == "-" || path.isEmpty) {
writer.format(option.getOrElse("implClass", format)).save()
} else {
writer.format(option.getOrElse("implClass", format)).save(resourceRealPath(context.execListener, owner, path))
writer.format(option.getOrElse("implClass", format)).save( resourceRealPath(context.execListener, owner, path) )
}
}

Expand All @@ -162,6 +174,23 @@ class SaveAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdapt
outputTable.createOrReplaceTempView(tempTable)
scriptSQLExecListener.setLastSelectTable(tempTable)
}

def optionsRewrite(orderKey: String,
config: DataSinkConfig,
format: String,
context: MLSQLExecuteContext):DataSinkConfig = {
AppRuntimeStore.store.getLoadSave(orderKey) match {
case Some(item) =>
item.customClassItems.classNames.map { className =>
val instance = Class.forName(className).newInstance().asInstanceOf[RewritableSinkConfig]
instance.rewrite(config, format, context)
}.headOption.getOrElse(config)
case None =>
config
}
}


}

case class SaveStatement(raw: String, inputTableName: String, format: String, path: String, option: Map[String, String] = Map(), mode: String, partitionByCol: List[String])
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,13 @@ trait LoadSaveRuntimeStore {
object AppRuntimeStore {
private val _store = new InMemoryStore()
val store = new AppRuntimeStore(_store)
val LOAD_BEFORE_CONFIG_KEY = "load_before_config_key"
val SAVE_BEFORE_CONFIG_KEY = "save_before_config_key"

val LOAD_BEFORE_KEY = "load_before_key"
val LOAD_AFTER_KEY = "load_after_key"

val FS_BEFORE_CONFIG_KEY = "fs_before_config_key"
}

class Jack extends CustomController {
Expand Down
28 changes: 23 additions & 5 deletions streamingpro-mlsql/src/main/java/tech/mlsql/ets/HDFSCommand.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,31 @@ import org.apache.hadoop.util.ToolRunner
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, SparkSession}
import streaming.core.datasource.{FSConfig, RewritableFSConfig}
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.Functions
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.dsl.{MLSQLExecuteContext, ScriptSQLExec}
import tech.mlsql.common.utils.serder.json.JSONTool
import tech.mlsql.ets.hdfs.WowFsShell
import tech.mlsql.runtime.AppRuntimeStore

/**
* 2019-05-07 WilliamZhu([email protected])
*/
* 2019-05-07 WilliamZhu([email protected])
*/
class HDFSCommand(override val uid: String) extends SQLAlg with Functions with WowParams {
def this() = this(BaseParams.randomUID())

override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = train(df, path, params)

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val spark = df.sparkSession
val conf = df.sparkSession.sessionState.newHadoopConf()
val fsConf = configRewrite(AppRuntimeStore.FS_BEFORE_CONFIG_KEY,
FSConfig(df.sparkSession.sessionState.newHadoopConf(), path, params), ScriptSQLExec.context())
val args = JSONTool.parseJson[List[String]](params("parameters"))
conf.setQuietMode(false)
fsConf.conf.setQuietMode(false)
var output = ""
val fsShell = new WowFsShell(conf, path)
val fsShell = new WowFsShell(fsConf.conf, fsConf.path)
try {
ToolRunner.run(fsShell, args.toArray)
output = fsShell.getError
Expand All @@ -45,6 +49,20 @@ class HDFSCommand(override val uid: String) extends SQLAlg with Functions with W
}
}

def configRewrite(orderKey: String,
config: FSConfig,
context: MLSQLExecuteContext): FSConfig = {
AppRuntimeStore.store.getLoadSave(orderKey) match {
case Some(item) =>
item.customClassItems.classNames.map { className =>
val instance = Class.forName(className).newInstance().asInstanceOf[RewritableFSConfig]
instance.rewrite(config, context)
}.headOption.getOrElse(config)
case None =>
config
}
}


override def skipPathPrefix: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package tech.mlsql.plugin.load

import org.apache.spark.sql.DataFrame
import streaming.core.datasource.{DataSourceConfig, RewriteableSource, SourceInfo}
import streaming.core.datasource.{DataSourceConfig, RewritableSource, SourceInfo}
import streaming.dsl.MLSQLExecuteContext

/**
* 11/12/2019 WilliamZhu([email protected])
*/
class DefaultLoaderPlugin extends RewriteableSource {
class DefaultLoaderPlugin extends RewritableSource {
override def rewrite(df: DataFrame, config: DataSourceConfig, sourceInfo: Option[SourceInfo], context: MLSQLExecuteContext): DataFrame = {
val conf = config.config
var table = df
Expand Down
Loading

0 comments on commit 844c0cb

Please sign in to comment.