From 9b179beaea2b623ad3637e417f6d8014b696d038 Mon Sep 17 00:00:00 2001 From: Zhuoluo Yang Date: Wed, 22 Feb 2017 18:53:34 +0800 Subject: [PATCH] [FLINK-5881] [table] ScalarFunction(UDF) should support variable types and variable arguments This closes #3389. --- .../codegen/calls/ScalarFunctionCallGen.scala | 17 +++- .../functions/utils/ScalarSqlFunction.scala | 26 ++++- .../utils/UserDefinedFunctionUtils.scala | 74 ++++++++++----- .../utils/UserDefinedScalarFunctions.java | 20 ++++ .../UserDefinedScalarFunctionTest.scala | 95 ++++++++++++++++++- .../utils/UserDefinedScalarFunctions.scala | 36 +++++++ 6 files changed, 229 insertions(+), 39 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala index 7ff18eb6332a9..b0b4e09a9159a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala @@ -43,15 +43,22 @@ class ScalarFunctionCallGen( codeGenerator: CodeGenerator, operands: Seq[GeneratedExpression]) : GeneratedExpression = { - // determine function signature and result class - val matchingSignature = getSignature(scalarFunction, signature) + // determine function method and result class + val matchingMethod = getEvalMethod(scalarFunction, signature) .getOrElse(throw new CodeGenException("No matching signature found.")) + val matchingSignature = matchingMethod.getParameterTypes val resultClass = getResultTypeClass(scalarFunction, matchingSignature) + // zip for variable signatures + var paramToOperands = matchingSignature.zip(operands) + if (operands.length > matchingSignature.length) { + operands.drop(matchingSignature.length).foreach(op => + paramToOperands = paramToOperands :+ (matchingSignature.last.getComponentType, op) + ) + } + // convert parameters for function (output boxing) - val parameters = matchingSignature - .zip(operands) - .map { case (paramClass, operandExpr) => + val parameters = paramToOperands.map { case (paramClass, operandExpr) => if (paramClass.isPrimitive) { operandExpr } else if (ClassUtils.isPrimitiveWrapper(paramClass) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala index dc6d41f876bac..e2cd272030cc4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala @@ -113,9 +113,15 @@ object ScalarSqlFunction { .getParameterTypes(foundSignature) .map(typeFactory.createTypeFromTypeInfo) - inferredTypes.zipWithIndex.foreach { - case (inferredType, i) => - operandTypes(i) = inferredType + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + } else { + operandTypes(i) = inferredTypes.last + } } } } @@ -137,8 +143,18 @@ object ScalarSqlFunction { } override def getOperandCountRange: SqlOperandCountRange = { - val signatureLengths = signatures.map(_.length) - SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max) + var min = 255 + var max = -1 + signatures.foreach( sig => { + var len = sig.length + if (len > 0 && sig(sig.length - 1).isArray) { + max = 254 // according to JVM spec 4.3.3 + len = sig.length - 1 + } + max = Math.max(len, max) + min = Math.min(len, min) + }) + SqlOperandCountRanges.between(min, max) } override def checkOperandTypes( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index 21d28b5e591da..c1cfe0610ae5e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -78,20 +78,7 @@ object UserDefinedFunctionUtils { function: UserDefinedFunction, signature: Seq[TypeInformation[_]]) : Option[Array[Class[_]]] = { - // We compare the raw Java classes not the TypeInformation. - // TypeInformation does not matter during runtime (e.g. within a MapFunction). - val actualSignature = typeInfoToClass(signature) - val signatures = getSignatures(function) - - signatures - // go over all signatures and find one matching actual signature - .find { curSig => - // match parameters of signature to actual parameters - actualSignature.length == curSig.length && - curSig.zipWithIndex.forall { case (clazz, i) => - parameterTypeEquals(actualSignature(i), clazz) - } - } + getEvalMethod(function, signature).map(_.getParameterTypes) } /** @@ -106,16 +93,52 @@ object UserDefinedFunctionUtils { val actualSignature = typeInfoToClass(signature) val evalMethods = checkAndExtractEvalMethods(function) - evalMethods - // go over all eval methods and find one matching - .find { cur => - val signatures = cur.getParameterTypes - // match parameters of signature to actual parameters - actualSignature.length == signatures.length && - signatures.zipWithIndex.forall { case (clazz, i) => - parameterTypeEquals(actualSignature(i), clazz) + val filtered = evalMethods + // go over all eval methods and filter out matching methods + .filter { + case cur if !cur.isVarArgs => + val signatures = cur.getParameterTypes + // match parameters of signature to actual parameters + actualSignature.length == signatures.length && + signatures.zipWithIndex.forall { case (clazz, i) => + parameterTypeEquals(actualSignature(i), clazz) + } + case cur if cur.isVarArgs => + val signatures = cur.getParameterTypes + actualSignature.zipWithIndex.forall { + // non-varargs + case (clazz, i) if i < signatures.length - 1 => + parameterTypeEquals(clazz, signatures(i)) + // varargs + case (clazz, i) if i >= signatures.length - 1 => + parameterTypeEquals(clazz, signatures.last.getComponentType) + } || (actualSignature.isEmpty && signatures.length == 1) // empty varargs + } + + // if there is a fixed method, compiler will call this method preferentially + val fixedMethodsCount = filtered.count(!_.isVarArgs) + val found = filtered.filter { cur => + fixedMethodsCount > 0 && !cur.isVarArgs || + fixedMethodsCount == 0 && cur.isVarArgs + } + + // check if there is a Scala varargs annotation + if (found.isEmpty && + evalMethods.exists { evalMethod => + val signatures = evalMethod.getParameterTypes + signatures.zipWithIndex.forall { + case (clazz, i) if i < signatures.length - 1 => + parameterTypeEquals(actualSignature(i), clazz) + case (clazz, i) if i == signatures.length - 1 => + clazz.getName.equals("scala.collection.Seq") } + }) { + throw new ValidationException("Scala-style variable arguments in 'eval' methods are not " + + "supported. Please add a @scala.annotation.varargs annotation.") + } else if (found.length > 1) { + throw new ValidationException("Found multiple 'eval' methods which match the signature.") } + found.headOption } /** @@ -133,7 +156,7 @@ object UserDefinedFunctionUtils { /** * Extracts "eval" methods and throws a [[ValidationException]] if no implementation - * can be found. + * can be found, or implementation does not match the requirements. */ def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = { val methods = function @@ -152,9 +175,9 @@ object UserDefinedFunctionUtils { s"Function class '${function.getClass.getCanonicalName}' does not implement at least " + s"one method named 'eval' which is public, not abstract and " + s"(in case of table functions) not static.") - } else { - methods } + + methods } def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = { @@ -317,6 +340,7 @@ object UserDefinedFunctionUtils { private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean = candidate == null || candidate == expected || + expected == classOf[Object] || expected.isPrimitive && Primitives.wrap(expected) == candidate || candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) || candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) || diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java index e817f06b4e1be..56f866d2b1135 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java @@ -33,4 +33,24 @@ public String eval(Integer a, int b, Long c) { } } + public static class JavaFunc2 extends ScalarFunction { + public String eval(String s, Integer... a) { + int m = 1; + for (int n : a) { + m *= n; + } + return s + m; + } + } + + public static class JavaFunc3 extends ScalarFunction { + public int eval(String a, int... b) { + return b.length; + } + + public String eval(String c) { + return c; + } + } + } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala index a6c1760c9b896..51583c3f21a40 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala @@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.types.Row -import org.apache.flink.table.api.Types -import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1} +import org.apache.flink.table.api.{Types, ValidationException} +import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1, JavaFunc2, JavaFunc3} import org.apache.flink.table.api.scala._ import org.apache.flink.table.expressions.utils._ import org.apache.flink.table.functions.ScalarFunction @@ -180,6 +180,85 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "+0 00:00:01.000") } + @Test + def testVariableArgs(): Unit = { + testAllApis( + Func14(1, 2, 3, 4), + "Func14(1, 2, 3, 4)", + "Func14(1, 2, 3, 4)", + "10") + + // Test for empty arguments + testAllApis( + Func14(), + "Func14()", + "Func14()", + "0") + + // Test for override + testAllApis( + Func15("Hello"), + "Func15('Hello')", + "Func15('Hello')", + "Hello" + ) + + testAllApis( + Func15('f1), + "Func15(f1)", + "Func15(f1)", + "Test" + ) + + testAllApis( + Func15("Hello", 1, 2, 3), + "Func15('Hello', 1, 2, 3)", + "Func15('Hello', 1, 2, 3)", + "Hello3" + ) + + testAllApis( + Func16('f9), + "Func16(f9)", + "Func16(f9)", + "Hello, World" + ) + + try { + testAllApis( + Func17("Hello", "World"), + "Func17('Hello', 'World')", + "Func17('Hello', 'World')", + "Hello, World" + ) + throw new RuntimeException("Shouldn't be reached here!") + } catch { + case ex: ValidationException => + // ok + } + + val JavaFunc2 = new JavaFunc2 + testAllApis( + JavaFunc2("Hi", 1, 3, 5, 7), + "JavaFunc2('Hi', 1, 3, 5, 7)", + "JavaFunc2('Hi', 1, 3, 5, 7)", + "Hi105") + + // test overloading + val JavaFunc3 = new JavaFunc3 + testAllApis( + JavaFunc3("Hi"), + "JavaFunc3('Hi')", + "JavaFunc3('Hi')", + "Hi") + + testAllApis( + JavaFunc3('f1), + "JavaFunc3(f1)", + "JavaFunc3(f1)", + "Test") + } + @Test def testJavaBoxedPrimitives(): Unit = { val JavaFunc0 = new JavaFunc0() @@ -238,7 +317,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { // ---------------------------------------------------------------------------------------------- override def testData: Any = { - val testData = new Row(9) + val testData = new Row(10) testData.setField(0, 42) testData.setField(1, "Test") testData.setField(2, null) @@ -248,6 +327,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10")) testData.setField(7, 12) testData.setField(8, 1000L) + testData.setField(9, Seq("Hello", "World")) testData } @@ -261,7 +341,8 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { Types.TIME, Types.TIMESTAMP, Types.INTERVAL_MONTHS, - Types.INTERVAL_MILLIS + Types.INTERVAL_MILLIS, + TypeInformation.of(classOf[Seq[String]]) ).asInstanceOf[TypeInformation[Any]] } @@ -279,8 +360,14 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "Func10" -> Func10, "Func11" -> Func11, "Func12" -> Func12, + "Func14" -> Func14, + "Func15" -> Func15, + "Func16" -> Func16, + "Func17" -> Func17, "JavaFunc0" -> new JavaFunc0, "JavaFunc1" -> new JavaFunc1, + "JavaFunc2" -> new JavaFunc2, + "JavaFunc3" -> new JavaFunc3, "RichFunc0" -> new RichFunc0, "RichFunc1" -> new RichFunc1, "RichFunc2" -> new RichFunc2 diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala index 1258137df7ed4..e858187912164 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala @@ -28,6 +28,8 @@ import org.junit.Assert import scala.collection.mutable import scala.io.Source +import scala.annotation.varargs + case class SimplePojo(name: String, age: Int) object Func0 extends ScalarFunction { @@ -227,3 +229,37 @@ class Func13(prefix: String) extends ScalarFunction { } } +object Func14 extends ScalarFunction { + + @varargs + def eval(a: Int*): Int = { + a.sum + } +} + +object Func15 extends ScalarFunction { + + @varargs + def eval(a: String, b: Int*): String = { + a + b.length + } + + def eval(a: String): String = { + a + } +} + +object Func16 extends ScalarFunction { + + def eval(a: Seq[String]): String = { + a.mkString(", ") + } +} + +object Func17 extends ScalarFunction { + + // Without @varargs, we will throw an exception + def eval(a: String*): String = { + a.mkString(", ") + } +}