Skip to content

Commit

Permalink
Merge pull request byzer-org#672 from cfmcgrady/scriptUdf
Browse files Browse the repository at this point in the history
[ScriptUDF][update] refatoring ScriptUDF module
  • Loading branch information
cfmcgrady authored Nov 28, 2018
2 parents 52e3bde + b4e6d3f commit e09b20a
Show file tree
Hide file tree
Showing 16 changed files with 1,021 additions and 615 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package streaming.common

import scala.reflect.runtime.universe._
import javax.tools._

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -51,6 +52,18 @@ object SourceCodeCompiler extends Logging {
}
}

def getFunReturnType(fun: String): Type = {
import scala.tools.reflect.ToolBox
var classLoader = Thread.currentThread().getContextClassLoader
if (classLoader == null) {
classLoader = scala.reflect.runtime.universe.getClass.getClassLoader
}
val tb = runtimeMirror(classLoader).mkToolBox()
val tree = tb.parse(fun)
val defDef = tb.typecheck(tree).asInstanceOf[DefDef]
defDef.tpt.tpe
}

def compileScala(src: String): Class[_] = {
import scala.reflect.runtime.universe
import scala.tools.reflect.ToolBox
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,4 @@ public static PyObject compilePython(String src, String methodName) {
return pi.get(methodName);
}

public static void main(String[] args) {
String source = "class A:\n" +
"\tdef a(self,k1,k2):\n" +
"\t return k1 + k2";
PyObject ob = compilePython(source, "A");
PyObject instance = ob.__call__();
System.out.println(instance.__getattr__("a").__call__(new PyInteger(2), new PyInteger(3)));
}

}
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package streaming.dsl.mmlib.algs

import org.apache.spark.ml.param.Param
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF, WowScalaUDF}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.udf.UDFManager
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.udf.{PythonSourceUDAF, PythonSourceUDF, ScalaSourceUDAF, ScalaSourceUDF}
import streaming.udf._

/**
* Created by allwefantasy on 27/8/2018.
Expand All @@ -23,9 +22,6 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit
emptyDataFrame()(df)
}

/*
register ScalaScriptUDF.`scriptText` as udf1;
*/
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
val res = params.get(code.name).getOrElse(sparkSession.table(path).head().getString(0))

Expand Down Expand Up @@ -60,40 +56,23 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit
if (!dataType.isValid(l)) thr(dataType, l)
set(dataType, l)
}

val udf = () => {
val (func, returnType) = $(lang) match {
case "python" =>
if (params.contains(className.name)) {
PythonSourceUDF(res, params(className.name), params.get(methodName.name), params(dataType.name))
} else {
PythonSourceUDF(res, params.get(methodName.name), params(dataType.name))
}

case _ =>
if (params.contains(className.name)) {
ScalaSourceUDF(res, params(className.name), params.get(methodName.name))
} else {
ScalaSourceUDF(res, params.get(methodName.name))
}
}
(e: Seq[Expression]) => new WowScalaUDF(func, returnType, e).toScalaUDF
}

val udaf = () => {
val func = $(lang) match {
case "python" =>
PythonSourceUDAF(res, $(className))

case _ =>
ScalaSourceUDAF(res, $(className))
}
(e: Seq[Expression]) => ScalaUDAF(e, func)
}
val scriptCacheKey = ScriptUDFCacheKey(
res, "", $(className), $(udfType), $(methodName), $(dataType), $(lang)
)

$(udfType) match {
case "udaf" => udaf()
case _ => udf()
case "udaf" =>
val udaf = RuntimeCompileScriptFactory.getUDAFCompilerBylang($(lang))
if (!udaf.isDefined) {
throw new IllegalArgumentException()
}
(e: Seq[Expression]) => udaf.get.udaf(e, scriptCacheKey)
case _ =>
val udf = RuntimeCompileScriptFactory.getUDFCompilerBylang($(lang))
if (!udf.isDefined) {
throw new IllegalArgumentException()
}
(e: Seq[Expression]) => udf.get.udf(e, scriptCacheKey)
}
}

Expand Down Expand Up @@ -396,9 +375,7 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit


final val lang: Param[String] = new Param[String](this, "lang",
s"""Which type of language you want. [scala|python]""", (s: String) => {
s == "scala" || s == "python"
})
s"""Which type of language you want. [scala|python]""")
setDefault(lang, "scala")

final val udfType: Param[String] = new Param[String](this, "udfType",
Expand All @@ -409,10 +386,23 @@ class ScriptUDF(override val uid: String) extends SQLAlg with MllibFunctions wit

final val className: Param[String] = new Param[String](this, "className",
s"""the className of you defined in code snippet.""")
setDefault(className, "")

final val methodName: Param[String] = new Param[String](this, "methodName",
s"""the methodName of you defined in code snippet. If the name is apply, this parameter is optional""")
setDefault(methodName, "apply")

final val dataType: Param[String] = new Param[String](this, "dataType",
s"""when you use python udf, you should define return type.""")
setDefault(dataType, "")
}

