From 85fc18d40a8541b436cb4ab031d47f60d1f90e62 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 13 Nov 2018 15:06:37 +0800 Subject: [PATCH 1/3] [fix] fix scala udf map type compile failed error. --- .../streaming/common/SourceCodeCompiler.scala | 15 +++ .../streaming/dsl/mmlib/algs/ScriptUDF.scala | 8 +- .../streaming/parser/SparkTypePaser.scala | 106 +++++++++++++----- .../java/streaming/udf/ScalaSourceUDF.scala | 27 +++-- .../catalyst/expressions/WowScalaUDF.scala | 2 + 5 files changed, 117 insertions(+), 41 deletions(-) diff --git a/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala b/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala index c72e992bc..6e383c14e 100644 --- a/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala +++ b/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala @@ -1,5 +1,6 @@ package streaming.common +import scala.reflect.runtime.universe._ import javax.tools._ import scala.collection.JavaConversions._ @@ -51,6 +52,20 @@ object SourceCodeCompiler extends Logging { } } + def getFunReturnType(fun: String): Type = { + println(fun) + import scala.tools.reflect.ToolBox + //val classLoader = scala.reflect.runtime.universe.getClass.getClassLoader + var classLoader = Thread.currentThread().getContextClassLoader + if (classLoader == null) { + classLoader = scala.reflect.runtime.universe.getClass.getClassLoader + } + val tb = runtimeMirror(classLoader).mkToolBox() + val tree = tb.parse(fun) + val defDef = tb.typecheck(tree).asInstanceOf[DefDef] + defDef.tpt.tpe + } + def compileScala(src: String): Class[_] = { import scala.reflect.runtime.universe import scala.tools.reflect.ToolBox diff --git a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala index 1c780504c..a6249828d 100644 --- a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala @@ -71,11 +71,11 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit } case _ => - if (params.contains(className.name)) { - ScalaSourceUDF(res, params(className.name), params.get(methodName.name)) - } else { +// if (params.contains(className.name)) { +// ScalaSourceUDF(res, params(className.name), params.get(methodName.name)) +// } else { ScalaSourceUDF(res, params.get(methodName.name)) - } +// } } (e: Seq[Expression]) => new WowScalaUDF(func, returnType, e).toScalaUDF } diff --git a/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala b/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala index e22966496..00fe44b1c 100644 --- a/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala +++ b/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala @@ -3,6 +3,8 @@ package streaming.parser import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.ScalaReflection + /** * Created by allwefantasy on 8/9/2018. */ @@ -58,27 +60,31 @@ object SparkTypePaser { } //array(array(map(string,string))) - def toSparkType(dt: String): DataType = dt match { - case "boolean" => BooleanType - case "byte" => ByteType - case "short" => ShortType - case "integer" => IntegerType - case "date" => DateType - case "long" => LongType - case "float" => FloatType - case "double" => DoubleType - case "decimal" => DoubleType - case "binary" => BinaryType - case "string" => StringType - case c: String if c.startsWith("array") => - ArrayType(toSparkType(findInputInArrayBracket(c))) - case c: String if c.startsWith("map") => - //map(map(string,string),string) - val (key, value) = findKeyAndValue(findInputInArrayBracket(c)) - MapType(toSparkType(key), toSparkType(value)) - - case _ => throw new RuntimeException("dt is not found spark type") + def toSparkType(dt: String): DataType = { + println(dt) + + dt match { + case "boolean" => BooleanType + case "byte" => ByteType + case "short" => ShortType + case "integer" => IntegerType + case "date" => DateType + case "long" => LongType + case "float" => FloatType + case "double" => DoubleType + case "decimal" => DoubleType + case "binary" => BinaryType + case "string" => StringType + case c: String if c.startsWith("array") => + ArrayType(toSparkType(findInputInArrayBracket(c))) + case c: String if c.startsWith("map") => + //map(map(string,string),string) + val (key, value) = findKeyAndValue(findInputInArrayBracket(c)) + MapType(toSparkType(key), toSparkType(value)) + + case _ => throw new RuntimeException(s"$dt is not found spark type") + } } def cleanSparkSchema(wowStructType: WowStructType): StructType = { @@ -169,10 +175,60 @@ object SparkTypePaser { } def main(args: Array[String]): Unit = { - val res = toSparkSchema("st(field(name,string),field(name1,st(field(name2,array(string)))))", WowStructType(ArrayBuffer())) - println(cleanSparkSchema(res.asInstanceOf[WowStructType])) - // val wow = ArrayBuffer[String]() - // findFieldArray("field(name,string),field(name1,st(field(name2,array(string))))", wow) - // println(wow) + xxxxx() } + def xxxxx(): Unit = { + import scala.reflect.runtime.currentMirror + import scala.reflect.runtime.universe._ + import scala.tools.reflect.ToolBox + val tb = currentMirror.mkToolBox() + val code = + """ + | def xxx(i: Map[String, String]): Map[String, String] = { + | null + | } + """.stripMargin + val code2 = + """ + | def xxx(i: Int): Int = { + | null + | } + """.stripMargin + + val code3 = + """ + | def xxx(i: Array[String]): Array[String] = { + | null + | } + """.stripMargin + + val code4 = + """ + |def apply(m: String): Map[String, String] = { + | Map("a" -> "b") + |} + """.stripMargin + val code5 = + """ + |def apply(m: String) = { + | Map("a" -> "b") + |} + """.stripMargin + val tree = tb.parse(code5) + val zz = tb.typecheck(tree) + println("=================") + + val yy = zz.asInstanceOf[DefDef] + println("----") +// println(yy.tpt) + println(yy.tpt.tpe) + println(yy.tpt.symbol.asType.typeParams) + println() + val dt = ScalaReflection.schemaFor(yy.tpt.tpe) + println(dt) + val x = typeOf[String] + typeOf[Int] + x + } + } diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDF.scala index 1d6b30285..875384890 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDF.scala @@ -2,7 +2,7 @@ package streaming.udf import java.util.UUID -import org.apache.spark.sql.catalyst.JavaTypeInference +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.types._ import streaming.common.{ScriptCacheKey, SourceCodeCompiler} import streaming.dsl.ScriptSQLExec @@ -27,17 +27,18 @@ object ScalaSourceUDF extends Logging with WowLog { } - def apply(src: String, methodName: Option[String]): (AnyRef, DataType) = { - val (className, newfun) = wrapClass(src) - apply(newfun, className, methodName) - } +// def apply(src: String, methodName: Option[String]): (AnyRef, DataType) = { +// val (className, newfun) = src +// apply(src , methodName) +// } - def apply(src: String, className: String, methodName: Option[String]): (AnyRef, DataType) = { - val (argumentNum, returnType) = getFunctionReturnType(src, className, methodName) - (generateFunction(src, className, methodName, argumentNum), returnType) + def apply(function: String, methodName: Option[String]): (AnyRef, DataType) = { + val (argumentNum, returnType) = getFunctionReturnType(function, methodName) + (generateFunction(function, methodName, argumentNum), returnType) } - private def getFunctionReturnType(src: String, className: String, methodName: Option[String]): (Int, DataType) = { + private def getFunctionReturnType(function: String, methodName: Option[String]): (Int, DataType) = { + val (className, src) = wrapClass(function) val c = ScriptSQLExec.contextGetOrForTest() @@ -56,12 +57,14 @@ object ScalaSourceUDF extends Logging with WowLog { SourceCodeCompiler.execute(ScriptCacheKey(src, className)) }).asInstanceOf[Class[_]] val method = SourceCodeCompiler.getMethod(clazz, methodName.getOrElse("apply")) - val dataType: (DataType, Boolean) = JavaTypeInference.inferDataType(method.getReturnType) - (method.getParameterCount, dataType._1) + val tpe = SourceCodeCompiler.getFunReturnType(function) + val dataType = ScalaReflection.schemaFor(tpe).dataType + (method.getParameterCount, dataType) } - def generateFunction(src: String, className: String, methodName: Option[String], argumentNum: Int): AnyRef = { + def generateFunction(function: String, methodName: Option[String], argumentNum: Int): AnyRef = { + val (className, src) = wrapClass(function) val c = ScriptSQLExec.contextGetOrForTest() diff --git a/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala b/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala index 9b8b273d9..f3e210a24 100644 --- a/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala +++ b/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala @@ -34,3 +34,5 @@ case class WowScalaUDF(function: AnyRef, } } + +case class RuntimeComplieUDF() From 7898911b9b25d58715ce50f795673686b02a3c94 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Wed, 14 Nov 2018 15:13:20 +0800 Subject: [PATCH 2/3] [ScriptUDF][update] refactoring ScirptUDF module [ScirptUDF][add] scala udf refator [ScriptUDF] [update] refactoring scala udaf [ScriptUDF] [update] refactoring python udaf [ScriptUDF][update] refactoring RuntimeCompileUDF [ScriptUDF][update] refactoring RuntimeCompileUDF [Test][add] add ScripUDF unit test [ScriptUDF][add] add more unit test && RuntimeComplieUDF/RuntimeCompileUDAF automatic register [ScriptUDF][update] clear code --- .../streaming/common/SourceCodeCompiler.scala | 2 - .../java/streaming/jython/PythonInterp.java | 9 - .../streaming/dsl/mmlib/algs/ScriptUDF.scala | 72 ++--- .../streaming/parser/SparkTypePaser.scala | 61 +--- ...AF.scala => PythonRuntimCompileUDAF.scala} | 37 ++- .../udf/PythonRuntimeCompileUDF.scala | 104 +++++++ .../java/streaming/udf/PythonSourceUDF.scala | 278 ------------------ .../udf/RuntimeCompileScriptInterface.scala | 120 ++++++++ .../streaming/udf/RuntimeCompileUDAF.scala | 16 + ...ourceUDF.scala => RuntimeCompileUDF.scala} | 161 +++++----- ...DAF.scala => ScalaRuntimCompileUDAF.scala} | 73 +++-- .../udf/ScalaRuntimeCompileUDF.scala | 132 +++++++++ .../spark/streaming/BasicSparkOperation.scala | 4 +- .../streaming/core/BasicMLSQLConfig.scala | 10 + .../dsl/mmlib/algs/ScriptUDFSuite.scala | 261 ++++++++++++++++ .../catalyst/expressions/WowScalaUDF.scala | 2 - 16 files changed, 821 insertions(+), 521 deletions(-) rename streamingpro-mlsql/src/main/java/streaming/udf/{PythonSourceUDAF.scala => PythonRuntimCompileUDAF.scala} (78%) create mode 100644 streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimeCompileUDF.scala delete mode 100644 streamingpro-mlsql/src/main/java/streaming/udf/PythonSourceUDF.scala create mode 100644 streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala create mode 100644 streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDAF.scala rename streamingpro-mlsql/src/main/java/streaming/udf/{ScalaSourceUDF.scala => RuntimeCompileUDF.scala} (67%) rename streamingpro-mlsql/src/main/java/streaming/udf/{ScalaSourceUDAF.scala => ScalaRuntimCompileUDAF.scala} (62%) create mode 100644 streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimeCompileUDF.scala create mode 100644 streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala diff --git a/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala b/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala index 6e383c14e..db7250e5f 100644 --- a/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala +++ b/streamingpro-commons/src/main/java/streaming/common/SourceCodeCompiler.scala @@ -53,9 +53,7 @@ object SourceCodeCompiler extends Logging { } def getFunReturnType(fun: String): Type = { - println(fun) import scala.tools.reflect.ToolBox - //val classLoader = scala.reflect.runtime.universe.getClass.getClassLoader var classLoader = Thread.currentThread().getContextClassLoader if (classLoader == null) { classLoader = scala.reflect.runtime.universe.getClass.getClassLoader diff --git a/streamingpro-jython/src/main/java/streaming/jython/PythonInterp.java b/streamingpro-jython/src/main/java/streaming/jython/PythonInterp.java index b16a2f5b1..8195fe30f 100644 --- a/streamingpro-jython/src/main/java/streaming/jython/PythonInterp.java +++ b/streamingpro-jython/src/main/java/streaming/jython/PythonInterp.java @@ -26,13 +26,4 @@ public static PyObject compilePython(String src, String methodName) { return pi.get(methodName); } - public static void main(String[] args) { - String source = "class A:\n" + - "\tdef a(self,k1,k2):\n" + - "\t return k1 + k2"; - PyObject ob = compilePython(source, "A"); - PyObject instance = ob.__call__(); - System.out.println(instance.__getattr__("a").__call__(new PyInteger(2), new PyInteger(3))); - } - } diff --git a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala index a6249828d..1159a24af 100644 --- a/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/dsl/mmlib/algs/ScriptUDF.scala @@ -1,14 +1,13 @@ package streaming.dsl.mmlib.algs import org.apache.spark.ml.param.Param -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF, WowScalaUDF} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.udf.UDFManager import streaming.dsl.mmlib._ import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams} -import streaming.udf.{PythonSourceUDAF, PythonSourceUDF, ScalaSourceUDAF, ScalaSourceUDF} +import streaming.udf._ /** * Created by allwefantasy on 27/8/2018. @@ -23,9 +22,6 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit emptyDataFrame()(df) } - /* - register ScalaScriptUDF.`scriptText` as udf1; - */ override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = { val res = params.get(code.name).getOrElse(sparkSession.table(path).head().getString(0)) @@ -60,40 +56,23 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit if (!dataType.isValid(l)) thr(dataType, l) set(dataType, l) } - - val udf = () => { - val (func, returnType) = $(lang) match { - case "python" => - if (params.contains(className.name)) { - PythonSourceUDF(res, params(className.name), params.get(methodName.name), params(dataType.name)) - } else { - PythonSourceUDF(res, params.get(methodName.name), params(dataType.name)) - } - - case _ => -// if (params.contains(className.name)) { -// ScalaSourceUDF(res, params(className.name), params.get(methodName.name)) -// } else { - ScalaSourceUDF(res, params.get(methodName.name)) -// } - } - (e: Seq[Expression]) => new WowScalaUDF(func, returnType, e).toScalaUDF - } - - val udaf = () => { - val func = $(lang) match { - case "python" => - PythonSourceUDAF(res, $(className)) - - case _ => - ScalaSourceUDAF(res, $(className)) - } - (e: Seq[Expression]) => ScalaUDAF(e, func) - } + val scriptCacheKey = ScriptUDFCacheKey( + res, "", $(className), $(udfType), $(methodName), $(dataType), $(lang) + ) $(udfType) match { - case "udaf" => udaf() - case _ => udf() + case "udaf" => + val udaf = RuntimeCompileScriptFactory.getUDAFCompilerBylang($(lang)) + if (!udaf.isDefined) { + throw new IllegalArgumentException() + } + (e: Seq[Expression]) => udaf.get.udaf(e, scriptCacheKey) + case _ => + val udf = RuntimeCompileScriptFactory.getUDFCompilerBylang($(lang)) + if (!udf.isDefined) { + throw new IllegalArgumentException() + } + (e: Seq[Expression]) => udf.get.udf(e, scriptCacheKey) } } @@ -396,9 +375,7 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit final val lang: Param[String] = new Param[String](this, "lang", - s"""Which type of language you want. [scala|python]""", (s: String) => { - s == "scala" || s == "python" - }) + s"""Which type of language you want. [scala|python]""") setDefault(lang, "scala") final val udfType: Param[String] = new Param[String](this, "udfType", @@ -409,10 +386,23 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit final val className: Param[String] = new Param[String](this, "className", s"""the className of you defined in code snippet.""") + setDefault(className, "") final val methodName: Param[String] = new Param[String](this, "methodName", s"""the methodName of you defined in code snippet. If the name is apply, this parameter is optional""") + setDefault(methodName, "apply") final val dataType: Param[String] = new Param[String](this, "dataType", s"""when you use python udf, you should define return type.""") + setDefault(dataType, "") } + +case class ScriptUDFCacheKey( + originalCode: String, + wrappedCode: String, + className: String, + udfType: String, + methodName: String, + dataType: String, + lang: String) + diff --git a/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala b/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala index 00fe44b1c..4f586ef30 100644 --- a/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala +++ b/streamingpro-mlsql/src/main/java/streaming/parser/SparkTypePaser.scala @@ -4,6 +4,7 @@ import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.ScalaReflection +import streaming.common.SourceCodeCompiler /** * Created by allwefantasy on 8/9/2018. @@ -61,8 +62,6 @@ object SparkTypePaser { //array(array(map(string,string))) def toSparkType(dt: String): DataType = { - println(dt) - dt match { case "boolean" => BooleanType case "byte" => ByteType @@ -173,62 +172,4 @@ object SparkTypePaser { } } - - def main(args: Array[String]): Unit = { - xxxxx() - } - def xxxxx(): Unit = { - import scala.reflect.runtime.currentMirror - import scala.reflect.runtime.universe._ - import scala.tools.reflect.ToolBox - val tb = currentMirror.mkToolBox() - val code = - """ - | def xxx(i: Map[String, String]): Map[String, String] = { - | null - | } - """.stripMargin - val code2 = - """ - | def xxx(i: Int): Int = { - | null - | } - """.stripMargin - - val code3 = - """ - | def xxx(i: Array[String]): Array[String] = { - | null - | } - """.stripMargin - - val code4 = - """ - |def apply(m: String): Map[String, String] = { - | Map("a" -> "b") - |} - """.stripMargin - val code5 = - """ - |def apply(m: String) = { - | Map("a" -> "b") - |} - """.stripMargin - val tree = tb.parse(code5) - val zz = tb.typecheck(tree) - println("=================") - - val yy = zz.asInstanceOf[DefDef] - println("----") -// println(yy.tpt) - println(yy.tpt.tpe) - println(yy.tpt.symbol.asType.typeParams) - println() - val dt = ScalaReflection.schemaFor(yy.tpt.tpe) - println(dt) - val x = typeOf[String] - typeOf[Int] - x - } - } diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/PythonSourceUDAF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimCompileUDAF.scala similarity index 78% rename from streamingpro-mlsql/src/main/java/streaming/udf/PythonSourceUDAF.scala rename to streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimCompileUDAF.scala index e0dd8f11f..9b376a105 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/PythonSourceUDAF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimCompileUDAF.scala @@ -4,21 +4,31 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, StructType} import org.python.core.{Py, PyObject} -import streaming.common.{ScriptCacheKey, SourceCodeCompiler} import streaming.dsl.ScriptSQLExec -import streaming.jython.JythonUtils -import streaming.log.{Logging, WowLog} - +import streaming.dsl.mmlib.algs.ScriptUDFCacheKey +import streaming.jython.{JythonUtils, PythonInterp} /** - * Created by allwefantasy on 31/8/2018. - */ -object PythonSourceUDAF extends Logging with WowLog { - def apply(src: String, className: String): UserDefinedAggregateFunction = { - generateAggregateFunction(src, className) + * Created by fchen on 2018/11/15. + */ +object PythonRuntimCompileUDAF extends RuntimeCompileUDAF { + /** + * validate the source code + */ + override def check(sourceCode: String): Boolean = true + + /** + * compile the source code. + * + * @param scriptCacheKey + * @return + */ + override def compile(scriptCacheKey: ScriptUDFCacheKey): AnyRef = { + PythonInterp.compilePython(scriptCacheKey.originalCode, scriptCacheKey.className) } - private def generateAggregateFunction(src: String, className: String): UserDefinedAggregateFunction = { + override def generateFunction(scriptCacheKey: ScriptUDFCacheKey): UserDefinedAggregateFunction = { + new UserDefinedAggregateFunction with Serializable { val c = ScriptSQLExec.contextGetOrForTest() @@ -28,17 +38,16 @@ object PythonSourceUDAF extends Logging with WowLog { fn() } catch { case e: Exception => - logError(format_exception(e)) throw e } } @transient val objectUsingInDriver = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, className, "python")).asInstanceOf[PyObject].__call__() + execute(scriptCacheKey).asInstanceOf[PyObject].__call__() }).asInstanceOf[PyObject] lazy val objectUsingInExecutor = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, className, "python")).asInstanceOf[PyObject].__call__() + execute(scriptCacheKey).asInstanceOf[PyObject].__call__() }).asInstanceOf[PyObject] @@ -112,4 +121,6 @@ object PythonSourceUDAF extends Logging with WowLog { } } + + override def lang: String = "python" } diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimeCompileUDF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimeCompileUDF.scala new file mode 100644 index 000000000..eae8d622a --- /dev/null +++ b/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimeCompileUDF.scala @@ -0,0 +1,104 @@ +package streaming.udf + +import java.util.UUID + +import org.apache.spark.sql.types.DataType +import org.python.core.{PyFunction, PyMethod, PyObject, PyTableCode} +import streaming.dsl.ScriptSQLExec +import streaming.dsl.mmlib.algs.ScriptUDFCacheKey +import streaming.jython.{JythonUtils, PythonInterp} +import streaming.log.Logging +import streaming.parser.SparkTypePaser + +/** + * Created by fchen on 2018/11/14. + */ +object PythonRuntimeCompileUDF extends RuntimeCompileUDF with Logging { + + override def returnType(scriptCacheKey: ScriptUDFCacheKey): Option[DataType] = { + Option(SparkTypePaser.toSparkType(scriptCacheKey.dataType)) + } + + /** + * reture udf input argument number + */ + override def argumentNum(scriptCacheKey: ScriptUDFCacheKey): Int = { + + val po = execute(scriptCacheKey).asInstanceOf[PyObject] + val pi = po.__getattr__(scriptCacheKey.methodName).asInstanceOf[PyMethod] + pi.__func__.asInstanceOf[PyFunction].__code__.asInstanceOf[PyTableCode].co_argcount - 1 + } + + override def wrapCode(scriptCacheKey: ScriptUDFCacheKey): ScriptUDFCacheKey = { + if (scriptCacheKey.className.isEmpty) { + val (className, code) = wrapClass(scriptCacheKey.originalCode) + scriptCacheKey.copy(wrappedCode = code, className = className) + } else { + scriptCacheKey.copy(wrappedCode = scriptCacheKey.originalCode) + } + } + + /** + * validate the source code + */ + override def check(sourceCode: String): Boolean = { + true + } + + /** + * compile the source code. + * + * @param scriptCacheKey + * @return + */ + override def compile(scriptCacheKey: ScriptUDFCacheKey): AnyRef = { + PythonInterp.compilePython(scriptCacheKey.wrappedCode, scriptCacheKey.className) + } + + override def invokeFunctionFromInstance(scriptCacheKey: ScriptUDFCacheKey) + : (Seq[Object]) => AnyRef = { + + val c = ScriptSQLExec.contextGetOrForTest() + + val wrap = (fn: () => Any) => { + try { + ScriptSQLExec.setContextIfNotPresent(c) + fn() + } catch { + case e: Exception => + throw e + } + } + + // instance will call by spark executor, so we declare as lazy val + lazy val instance = wrap(() => { + execute(scriptCacheKey).asInstanceOf[PyObject].__call__() + }).asInstanceOf[PyObject] + + // the same with instance, method will call by spark executor too. + lazy val method = instance.__getattr__(scriptCacheKey.methodName) + + val invokeFunc: (Seq[Any]) => AnyRef = { + (args: Seq[Any]) => { + val argsArray = args.map(JythonUtils.toPy).toArray + JythonUtils.toJava(method.__call__(argsArray)) + } + } + invokeFunc + } + + override def lang: String = "python" + + private def wrapClass(function: String): WrappedType = { + + val temp = function.split("\n").map(f => s" $f").mkString("\n") + val className = s"StreamingProUDF_${UUID.randomUUID().toString.replaceAll("-", "")}" + val newfun = + s""" + |# -*- coding: utf-8 -*- + |class ${className}: + |${temp} + """.stripMargin + (className, newfun) + } +} diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/PythonSourceUDF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/PythonSourceUDF.scala deleted file mode 100644 index f305526f3..000000000 --- a/streamingpro-mlsql/src/main/java/streaming/udf/PythonSourceUDF.scala +++ /dev/null @@ -1,278 +0,0 @@ -package streaming.udf - -import java.util.UUID - -import org.apache.spark.sql.types._ -import org.python.core._ -import streaming.common.{ScriptCacheKey, SourceCodeCompiler} -import streaming.dsl.ScriptSQLExec -import streaming.jython.JythonUtils -import streaming.log.{Logging, WowLog} -import streaming.parser.SparkTypePaser - -import scala.collection.mutable.ArrayBuffer - -/** - * Created by allwefantasy on 28/8/2018. - */ -object PythonSourceUDF extends Logging with WowLog { - - private def wrapClass(function: String) = { - val temp = function.split("\n").map(f => s" $f").mkString("\n") - val className = s"StreamingProUDF_${UUID.randomUUID().toString.replaceAll("-", "")}" - val newfun = - s""" - |# -*- coding: utf-8 -*- - |class ${className}: - |${temp} - """.stripMargin - (className, newfun) - - } - - def apply(src: String, methodName: Option[String], returnType: String): (AnyRef, DataType) = { - val (className, newfun) = wrapClass(src) - apply(newfun, className, methodName, returnType) - } - - - def apply(src: String, className: String, methodName: Option[String], returnType: String): (AnyRef, DataType) = { - val argumentNum = getParameterCount(src, className, methodName) - (generateFunction(src, className, methodName, argumentNum), SparkTypePaser.toSparkType(returnType)) - } - - - private def getParameterCount(src: String, classMethod: String, methodName: Option[String]): Int = { - - val c = ScriptSQLExec.contextGetOrForTest() - - val wrap = (fn: () => Any) => { - try { - ScriptSQLExec.setContextIfNotPresent(c) - fn() - } catch { - case e: Exception => - logError(format_exception(e)) - throw e - } - } - - val po = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, classMethod, "python")) - }) - val pi = po.asInstanceOf[PyObject].__getattr__(methodName.getOrElse("apply")).asInstanceOf[PyMethod] - pi.__func__.asInstanceOf[PyFunction].__code__.asInstanceOf[PyTableCode].co_argcount - 1 - } - - def generateFunction(src: String, className: String, methodName: Option[String], argumentNum: Int): AnyRef = { - val c = ScriptSQLExec.contextGetOrForTest() - - val wrap = (fn: () => Any) => { - try { - ScriptSQLExec.setContextIfNotPresent(c) - fn() - } catch { - case e: Exception => - logError(format_exception(e)) - throw e - } - } - - lazy val instance = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, className, "python")).asInstanceOf[PyObject].__call__() - }).asInstanceOf[PyObject] - lazy val method = instance.__getattr__(methodName.getOrElse("apply")) - - argumentNum match { - case 0 => new Function0[Any] with Serializable { - override def apply(): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__()) - }) - } - } - case 1 => new Function1[Object, Any] with Serializable { - override def apply(v1: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(JythonUtils.toPy(v1))) - }) - - } - } - case 2 => new Function2[Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(JythonUtils.toPy(v1), JythonUtils.toPy(v2))) - }) - - } - } - case 3 => new Function3[Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3))) - }) - - } - } - case 4 => new Function4[Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4))) - }) - - } - } - case 5 => new Function5[Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5)))) - }) - } - } - case 6 => new Function6[Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6)))) - }) - - } - } - case 7 => new Function7[Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7)))) - }) - } - } - case 8 => new Function8[Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8)))) - }) - } - } - case 9 => new Function9[Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9)))) - }) - - } - } - case 10 => new Function10[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10)))) - }) - } - } - case 11 => new Function11[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), JythonUtils.toPy(v11)))) - }) - - } - } - case 12 => new Function12[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), JythonUtils.toPy(v11), JythonUtils.toPy(v12)))) - }) - } - } - case 13 => new Function13[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13)))) - }) - - } - } - case 14 => new Function14[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14)))) - }) - - } - } - case 15 => new Function15[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15)))) - }) - - } - } - case 16 => new Function16[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15), JythonUtils.toPy(v16)))) - }) - } - } - case 17 => new Function17[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15), JythonUtils.toPy(v16), JythonUtils.toPy(v17)))) - }) - } - } - case 18 => new Function18[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15), JythonUtils.toPy(v16), JythonUtils.toPy(v17), JythonUtils.toPy(v18)))) - }) - - } - } - case 19 => new Function19[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15), JythonUtils.toPy(v16), JythonUtils.toPy(v17), JythonUtils.toPy(v18), JythonUtils.toPy(v19)))) - }) - - } - } - case 20 => new Function20[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object, v20: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15), - JythonUtils.toPy(v16), JythonUtils.toPy(v17), JythonUtils.toPy(v18), JythonUtils.toPy(v19), JythonUtils.toPy(v20)))) - }) - - } - } - case 21 => new Function21[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object, v20: Object, v21: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15), - JythonUtils.toPy(v16), JythonUtils.toPy(v17), JythonUtils.toPy(v18), JythonUtils.toPy(v19), JythonUtils.toPy(v20), JythonUtils.toPy(v21)))) - }) - - } - } - case 22 => new Function22[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { - override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object, v20: Object, v21: Object, v22: Object): Any = { - wrap(() => { - JythonUtils.toJava(method.__call__(Array(JythonUtils.toPy(v1), JythonUtils.toPy(v2), JythonUtils.toPy(v3), JythonUtils.toPy(v4), JythonUtils.toPy(v5), JythonUtils.toPy(v6), JythonUtils.toPy(v7), JythonUtils.toPy(v8), JythonUtils.toPy(v9), JythonUtils.toPy(v10), - JythonUtils.toPy(v11), JythonUtils.toPy(v12), JythonUtils.toPy(v13), JythonUtils.toPy(v14), JythonUtils.toPy(v15), - JythonUtils.toPy(v16), JythonUtils.toPy(v17), JythonUtils.toPy(v18), JythonUtils.toPy(v19), JythonUtils.toPy(v20), JythonUtils.toPy(v21), JythonUtils.toPy(v22)))) - }) - - } - } - case n => throw new Exception(s"UDF with $n arguments is not supported ") - } - } -} diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala new file mode 100644 index 000000000..75b16c336 --- /dev/null +++ b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala @@ -0,0 +1,120 @@ +package streaming.udf + +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ + +import com.google.common.cache.{CacheBuilder, CacheLoader} +import com.google.common.reflect.ClassPath +import streaming.dsl.mmlib.algs.ScriptUDFCacheKey +import streaming.log.Logging + +/** + * Created by fchen on 2018/11/13. + */ +trait RuntimeCompileScriptInterface[FunType] extends Logging { + + private val _scriptCache = CacheBuilder.newBuilder() + .maximumSize(10000) + .build( + new CacheLoader[ScriptUDFCacheKey, AnyRef]() { + override def load(scriptCacheKey: ScriptUDFCacheKey): AnyRef = { + val startTime = System.nanoTime() + val compiled = compile(scriptCacheKey) + + def timeMs: Double = (System.nanoTime() - startTime).toDouble / 1000000 + + logInfo(s"generate udf time: [ ${timeMs} ]ms.") + compiled + } + }) + + + /** + * validate the source code + */ + def check(sourceCode: String): Boolean + + /** + * how to compile the language source code with jvm. + * + * @param scriptCacheKey + * @return + */ + def compile(scriptCacheKey: ScriptUDFCacheKey): AnyRef + + /** + * generate udf or udaf + */ + def generateFunction(scriptCacheKey: ScriptUDFCacheKey): FunType + + def lang: String + + /** + * compile source code or get binary code for cache. + * @param scriptCacheKey + * @return + */ + def execute(scriptCacheKey: ScriptUDFCacheKey): AnyRef = { + _scriptCache.get(scriptCacheKey) + } +} + +object RuntimeCompileScriptFactory { + + private val _udfCache = HashMap[String, RuntimeCompileUDF]() + private val _udafCache = HashMap[String, RuntimeCompileUDAF]() + private val _loaded = new AtomicBoolean(false) + private val _lock = new Object() + + def getUDFCompilerBylang(lang: String): Option[RuntimeCompileUDF] = { + if (!_loaded.get()) { + loadAll() + } + _udfCache.get(lang) + } + + def getUDAFCompilerBylang(lang: String): Option[RuntimeCompileUDAF] = { + if (!_loaded.get()) { + loadAll() + } + _udafCache.get(lang) + } + + def registerUDF(lang: String, runtimeCompileUDF: RuntimeCompileUDF): Unit = { + _udfCache.put(lang, runtimeCompileUDF) + } + + def registerUDAF(lang: String, runtimeCompileUDAF: RuntimeCompileUDAF): Unit = { + _udafCache.put(lang, runtimeCompileUDAF) + } + + /** + * load all [[RuntimeCompileUDF]] and [[RuntimeCompileUDAF]] + */ + def loadAll(): Unit = { + _lock.synchronized { + ClassPath.from(this.getClass.getClassLoader) + .getTopLevelClasses("streaming.udf") + .map(_.load()) + .filter(n => { + n.getName.endsWith("RuntimeCompileUDF") && n.getName != "streaming.udf.RuntimeCompileUDF" + }).map(getInstance) + .foreach { + case udf: RuntimeCompileUDF => registerUDF(udf.lang, udf) + case udaf: RuntimeCompileUDAF => registerUDAF(udaf.lang, udaf) + } + _loaded.set(true) + } + } + + private def getInstance(clz: Class[_]): Any = { + import scala.reflect.runtime.universe + val runtimeMirror = universe.runtimeMirror(this.getClass.getClassLoader) + val module = runtimeMirror.staticModule(clz.getName) + val obj = runtimeMirror.reflectModule(module) + obj.instance + } + +} diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDAF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDAF.scala new file mode 100644 index 000000000..aae8c210e --- /dev/null +++ b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDAF.scala @@ -0,0 +1,16 @@ +package streaming.udf + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import streaming.dsl.mmlib.algs.ScriptUDFCacheKey + +/** + * Created by fchen on 2018/11/15. + */ +trait RuntimeCompileUDAF extends RuntimeCompileScriptInterface[UserDefinedAggregateFunction] { + + def udaf(e: Seq[Expression], scriptCacheKey: ScriptUDFCacheKey): ScalaUDAF = { + ScalaUDAF(e, generateFunction(scriptCacheKey)) + } +} diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDF.scala similarity index 67% rename from streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDF.scala rename to streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDF.scala index 875384890..dcd56af6b 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDF.scala @@ -1,45 +1,46 @@ package streaming.udf -import java.util.UUID - -import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.types._ -import streaming.common.{ScriptCacheKey, SourceCodeCompiler} +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.types.DataType import streaming.dsl.ScriptSQLExec -import streaming.log.{Logging, WowLog} +import streaming.dsl.mmlib.algs.ScriptUDFCacheKey /** - * Created by allwefantasy on 27/8/2018. - */ -object ScalaSourceUDF extends Logging with WowLog { - - private def wrapClass(function: String) = { - val className = s"StreamingProUDF_${UUID.randomUUID().toString.replaceAll("-", "")}" - val newfun = - s""" - |class ${className}{ - | - |${function} - | - |} - """.stripMargin - (className, newfun) - + * Created by fchen on 2018/11/15. + */ +trait RuntimeCompileUDF extends RuntimeCompileScriptInterface[AnyRef] { + + /** + * udf return DataType + */ + def returnType(scriptCacheKey: ScriptUDFCacheKey): Option[DataType] + + /** + * reture udf input argument number + */ + def argumentNum(scriptCacheKey: ScriptUDFCacheKey): Int + + /** + * wrap original source code. + * e.g. in [[ScalaRuntimCompileUDAF]], user pass function code, we should wrap code as a class. + * so the runtime compiler will compile source code as runtime instance. + */ + def wrapCode(scriptCacheKey: ScriptUDFCacheKey): ScriptUDFCacheKey + + def invokeFunctionFromInstance(scriptCacheKey: ScriptUDFCacheKey): (Seq[Object]) => AnyRef + + override def generateFunction(scriptCacheKey: ScriptUDFCacheKey): AnyRef = { + val runtimeFunction = invokeFunctionFromInstance(scriptCacheKey) + toPartialFunc(scriptCacheKey, runtimeFunction) } -// def apply(src: String, methodName: Option[String]): (AnyRef, DataType) = { -// val (className, newfun) = src -// apply(src , methodName) -// } - - def apply(function: String, methodName: Option[String]): (AnyRef, DataType) = { - val (argumentNum, returnType) = getFunctionReturnType(function, methodName) - (generateFunction(function, methodName, argumentNum), returnType) + def udf(exp: Seq[Expression], scriptCacheKey: ScriptUDFCacheKey): ScalaUDF = { + val newScript = wrapCode(scriptCacheKey) + ScalaUDF(generateFunction(newScript), returnType(newScript).get, exp) } - private def getFunctionReturnType(function: String, methodName: Option[String]): (Int, DataType) = { - val (className, src) = wrapClass(function) - + def toPartialFunc(scriptCacheKey: ScriptUDFCacheKey, + invokeFunction: (Seq[Object]) => AnyRef): AnyRef = { val c = ScriptSQLExec.contextGetOrForTest() val wrap = (fn: () => Any) => { @@ -48,61 +49,30 @@ object ScalaSourceUDF extends Logging with WowLog { fn() } catch { case e: Exception => - logError(format_exception(e)) throw e } } - val clazz = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, className)) - }).asInstanceOf[Class[_]] - val method = SourceCodeCompiler.getMethod(clazz, methodName.getOrElse("apply")) - val tpe = SourceCodeCompiler.getFunReturnType(function) - val dataType = ScalaReflection.schemaFor(tpe).dataType - (method.getParameterCount, dataType) - } - - - def generateFunction(function: String, methodName: Option[String], argumentNum: Int): AnyRef = { - val (className, src) = wrapClass(function) - - val c = ScriptSQLExec.contextGetOrForTest() - - val wrap = (fn: () => Any) => { - try { - ScriptSQLExec.setContextIfNotPresent(c) - fn() - } catch { - case e: Exception => - logError(format_cause(e)) - throw e - } - } - - lazy val clazz = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, className)) - }).asInstanceOf[Class[_]] - lazy val instance = SourceCodeCompiler.newInstance(clazz) - lazy val method = SourceCodeCompiler.getMethod(clazz, methodName.getOrElse("apply")) - argumentNum match { + argumentNum(scriptCacheKey) match { case 0 => new Function0[Any] with Serializable { override def apply(): Any = { wrap(() => { - method.invoke(instance) + invokeFunction }) } } case 1 => new Function1[Object, Any] with Serializable { override def apply(v1: Object): Any = { wrap(() => { - method.invoke(instance, v1) + invokeFunction(Seq(v1)) }) + } } case 2 => new Function2[Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2) + invokeFunction(Seq(v1, v2)) }) } @@ -110,14 +80,15 @@ object ScalaSourceUDF extends Logging with WowLog { case 3 => new Function3[Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3) + invokeFunction(Seq(v1, v2, v3)) }) + } } case 4 => new Function4[Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4) + invokeFunction(Seq(v1, v2, v3, v4)) }) } @@ -125,35 +96,36 @@ object ScalaSourceUDF extends Logging with WowLog { case 5 => new Function5[Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5) + invokeFunction(Seq(v1, v2, v3, v4, v5)) }) } } case 6 => new Function6[Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6)) }) + } } case 7 => new Function7[Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7)) }) } } case 8 => new Function8[Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8)) }) } } case 9 => new Function9[Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9)) }) } @@ -161,95 +133,106 @@ object ScalaSourceUDF extends Logging with WowLog { case 10 => new Function10[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10)) }) } } case 11 => new Function11[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) }) + } } case 12 => new Function12[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12)) }) } } case 13 => new Function13[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13)) }) + } } case 14 => new Function14[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14)) }) + } } case 15 => new Function15[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15)) }) + } } case 16 => new Function16[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16)) }) } } case 17 => new Function17[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17)) }) } } case 18 => new Function18[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18)) }) + } } case 19 => new Function19[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19)) }) + } } case 20 => new Function20[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object, v20: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20)) }) + } } case 21 => new Function21[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object, v20: Object, v21: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21)) }) + } } case 22 => new Function22[Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Object, Any] with Serializable { override def apply(v1: Object, v2: Object, v3: Object, v4: Object, v5: Object, v6: Object, v7: Object, v8: Object, v9: Object, v10: Object, v11: Object, v12: Object, v13: Object, v14: Object, v15: Object, v16: Object, v17: Object, v18: Object, v19: Object, v20: Object, v21: Object, v22: Object): Any = { wrap(() => { - method.invoke(instance, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) + invokeFunction(Seq(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22)) }) + } } case n => throw new Exception(s"UDF with $n arguments is not supported ") } } + + type WrappedType = (String, String) } diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDAF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimCompileUDAF.scala similarity index 62% rename from streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDAF.scala rename to streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimCompileUDAF.scala index 206234735..bd4670c23 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaSourceUDAF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimCompileUDAF.scala @@ -1,45 +1,66 @@ package streaming.udf +import scala.reflect.ClassTag + import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, StructType} -import streaming.common.{ScriptCacheKey, SourceCodeCompiler} +import org.python.antlr.ast.ClassDef +import streaming.common.SourceCodeCompiler import streaming.dsl.ScriptSQLExec -import streaming.log.{Logging, WowLog} - -import scala.reflect.ClassTag - +import streaming.dsl.mmlib.algs.ScriptUDFCacheKey + +/** + * Created by fchen on 2018/11/14. + */ +object ScalaRuntimCompileUDAF extends RuntimeCompileUDAF with ScalaCompileUtils { + /** + * validate the source code + */ + override def check(sourceCode: String): Boolean = { + val tree = tb.parse(sourceCode) + val typeCheckResult = tb.typecheck(tree) + val checkResult = typeCheckResult.isInstanceOf[ClassDef] + if (!checkResult) { + throw new IllegalArgumentException("scala udaf require a class define!") + } + checkResult + } -object ScalaSourceUDAF extends Logging with WowLog { - def apply(src: String, className: String): UserDefinedAggregateFunction = { - generateAggregateFunction(src, className) + /** + * compile the source code. + * + * @param scriptCacheKey + * @return + */ + override def compile(scriptCacheKey: ScriptUDFCacheKey): AnyRef = { + val tree = tb.parse(prepareScala(scriptCacheKey.originalCode, scriptCacheKey.className)) + tb.compile(tree).apply().asInstanceOf[Class[_]] } - private def generateAggregateFunction(src: String, className: String): UserDefinedAggregateFunction = { - new UserDefinedAggregateFunction with Serializable { + override def generateFunction(scriptCacheKey: ScriptUDFCacheKey): UserDefinedAggregateFunction = { + val c = ScriptSQLExec.contextGetOrForTest() - val c = ScriptSQLExec.contextGetOrForTest() - - val wrap = (fn: () => Any) => { - try { - ScriptSQLExec.setContextIfNotPresent(c) - fn() - } catch { - case e: Exception => - logError(format_cause(e)) - throw e - } + val wrap = (fn: () => Any) => { + try { + ScriptSQLExec.setContextIfNotPresent(c) + fn() + } catch { + case e: Exception => + throw e } + } + new UserDefinedAggregateFunction with Serializable { @transient val clazzUsingInDriver = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, className)) + execute(scriptCacheKey) }).asInstanceOf[Class[_]] - @transient val instanceUsingInDriver = SourceCodeCompiler.newInstance(clazzUsingInDriver) + @transient val instanceUsingInDriver = newInstance(clazzUsingInDriver) lazy val clazzUsingInExecutor = wrap(() => { - SourceCodeCompiler.execute(ScriptCacheKey(src, className)) + execute(scriptCacheKey) }).asInstanceOf[Class[_]] - lazy val instanceUsingInExecutor = SourceCodeCompiler.newInstance(clazzUsingInExecutor) + lazy val instanceUsingInExecutor = newInstance(clazzUsingInExecutor) def invokeMethod[T: ClassTag](clazz: Class[_], instance: Any, method: String): T = { wrap(() => { @@ -100,4 +121,6 @@ object ScalaSourceUDAF extends Logging with WowLog { } } + + override def lang: String = "scala" } diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimeCompileUDF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimeCompileUDF.scala new file mode 100644 index 000000000..22fe5b744 --- /dev/null +++ b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimeCompileUDF.scala @@ -0,0 +1,132 @@ +package streaming.udf + +import java.util.UUID + +import scala.reflect.runtime.universe._ +import scala.tools.reflect.ToolBox + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.DataType +import streaming.common.SourceCodeCompiler +import streaming.dsl.ScriptSQLExec +import streaming.dsl.mmlib.algs.ScriptUDFCacheKey +import streaming.log.Logging + +/** + * Created by fchen on 2018/11/14. + */ +object ScalaRuntimeCompileUDF extends RuntimeCompileUDF with ScalaCompileUtils with Logging { + + override def returnType(scriptCacheKey: ScriptUDFCacheKey): Option[DataType] = { + + getFunctionDef(scriptCacheKey) + .map(defDef => { + ScalaReflection.schemaFor(defDef.tpt.tpe).dataType + }) + } + + override def argumentNum(scriptCacheKey: ScriptUDFCacheKey): Int = { + val funcDef = getFunctionDef(scriptCacheKey) + require(funcDef.isDefined, s"function ${scriptCacheKey.methodName} not found" + + s" in ${scriptCacheKey.originalCode}") + funcDef.get.vparamss.head.size + } + + /** + * validate the source code + */ + override def check(sourceCode: String): Boolean = { + val tree = tb.parse(sourceCode) + val typeCheckResult = tb.typecheck(tree) + val checkResult = typeCheckResult.isInstanceOf[DefDef] || typeCheckResult.isInstanceOf[ClassDef] + if (!checkResult) { + throw new IllegalArgumentException(s"${sourceCode} isn't a function or class define.") + } + checkResult + } + + /** + * compile the source code. + * + * @param scriptCacheKey + * @return + */ + override def compile(scriptCacheKey: ScriptUDFCacheKey): AnyRef = { + val tree = tb.parse(prepareScala(scriptCacheKey.wrappedCode, scriptCacheKey.className)) + tb.compile(tree).apply().asInstanceOf[Class[_]] + } + + override def lang: String = "scala" + + override def wrapCode(scriptCacheKey: ScriptUDFCacheKey): ScriptUDFCacheKey = { + check(scriptCacheKey.originalCode) + val tree = tb.parse(scriptCacheKey.originalCode) + tb.typecheck(tree) match { + case dd: DefDef => + val (className, code) = wrapClass(scriptCacheKey.originalCode) + scriptCacheKey.copy(wrappedCode = code, className = className) + case cd: ClassDef => + scriptCacheKey.copy(wrappedCode = scriptCacheKey.originalCode) + case s: Any => + // never happen + throw new IllegalArgumentException(s"script type ${s.getClass} isn't a function or class.") + } + } + + private def getFunctionDef(scriptCacheKey: ScriptUDFCacheKey): Option[DefDef] = { + val tree = tb.parse(scriptCacheKey.wrappedCode) + val classDef = tb.typecheck(tree).asInstanceOf[ClassDef] + classDef.children + .head + .children + .filter(_.isInstanceOf[DefDef]) + .map(_.asInstanceOf[DefDef]) + .filter(_.name.decodedName.toString == scriptCacheKey.methodName) + .headOption + } + + private def wrapClass(function: String): WrappedType = { + val className = s"StreamingProUDF_${UUID.randomUUID().toString.replaceAll("-", "")}" + val newfun = + s""" + |class ${className} { + | + |${function} + | + |} + """.stripMargin + (className, newfun) + } + + def invokeFunctionFromInstance(scriptCacheKey: ScriptUDFCacheKey): (Seq[Object]) => AnyRef = { + + lazy val clz = execute(scriptCacheKey).asInstanceOf[Class[_]] + lazy val instance = newInstance(clz) + lazy val method = SourceCodeCompiler.getMethod(clz, scriptCacheKey.methodName) + + val func: (Seq[Object]) => AnyRef = { + (args: Seq[Object]) => method.invoke(instance, args: _*) + } + func + } +} + +trait ScalaCompileUtils { + var classLoader = Thread.currentThread().getContextClassLoader + if (classLoader == null) { + classLoader = scala.reflect.runtime.universe.getClass.getClassLoader + } + val tb = runtimeMirror(classLoader).mkToolBox() + + def prepareScala(src: String, className: String): String = { + src + "\n" + s"scala.reflect.classTag[$className].runtimeClass" + } + + def newInstance(clz: Class[_]): Any = { + SourceCodeCompiler.newInstance(clz) + } + +} + + + diff --git a/streamingpro-mlsql/src/test/scala/org/apache/spark/streaming/BasicSparkOperation.scala b/streamingpro-mlsql/src/test/scala/org/apache/spark/streaming/BasicSparkOperation.scala index 3bac22002..f5a851426 100644 --- a/streamingpro-mlsql/src/test/scala/org/apache/spark/streaming/BasicSparkOperation.scala +++ b/streamingpro-mlsql/src/test/scala/org/apache/spark/streaming/BasicSparkOperation.scala @@ -4,7 +4,7 @@ import java.io.File import net.csdn.common.reflect.ReflectHelper import org.apache.commons.io.FileUtils -import org.scalatest.{FlatSpec, Matchers} +import org.scalatest.{FlatSpec, FunSuite, Matchers} import serviceframework.dispatcher.{Compositor, StrategyDispatcher} import streaming.common.ParamsUtil import streaming.core.strategy.platform.{PlatformManager, SparkRuntime} @@ -12,7 +12,7 @@ import streaming.core.strategy.platform.{PlatformManager, SparkRuntime} /** * Created by allwefantasy on 30/3/2017. */ -class BasicSparkOperation extends FlatSpec with Matchers { +trait BasicSparkOperation extends FlatSpec with Matchers { def withBatchContext[R](runtime: SparkRuntime)(block: SparkRuntime => R): R = { try { diff --git a/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala b/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala index f2ba4bca3..c5218d160 100644 --- a/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala +++ b/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala @@ -74,4 +74,14 @@ trait BasicMLSQLConfig { "-streaming.unittest", "true" ) + + val batchParamsWithoutHive = Array( + "-streaming.master", "local[2]", + "-streaming.name", "unit-test", + "-streaming.rest", "false", + "-streaming.platform", "spark", + "-streaming.spark.service", "false", + "-streaming.udf.clzznames", "streaming.crawler.udf.Functions,streaming.dsl.mmlib.algs.processing.UDFFunctions", + "-streaming.unittest", "true" + ) } diff --git a/streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala b/streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala new file mode 100644 index 000000000..2bef4075a --- /dev/null +++ b/streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala @@ -0,0 +1,261 @@ +package streaming.dsl.mmlib.algs + +import scala.collection.mutable.WrappedArray + +import org.apache.spark.streaming.BasicSparkOperation +import streaming.core.strategy.platform.SparkRuntime +import streaming.core.{BasicMLSQLConfig, SpecFunctions} +import streaming.dsl.ScriptSQLExec + +/** + * Created by fchen on 2018/11/15. + */ +class ScriptUDFSuite extends BasicSparkOperation with SpecFunctions with BasicMLSQLConfig { + + "test scala script map return type" should "work fine" in { + withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + implicit val spark = runtime.sparkSession + var sq = createSSEL + ScriptSQLExec.parse( + """ + | set echoFun=''' + | + | def apply(m: String) = { + | Map("a" -> Array[Int](1)) + | } + | '''; + | + | load script.`echoFun` as scriptTable; + | + | register ScriptUDF.`scriptTable` as funx + | ; + | + | -- create a data table. + | set data=''' + | {"a":"a"} + | '''; + | load jsonStr.`data` as dataTable; + | + | select funx(a) as res from dataTable as output; + | + """.stripMargin, sq) + + val result = runtime.sparkSession.sql("select * from output").collect() + assert(result.size == 1) + assert(result.head.getAs[Map[String, WrappedArray[Int]]](0)("a").head == 1) + } + } + + + "test scala compile error case" should "work fine" in { + withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + implicit val spark = runtime.sparkSession + var sq = createSSEL + val query: (String) => String = (scalaCode) => s""" + | set echoFun=''' + | ${scalaCode} + | '''; + | + | load script.`echoFun` as scriptTable; + | + | register ScriptUDF.`scriptTable` as funx + | ; + | + | -- create a data table. + | set data=''' + | {"a":"a"} + | '''; + | load jsonStr.`data` as dataTable; + | + | select funx(a) as res from dataTable as output; + | + """.stripMargin + + val scalaCode = + """ + | def apply(m: String) = { + | Map("a" -> Array[Int](1)) + | } + | apply("hello") + """.stripMargin + + + assertThrows[IllegalArgumentException] { + ScriptSQLExec.parse(query(scalaCode), sq) + } + + val functionWithOtherName = + """ + | def function(m: String) = { + | Map("a" -> Array[Int](1)) + | } + """.stripMargin + sq = createSSEL + + assertThrows[IllegalArgumentException] { + ScriptSQLExec.parse(query(functionWithOtherName), sq) + } + + + + } + } + + "test python script map return type" should "work fine" in { + withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + implicit val spark = runtime.sparkSession + var sq = createSSEL + ScriptSQLExec.parse( + """ + |set dictFun=''' + | + |def apply(self,m): + | dict = {m: m} + | return dict + |'''; + | + |load script.`dictFun` as scriptTable; + |register ScriptUDF.`scriptTable` as dictFun options + |and lang="python" + |and dataType="map(string,string)" + |; + |set data=''' + |{"a":"1"} + |{"a":"2"} + |{"a":"3"} + |{"a":"4"} + |'''; + |load jsonStr.`data` as dataTable; + |select dictFun(a) as res from dataTable as output; + """.stripMargin, sq) + + val result = runtime.sparkSession.sql("select * from output").collect() + assert(result.size == 4) + val sample = result.head.getAs[Map[String, String]](0).head + assert(sample._1 == sample._2) + } + } + + "test scala udaf" should "work fine" in { + withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + implicit val spark = runtime.sparkSession + var sq = createSSEL + ScriptSQLExec.parse( + """ + |set plusFun=''' + |import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} + |import org.apache.spark.sql.types._ + |import org.apache.spark.sql.Row + |class SumAggregation extends UserDefinedAggregateFunction with Serializable{ + | def inputSchema: StructType = new StructType().add("a", LongType) + | def bufferSchema: StructType = new StructType().add("total", LongType) + | def dataType: DataType = LongType + | def deterministic: Boolean = true + | def initialize(buffer: MutableAggregationBuffer): Unit = { + | buffer.update(0, 0l) + | } + | def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + | val sum = buffer.getLong(0) + | val newitem = input.getLong(0) + | buffer.update(0, sum + newitem) + | } + | def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + | buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)) + | } + | def evaluate(buffer: Row): Any = { + | buffer.getLong(0) + | } + |} + |'''; + | + |load script.`plusFun` as scriptTable; + |register ScriptUDF.`scriptTable` as plusFun options + | + |className="SumAggregation" + |and udfType="udaf" + |; + | + |set data=''' + |{"a":1} + |{"a":1} + |{"a":1} + |{"a":1} + |'''; + |load jsonStr.`data` as dataTable; + |select a,plusFun(a) as res from dataTable group by a as output; + | + """.stripMargin, sq) + + val result = runtime.sparkSession.sql("select * from output").collect() + assert(result.size == 1) + assert(result.head.getAs[Long]("res") == 4) + } + } + + "test python udaf" should "work fine" in { + withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + implicit val spark = runtime.sparkSession + var sq = createSSEL + ScriptSQLExec.parse( + """ + |set plusFun=''' + |from org.apache.spark.sql.expressions import MutableAggregationBuffer, UserDefinedAggregateFunction + |from org.apache.spark.sql.types import DataTypes,StructType + |from org.apache.spark.sql import Row + |import java.lang.Long as l + |import java.lang.Integer as i + | + |class SumAggregation: + | + | def inputSchema(self): + | return StructType().add("a", DataTypes.LongType) + | + | def bufferSchema(self): + | return StructType().add("total", DataTypes.LongType) + | + | def dataType(self): + | return DataTypes.LongType + | + | def deterministic(self): + | return True + | + | def initialize(self,buffer): + | return buffer.update(i(0), l(0)) + | + | def update(self,buffer, input): + | sum = buffer.getLong(i(0)) + | newitem = input.getLong(i(0)) + | buffer.update(i(0), l(sum + newitem)) + | + | def merge(self,buffer1, buffer2): + | buffer1.update(i(0), l(buffer1.getLong(i(0)) + buffer2.getLong(i(0)))) + | + | def evaluate(self,buffer): + | return buffer.getLong(i(0)) + |'''; + | + | + |load script.`plusFun` as scriptTable; + |register ScriptUDF.`scriptTable` as plusFun options + |className="SumAggregation" + |and udfType="udaf" + |and lang="python" + |; + | + |set data=''' + |{"a":1} + |{"a":1} + |{"a":1} + |{"a":1} + |'''; + |load jsonStr.`data` as dataTable; + | + |select a,plusFun(a) as res from dataTable group by a as output; + """.stripMargin, sq) + + val result = runtime.sparkSession.sql("select * from output").collect() + assert(result.size == 1) + assert(result.head.getAs[Long]("res") == 4) + } + } +} diff --git a/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala b/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala index f3e210a24..9b8b273d9 100644 --- a/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala +++ b/streamingpro-spark-2.3.0-adaptor/src/main/java/org/apache/spark/sql/catalyst/expressions/WowScalaUDF.scala @@ -34,5 +34,3 @@ case class WowScalaUDF(function: AnyRef, } } - -case class RuntimeComplieUDF() From b4e6d3fd244c1feea5e7386c631d3ed1f2e25167 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Wed, 28 Nov 2018 12:44:47 +0800 Subject: [PATCH 3/3] [TEST][ScriptUDF] add more unit test for ScriptUDF --- ...F.scala => PythonRuntimeCompileUDAF.scala} | 2 +- .../udf/RuntimeCompileScriptInterface.scala | 12 +- .../streaming/udf/RuntimeCompileUDF.scala | 2 +- ...AF.scala => ScalaRuntimeCompileUDAF.scala} | 2 +- .../streaming/core/BasicMLSQLConfig.scala | 9 - .../dsl/mmlib/algs/ScriptUDFSuite.scala | 172 +++++++++++++++++- .../scala/streaming/test/dsl/DslSpec.scala | 135 +------------- 7 files changed, 182 insertions(+), 152 deletions(-) rename streamingpro-mlsql/src/main/java/streaming/udf/{PythonRuntimCompileUDAF.scala => PythonRuntimeCompileUDAF.scala} (98%) rename streamingpro-mlsql/src/main/java/streaming/udf/{ScalaRuntimCompileUDAF.scala => ScalaRuntimeCompileUDAF.scala} (97%) diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimCompileUDAF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimeCompileUDAF.scala similarity index 98% rename from streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimCompileUDAF.scala rename to streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimeCompileUDAF.scala index 9b376a105..00f929c2f 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimCompileUDAF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/PythonRuntimeCompileUDAF.scala @@ -11,7 +11,7 @@ import streaming.jython.{JythonUtils, PythonInterp} /** * Created by fchen on 2018/11/15. */ -object PythonRuntimCompileUDAF extends RuntimeCompileUDAF { +object PythonRuntimeCompileUDAF extends RuntimeCompileUDAF { /** * validate the source code */ diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala index 75b16c336..11e91212d 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileScriptInterface.scala @@ -61,7 +61,7 @@ trait RuntimeCompileScriptInterface[FunType] extends Logging { } } -object RuntimeCompileScriptFactory { +object RuntimeCompileScriptFactory extends Logging { private val _udfCache = HashMap[String, RuntimeCompileUDF]() private val _udafCache = HashMap[String, RuntimeCompileUDAF]() @@ -83,10 +83,14 @@ object RuntimeCompileScriptFactory { } def registerUDF(lang: String, runtimeCompileUDF: RuntimeCompileUDF): Unit = { + logInfo(s"register $lang runtime compile udf" + + s" engine ${runtimeCompileUDF.getClass.getCanonicalName}!") _udfCache.put(lang, runtimeCompileUDF) } def registerUDAF(lang: String, runtimeCompileUDAF: RuntimeCompileUDAF): Unit = { + logInfo(s"register $lang runtime compile udaf" + + s" engine ${runtimeCompileUDAF.getClass.getCanonicalName}!") _udafCache.put(lang, runtimeCompileUDAF) } @@ -94,12 +98,16 @@ object RuntimeCompileScriptFactory { * load all [[RuntimeCompileUDF]] and [[RuntimeCompileUDAF]] */ def loadAll(): Unit = { + def isRuntimeComile(className: String): Boolean = { + (className.endsWith("RuntimeCompileUDF") && className != "streaming.udf.RuntimeCompileUDF") || + (className.endsWith("RuntimeCompileUDAF") && className != "streaming.udf.RuntimeCompileUDAF") + } _lock.synchronized { ClassPath.from(this.getClass.getClassLoader) .getTopLevelClasses("streaming.udf") .map(_.load()) .filter(n => { - n.getName.endsWith("RuntimeCompileUDF") && n.getName != "streaming.udf.RuntimeCompileUDF" + isRuntimeComile(n.getName) }).map(getInstance) .foreach { case udf: RuntimeCompileUDF => registerUDF(udf.lang, udf) diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDF.scala index dcd56af6b..99ce99206 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/RuntimeCompileUDF.scala @@ -22,7 +22,7 @@ trait RuntimeCompileUDF extends RuntimeCompileScriptInterface[AnyRef] { /** * wrap original source code. - * e.g. in [[ScalaRuntimCompileUDAF]], user pass function code, we should wrap code as a class. + * e.g. in [[ScalaRuntimeCompileUDAF]], user pass function code, we should wrap code as a class. * so the runtime compiler will compile source code as runtime instance. */ def wrapCode(scriptCacheKey: ScriptUDFCacheKey): ScriptUDFCacheKey diff --git a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimCompileUDAF.scala b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimeCompileUDAF.scala similarity index 97% rename from streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimCompileUDAF.scala rename to streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimeCompileUDAF.scala index bd4670c23..d49bb3f1b 100644 --- a/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimCompileUDAF.scala +++ b/streamingpro-mlsql/src/main/java/streaming/udf/ScalaRuntimeCompileUDAF.scala @@ -13,7 +13,7 @@ import streaming.dsl.mmlib.algs.ScriptUDFCacheKey /** * Created by fchen on 2018/11/14. */ -object ScalaRuntimCompileUDAF extends RuntimeCompileUDAF with ScalaCompileUtils { +object ScalaRuntimeCompileUDAF extends RuntimeCompileUDAF with ScalaCompileUtils { /** * validate the source code */ diff --git a/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala b/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala index c5218d160..7c6413a27 100644 --- a/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala +++ b/streamingpro-mlsql/src/test/scala/streaming/core/BasicMLSQLConfig.scala @@ -75,13 +75,4 @@ trait BasicMLSQLConfig { ) - val batchParamsWithoutHive = Array( - "-streaming.master", "local[2]", - "-streaming.name", "unit-test", - "-streaming.rest", "false", - "-streaming.platform", "spark", - "-streaming.spark.service", "false", - "-streaming.udf.clzznames", "streaming.crawler.udf.Functions,streaming.dsl.mmlib.algs.processing.UDFFunctions", - "-streaming.unittest", "true" - ) } diff --git a/streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala b/streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala index 2bef4075a..d38761aaa 100644 --- a/streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala +++ b/streamingpro-mlsql/src/test/scala/streaming/dsl/mmlib/algs/ScriptUDFSuite.scala @@ -13,7 +13,7 @@ import streaming.dsl.ScriptSQLExec class ScriptUDFSuite extends BasicSparkOperation with SpecFunctions with BasicMLSQLConfig { "test scala script map return type" should "work fine" in { - withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + withBatchContext(setupBatchContext(batchParamsWithoutHive)) { runtime: SparkRuntime => implicit val spark = runtime.sparkSession var sq = createSSEL ScriptSQLExec.parse( @@ -48,7 +48,7 @@ class ScriptUDFSuite extends BasicSparkOperation with SpecFunctions with BasicML "test scala compile error case" should "work fine" in { - withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + withBatchContext(setupBatchContext(batchParamsWithoutHive)) { runtime: SparkRuntime => implicit val spark = runtime.sparkSession var sq = createSSEL val query: (String) => String = (scalaCode) => s""" @@ -102,7 +102,7 @@ class ScriptUDFSuite extends BasicSparkOperation with SpecFunctions with BasicML } "test python script map return type" should "work fine" in { - withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + withBatchContext(setupBatchContext(batchParamsWithoutHive)) { runtime: SparkRuntime => implicit val spark = runtime.sparkSession var sq = createSSEL ScriptSQLExec.parse( @@ -136,8 +136,37 @@ class ScriptUDFSuite extends BasicSparkOperation with SpecFunctions with BasicML } } + "test python script import" should "work fine" in { + withBatchContext(setupBatchContext(batchParamsWithoutHive)) { runtime: SparkRuntime => + implicit val spark = runtime.sparkSession + var sq = createSSEL + ScriptSQLExec.parse( + """ + |set jsonFun=''' + | + |def apply(self,m): + | import json + | d = json.loads(m) + | return d['key'] + |'''; + | + |load script.`jsonFun` as scriptTable; + |register ScriptUDF.`scriptTable` as jsonFun options + |and lang="python" + |and dataType="string" + |; + |select jsonFun("{\"key\": \"value\"}") as res as output; + """.stripMargin, sq) + + val result = runtime.sparkSession.sql("select * from output").collect() + assert(result.size == 1) + val sample = result.head.getString(0) + assert(sample == "value") + } + } + "test scala udaf" should "work fine" in { - withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + withBatchContext(setupBatchContext(batchParamsWithoutHive)) { runtime: SparkRuntime => implicit val spark = runtime.sparkSession var sq = createSSEL ScriptSQLExec.parse( @@ -193,7 +222,7 @@ class ScriptUDFSuite extends BasicSparkOperation with SpecFunctions with BasicML } "test python udaf" should "work fine" in { - withBatchContext(setupBatchContext(batchParamsWithoutHive, "classpath:///test/empty.json")) { runtime: SparkRuntime => + withBatchContext(setupBatchContext(batchParamsWithoutHive)) { runtime: SparkRuntime => implicit val spark = runtime.sparkSession var sq = createSSEL ScriptSQLExec.parse( @@ -258,4 +287,137 @@ class ScriptUDFSuite extends BasicSparkOperation with SpecFunctions with BasicML assert(result.head.getAs[Long]("res") == 4) } } + + "test ScalaRuntimeCompileUDAF" should "work fine" in { + + withBatchContext(setupBatchContext(batchParams)) { runtime: SparkRuntime => + //执行sql + implicit val spark = runtime.sparkSession + val sq = createSSEL + + ScriptSQLExec.parse( + """ + |set plusFun=''' + |import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} + |import org.apache.spark.sql.types._ + |import org.apache.spark.sql.Row + |class SumAggregation extends UserDefinedAggregateFunction with Serializable{ + | def inputSchema: StructType = new StructType().add("a", LongType) + | def bufferSchema: StructType = new StructType().add("total", LongType) + | def dataType: DataType = LongType + | def deterministic: Boolean = true + | def initialize(buffer: MutableAggregationBuffer): Unit = { + | buffer.update(0, 0l) + | } + | def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + | val sum = buffer.getLong(0) + | val newitem = input.getLong(0) + | buffer.update(0, sum + newitem) + | } + | def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + | buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)) + | } + | def evaluate(buffer: Row): Any = { + | buffer.getLong(0) + | } + |} + |'''; + | + | + |--加载脚本 + |load script.`plusFun` as scriptTable; + |--注册为UDF函数 名称为plusFun + |register ScriptUDF.`scriptTable` as plusFun options + |className="SumAggregation" + |and udfType="udaf" + |; + | + |set data=''' + |{"a":1} + |{"a":1} + |{"a":1} + |{"a":1} + |'''; + |load jsonStr.`data` as dataTable; + | + |-- 使用plusFun + |select a,plusFun(a) as res from dataTable group by a as output; + """.stripMargin, sq) + val res = spark.sql("select * from output").collect().head.get(1) + assert(res == 4) + } + } + + "test PythonRuntimeCompileUDAF" should "work fine" in { + + withBatchContext(setupBatchContext(batchParams)) { runtime: SparkRuntime => + //执行sql + implicit val spark = runtime.sparkSession + val sq = createSSEL + + ScriptSQLExec.parse( + """ + |set plusFun=''' + |from org.apache.spark.sql.expressions import MutableAggregationBuffer, UserDefinedAggregateFunction + |from org.apache.spark.sql.types import DataTypes,StructType + |from org.apache.spark.sql import Row + |import java.lang.Long as l + |import java.lang.Integer as i + | + |class SumAggregation: + | + | def inputSchema(self): + | return StructType().add("a", DataTypes.LongType) + | + | def bufferSchema(self): + | return StructType().add("total", DataTypes.LongType) + | + | def dataType(self): + | return DataTypes.LongType + | + | def deterministic(self): + | return True + | + | def initialize(self,buffer): + | return buffer.update(i(0), l(0)) + | + | def update(self,buffer, input): + | sum = buffer.getLong(i(0)) + | newitem = input.getLong(i(0)) + | buffer.update(i(0), l(sum + newitem)) + | + | def merge(self,buffer1, buffer2): + | buffer1.update(i(0), l(buffer1.getLong(i(0)) + buffer2.getLong(i(0)))) + | + | def evaluate(self,buffer): + | return buffer.getLong(i(0)) + |'''; + | + | + |--加载脚本 + |load script.`plusFun` as scriptTable; + |--注册为UDF函数 名称为plusFun + |register ScriptUDF.`scriptTable` as plusFun options + |className="SumAggregation" + |and udfType="udaf" + |and lang="python" + |; + | + |set data=''' + |{"a":1} + |{"a":1} + |{"a":1} + |{"a":1} + |'''; + |load jsonStr.`data` as dataTable; + | + |-- 使用plusFun + |select a,plusFun(a) as res from dataTable group by a as output; + """.stripMargin, sq) + val res = spark.sql("select * from output").collect().head.get(1) + assume(res == 4) + } + } + + } diff --git a/streamingpro-mlsql/src/test/scala/streaming/test/dsl/DslSpec.scala b/streamingpro-mlsql/src/test/scala/streaming/test/dsl/DslSpec.scala index c9b1c3df9..68d91d5e7 100644 --- a/streamingpro-mlsql/src/test/scala/streaming/test/dsl/DslSpec.scala +++ b/streamingpro-mlsql/src/test/scala/streaming/test/dsl/DslSpec.scala @@ -66,7 +66,7 @@ class DslSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLConf "ScalaScriptUDF" should "work fine" in { - withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime => + withBatchContext(setupBatchContext(batchParams)) { runtime: SparkRuntime => //执行sql implicit val spark = runtime.sparkSession val sq = createSSEL @@ -169,7 +169,7 @@ class DslSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLConf } } "pyton udf" should "work fine" in { - withBatchContext(setupBatchContext(batchParams ++ Array("-spark.sql.codegen.wholeStage", "false"), "classpath:///test/empty.json")) { runtime: SparkRuntime => + withBatchContext(setupBatchContext(batchParams)) { runtime: SparkRuntime => implicit val spark = runtime.sparkSession val sq = createSSEL @@ -428,137 +428,6 @@ class DslSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLConf } } - "ScalaScriptUDAF" should "work fine" in { - - withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime => - //执行sql - implicit val spark = runtime.sparkSession - val sq = createSSEL - - ScriptSQLExec.parse( - """ - |set plusFun=''' - |import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} - |import org.apache.spark.sql.types._ - |import org.apache.spark.sql.Row - |class SumAggregation extends UserDefinedAggregateFunction with Serializable{ - | def inputSchema: StructType = new StructType().add("a", LongType) - | def bufferSchema: StructType = new StructType().add("total", LongType) - | def dataType: DataType = LongType - | def deterministic: Boolean = true - | def initialize(buffer: MutableAggregationBuffer): Unit = { - | buffer.update(0, 0l) - | } - | def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - | val sum = buffer.getLong(0) - | val newitem = input.getLong(0) - | buffer.update(0, sum + newitem) - | } - | def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - | buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)) - | } - | def evaluate(buffer: Row): Any = { - | buffer.getLong(0) - | } - |} - |'''; - | - | - |--加载脚本 - |load script.`plusFun` as scriptTable; - |--注册为UDF函数 名称为plusFun - |register ScriptUDF.`scriptTable` as plusFun options - |className="SumAggregation" - |and udfType="udaf" - |; - | - |set data=''' - |{"a":1} - |{"a":1} - |{"a":1} - |{"a":1} - |'''; - |load jsonStr.`data` as dataTable; - | - |-- 使用plusFun - |select a,plusFun(a) as res from dataTable group by a as output; - """.stripMargin, sq) - val res = spark.sql("select * from output").collect().head.get(1) - assume(res == 4) - } - } - - "PythonScriptUDAF" should "work fine" in { - - withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime => - //执行sql - implicit val spark = runtime.sparkSession - val sq = createSSEL - - ScriptSQLExec.parse( - """ - |set plusFun=''' - |from org.apache.spark.sql.expressions import MutableAggregationBuffer, UserDefinedAggregateFunction - |from org.apache.spark.sql.types import DataTypes,StructType - |from org.apache.spark.sql import Row - |import java.lang.Long as l - |import java.lang.Integer as i - | - |class SumAggregation: - | - | def inputSchema(self): - | return StructType().add("a", DataTypes.LongType) - | - | def bufferSchema(self): - | return StructType().add("total", DataTypes.LongType) - | - | def dataType(self): - | return DataTypes.LongType - | - | def deterministic(self): - | return True - | - | def initialize(self,buffer): - | return buffer.update(i(0), l(0)) - | - | def update(self,buffer, input): - | sum = buffer.getLong(i(0)) - | newitem = input.getLong(i(0)) - | buffer.update(i(0), l(sum + newitem)) - | - | def merge(self,buffer1, buffer2): - | buffer1.update(i(0), l(buffer1.getLong(i(0)) + buffer2.getLong(i(0)))) - | - | def evaluate(self,buffer): - | return buffer.getLong(i(0)) - |'''; - | - | - |--加载脚本 - |load script.`plusFun` as scriptTable; - |--注册为UDF函数 名称为plusFun - |register ScriptUDF.`scriptTable` as plusFun options - |className="SumAggregation" - |and udfType="udaf" - |and lang="python" - |; - | - |set data=''' - |{"a":1} - |{"a":1} - |{"a":1} - |{"a":1} - |'''; - |load jsonStr.`data` as dataTable; - | - |-- 使用plusFun - |select a,plusFun(a) as res from dataTable group by a as output; - """.stripMargin, sq) - val res = spark.sql("select * from output").collect().head.get(1) - assume(res == 4) - } - } - "save-partitionby" should "work fine" in { withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime =>