Skip to content

Commit

Permalink
[FLINK-5881] [table] ScalarFunction(UDF) should support variable type…
Browse files Browse the repository at this point in the history
…s and variable arguments

This closes apache#3389.
  • Loading branch information
zhuoluoy authored and twalthr committed Mar 13, 2017
1 parent 354a13e commit 9b179be
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,22 @@ class ScalarFunctionCallGen(
codeGenerator: CodeGenerator,
operands: Seq[GeneratedExpression])
: GeneratedExpression = {
// determine function signature and result class
val matchingSignature = getSignature(scalarFunction, signature)
// determine function method and result class
val matchingMethod = getEvalMethod(scalarFunction, signature)
.getOrElse(throw new CodeGenException("No matching signature found."))
val matchingSignature = matchingMethod.getParameterTypes
val resultClass = getResultTypeClass(scalarFunction, matchingSignature)

// zip for variable signatures
var paramToOperands = matchingSignature.zip(operands)
if (operands.length > matchingSignature.length) {
operands.drop(matchingSignature.length).foreach(op =>
paramToOperands = paramToOperands :+ (matchingSignature.last.getComponentType, op)
)
}

// convert parameters for function (output boxing)
val parameters = matchingSignature
.zip(operands)
.map { case (paramClass, operandExpr) =>
val parameters = paramToOperands.map { case (paramClass, operandExpr) =>
if (paramClass.isPrimitive) {
operandExpr
} else if (ClassUtils.isPrimitiveWrapper(paramClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,15 @@ object ScalarSqlFunction {
.getParameterTypes(foundSignature)
.map(typeFactory.createTypeFromTypeInfo)

inferredTypes.zipWithIndex.foreach {
case (inferredType, i) =>
operandTypes(i) = inferredType
for (i <- operandTypes.indices) {
if (i < inferredTypes.length - 1) {
operandTypes(i) = inferredTypes(i)
} else if (null != inferredTypes.last.getComponentType) {
// last argument is a collection, the array type
operandTypes(i) = inferredTypes.last.getComponentType
} else {
operandTypes(i) = inferredTypes.last
}
}
}
}
Expand All @@ -137,8 +143,18 @@ object ScalarSqlFunction {
}

override def getOperandCountRange: SqlOperandCountRange = {
val signatureLengths = signatures.map(_.length)
SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max)
var min = 255
var max = -1
signatures.foreach( sig => {
var len = sig.length
if (len > 0 && sig(sig.length - 1).isArray) {
max = 254 // according to JVM spec 4.3.3
len = sig.length - 1
}
max = Math.max(len, max)
min = Math.min(len, min)
})
SqlOperandCountRanges.between(min, max)
}

override def checkOperandTypes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,7 @@ object UserDefinedFunctionUtils {
function: UserDefinedFunction,
signature: Seq[TypeInformation[_]])
: Option[Array[Class[_]]] = {
// We compare the raw Java classes not the TypeInformation.
// TypeInformation does not matter during runtime (e.g. within a MapFunction).
val actualSignature = typeInfoToClass(signature)
val signatures = getSignatures(function)

signatures
// go over all signatures and find one matching actual signature
.find { curSig =>
// match parameters of signature to actual parameters
actualSignature.length == curSig.length &&
curSig.zipWithIndex.forall { case (clazz, i) =>
parameterTypeEquals(actualSignature(i), clazz)
}
}
getEvalMethod(function, signature).map(_.getParameterTypes)
}

/**
Expand All @@ -106,16 +93,52 @@ object UserDefinedFunctionUtils {
val actualSignature = typeInfoToClass(signature)
val evalMethods = checkAndExtractEvalMethods(function)

evalMethods
// go over all eval methods and find one matching
.find { cur =>
val signatures = cur.getParameterTypes
// match parameters of signature to actual parameters
actualSignature.length == signatures.length &&
signatures.zipWithIndex.forall { case (clazz, i) =>
parameterTypeEquals(actualSignature(i), clazz)
val filtered = evalMethods
// go over all eval methods and filter out matching methods
.filter {
case cur if !cur.isVarArgs =>
val signatures = cur.getParameterTypes
// match parameters of signature to actual parameters
actualSignature.length == signatures.length &&
signatures.zipWithIndex.forall { case (clazz, i) =>
parameterTypeEquals(actualSignature(i), clazz)
}
case cur if cur.isVarArgs =>
val signatures = cur.getParameterTypes
actualSignature.zipWithIndex.forall {
// non-varargs
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(clazz, signatures(i))
// varargs
case (clazz, i) if i >= signatures.length - 1 =>
parameterTypeEquals(clazz, signatures.last.getComponentType)
} || (actualSignature.isEmpty && signatures.length == 1) // empty varargs
}

// if there is a fixed method, compiler will call this method preferentially
val fixedMethodsCount = filtered.count(!_.isVarArgs)
val found = filtered.filter { cur =>
fixedMethodsCount > 0 && !cur.isVarArgs ||
fixedMethodsCount == 0 && cur.isVarArgs
}

// check if there is a Scala varargs annotation
if (found.isEmpty &&
evalMethods.exists { evalMethod =>
val signatures = evalMethod.getParameterTypes
signatures.zipWithIndex.forall {
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(actualSignature(i), clazz)
case (clazz, i) if i == signatures.length - 1 =>
clazz.getName.equals("scala.collection.Seq")
}
}) {
throw new ValidationException("Scala-style variable arguments in 'eval' methods are not " +
"supported. Please add a @scala.annotation.varargs annotation.")
} else if (found.length > 1) {
throw new ValidationException("Found multiple 'eval' methods which match the signature.")
}
found.headOption
}

/**
Expand All @@ -133,7 +156,7 @@ object UserDefinedFunctionUtils {

/**
* Extracts "eval" methods and throws a [[ValidationException]] if no implementation
* can be found.
* can be found, or implementation does not match the requirements.
*/
def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = {
val methods = function
Expand All @@ -152,9 +175,9 @@ object UserDefinedFunctionUtils {
s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
s"one method named 'eval' which is public, not abstract and " +
s"(in case of table functions) not static.")
} else {
methods
}

methods
}

def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
Expand Down Expand Up @@ -317,6 +340,7 @@ object UserDefinedFunctionUtils {
private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
candidate == null ||
candidate == expected ||
expected == classOf[Object] ||
expected.isPrimitive && Primitives.wrap(expected) == candidate ||
candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,24 @@ public String eval(Integer a, int b, Long c) {
}
}

public static class JavaFunc2 extends ScalarFunction {
public String eval(String s, Integer... a) {
int m = 1;
for (int n : a) {
m *= n;
}
return s + m;
}
}

public static class JavaFunc3 extends ScalarFunction {
public int eval(String a, int... b) {
return b.length;
}

public String eval(String c) {
return c;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.types.Row
import org.apache.flink.table.api.Types
import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1}
import org.apache.flink.table.api.{Types, ValidationException}
import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1, JavaFunc2, JavaFunc3}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils._
import org.apache.flink.table.functions.ScalarFunction
Expand Down Expand Up @@ -180,6 +180,85 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"+0 00:00:01.000")
}

@Test
def testVariableArgs(): Unit = {
testAllApis(
Func14(1, 2, 3, 4),
"Func14(1, 2, 3, 4)",
"Func14(1, 2, 3, 4)",
"10")

// Test for empty arguments
testAllApis(
Func14(),
"Func14()",
"Func14()",
"0")

// Test for override
testAllApis(
Func15("Hello"),
"Func15('Hello')",
"Func15('Hello')",
"Hello"
)

testAllApis(
Func15('f1),
"Func15(f1)",
"Func15(f1)",
"Test"
)

testAllApis(
Func15("Hello", 1, 2, 3),
"Func15('Hello', 1, 2, 3)",
"Func15('Hello', 1, 2, 3)",
"Hello3"
)

testAllApis(
Func16('f9),
"Func16(f9)",
"Func16(f9)",
"Hello, World"
)

try {
testAllApis(
Func17("Hello", "World"),
"Func17('Hello', 'World')",
"Func17('Hello', 'World')",
"Hello, World"
)
throw new RuntimeException("Shouldn't be reached here!")
} catch {
case ex: ValidationException =>
// ok
}

val JavaFunc2 = new JavaFunc2
testAllApis(
JavaFunc2("Hi", 1, 3, 5, 7),
"JavaFunc2('Hi', 1, 3, 5, 7)",
"JavaFunc2('Hi', 1, 3, 5, 7)",
"Hi105")

// test overloading
val JavaFunc3 = new JavaFunc3
testAllApis(
JavaFunc3("Hi"),
"JavaFunc3('Hi')",
"JavaFunc3('Hi')",
"Hi")

testAllApis(
JavaFunc3('f1),
"JavaFunc3(f1)",
"JavaFunc3(f1)",
"Test")
}

@Test
def testJavaBoxedPrimitives(): Unit = {
val JavaFunc0 = new JavaFunc0()
Expand Down Expand Up @@ -238,7 +317,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
// ----------------------------------------------------------------------------------------------

override def testData: Any = {
val testData = new Row(9)
val testData = new Row(10)
testData.setField(0, 42)
testData.setField(1, "Test")
testData.setField(2, null)
Expand All @@ -248,6 +327,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10"))
testData.setField(7, 12)
testData.setField(8, 1000L)
testData.setField(9, Seq("Hello", "World"))
testData
}

Expand All @@ -261,7 +341,8 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
Types.TIME,
Types.TIMESTAMP,
Types.INTERVAL_MONTHS,
Types.INTERVAL_MILLIS
Types.INTERVAL_MILLIS,
TypeInformation.of(classOf[Seq[String]])
).asInstanceOf[TypeInformation[Any]]
}

Expand All @@ -279,8 +360,14 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"Func10" -> Func10,
"Func11" -> Func11,
"Func12" -> Func12,
"Func14" -> Func14,
"Func15" -> Func15,
"Func16" -> Func16,
"Func17" -> Func17,
"JavaFunc0" -> new JavaFunc0,
"JavaFunc1" -> new JavaFunc1,
"JavaFunc2" -> new JavaFunc2,
"JavaFunc3" -> new JavaFunc3,
"RichFunc0" -> new RichFunc0,
"RichFunc1" -> new RichFunc1,
"RichFunc2" -> new RichFunc2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.junit.Assert
import scala.collection.mutable
import scala.io.Source

import scala.annotation.varargs

case class SimplePojo(name: String, age: Int)

object Func0 extends ScalarFunction {
Expand Down Expand Up @@ -227,3 +229,37 @@ class Func13(prefix: String) extends ScalarFunction {
}
}

object Func14 extends ScalarFunction {

@varargs
def eval(a: Int*): Int = {
a.sum
}
}

object Func15 extends ScalarFunction {

@varargs
def eval(a: String, b: Int*): String = {
a + b.length
}

def eval(a: String): String = {
a
}
}

object Func16 extends ScalarFunction {

def eval(a: Seq[String]): String = {
a.mkString(", ")
}
}

object Func17 extends ScalarFunction {

// Without @varargs, we will throw an exception
def eval(a: String*): String = {
a.mkString(", ")
}
}

0 comments on commit 9b179be

Please sign in to comment.