case class ScriptUDFCacheKey(
originalCode: String,
wrappedCode: String,
className: String,
udfType: String,
methodName: String,
dataType: String,
lang: String)

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package streaming.parser
import org.apache.spark.sql.types._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.ScalaReflection
import streaming.common.SourceCodeCompiler

/**
* Created by allwefantasy on 8/9/2018.
*/
Expand Down Expand Up @@ -58,27 +61,29 @@ object SparkTypePaser {
}

//array(array(map(string,string)))
def toSparkType(dt: String): DataType = dt match {
case "boolean" => BooleanType
case "byte" => ByteType
case "short" => ShortType
case "integer" => IntegerType
case "date" => DateType
case "long" => LongType
case "float" => FloatType
case "double" => DoubleType
case "decimal" => DoubleType
case "binary" => BinaryType
case "string" => StringType
case c: String if c.startsWith("array") =>
ArrayType(toSparkType(findInputInArrayBracket(c)))
case c: String if c.startsWith("map") =>
//map(map(string,string),string)
val (key, value) = findKeyAndValue(findInputInArrayBracket(c))
MapType(toSparkType(key), toSparkType(value))

case _ => throw new RuntimeException("dt is not found spark type")
def toSparkType(dt: String): DataType = {
dt match {
case "boolean" => BooleanType
case "byte" => ByteType
case "short" => ShortType
case "integer" => IntegerType
case "date" => DateType
case "long" => LongType
case "float" => FloatType
case "double" => DoubleType
case "decimal" => DoubleType
case "binary" => BinaryType
case "string" => StringType
case c: String if c.startsWith("array") =>
ArrayType(toSparkType(findInputInArrayBracket(c)))
case c: String if c.startsWith("map") =>
//map(map(string,string),string)
val (key, value) = findKeyAndValue(findInputInArrayBracket(c))
MapType(toSparkType(key), toSparkType(value))

case _ => throw new RuntimeException(s"$dt is not found spark type")

}
}

def cleanSparkSchema(wowStructType: WowStructType): StructType = {
Expand Down Expand Up @@ -167,12 +172,4 @@ object SparkTypePaser {

}
}

def main(args: Array[String]): Unit = {
val res = toSparkSchema("st(field(name,string),field(name1,st(field(name2,array(string)))))", WowStructType(ArrayBuffer()))
println(cleanSparkSchema(res.asInstanceOf[WowStructType]))
// val wow = ArrayBuffer[String]()
// findFieldArray("field(name,string),field(name1,st(field(name2,array(string))))", wow)
// println(wow)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,31 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StructType}
import org.python.core.{Py, PyObject}
import streaming.common.{ScriptCacheKey, SourceCodeCompiler}
import streaming.dsl.ScriptSQLExec
import streaming.jython.JythonUtils
import streaming.log.{Logging, WowLog}

import streaming.dsl.mmlib.algs.ScriptUDFCacheKey
import streaming.jython.{JythonUtils, PythonInterp}

/**
* Created by allwefantasy on 31/8/2018.
*/
object PythonSourceUDAF extends Logging with WowLog {
def apply(src: String, className: String): UserDefinedAggregateFunction = {
generateAggregateFunction(src, className)
* Created by fchen on 2018/11/15.
*/
object PythonRuntimeCompileUDAF extends RuntimeCompileUDAF {
/**
* validate the source code
*/
override def check(sourceCode: String): Boolean = true

/**
* compile the source code.
*
* @param scriptCacheKey
* @return
*/
override def compile(scriptCacheKey: ScriptUDFCacheKey): AnyRef = {
PythonInterp.compilePython(scriptCacheKey.originalCode, scriptCacheKey.className)
}

private def generateAggregateFunction(src: String, className: String): UserDefinedAggregateFunction = {
override def generateFunction(scriptCacheKey: ScriptUDFCacheKey): UserDefinedAggregateFunction = {

new UserDefinedAggregateFunction with Serializable {

val c = ScriptSQLExec.contextGetOrForTest()
Expand All @@ -28,17 +38,16 @@ object PythonSourceUDAF extends Logging with WowLog {
fn()
} catch {
case e: Exception =>
logError(format_exception(e))
throw e
}
}

@transient val objectUsingInDriver = wrap(() => {
SourceCodeCompiler.execute(ScriptCacheKey(src, className, "python")).asInstanceOf[PyObject].__call__()
execute(scriptCacheKey).asInstanceOf[PyObject].__call__()
}).asInstanceOf[PyObject]

lazy val objectUsingInExecutor = wrap(() => {
SourceCodeCompiler.execute(ScriptCacheKey(src, className, "python")).asInstanceOf[PyObject].__call__()
execute(scriptCacheKey).asInstanceOf[PyObject].__call__()
}).asInstanceOf[PyObject]


Expand Down Expand Up @@ -112,4 +121,6 @@ object PythonSourceUDAF extends Logging with WowLog {

}
}

override def lang: String = "python"
}
Loading

0 comments on commit e09b20a

Please sign in to comment.