Skip to content

Commit

Permalink
[FLINK-7678] [table] Support composite inputs for user-defined functions
Browse files Browse the repository at this point in the history
This closes apache#4726.
  • Loading branch information
twalthr committed Nov 15, 2017
1 parent 11218a3 commit 54eeccf
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
package org.apache.flink.table.api.stream.sql

import org.apache.flink.api.scala._
import org.apache.flink.table.api.Types
import org.apache.flink.table.api.scala._
import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.JavaVarsArgTableFunc0
import org.apache.flink.table.utils.TableTestUtil._
import org.apache.flink.table.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc2, _}
import org.apache.flink.types.Row
import org.junit.Test

class CorrelateTest extends TableTestBase {
Expand Down Expand Up @@ -226,6 +228,39 @@ class CorrelateTest extends TableTestBase {
util.verifySql(sqlQuery, expected)
}

@Test
def testRowType(): Unit = {
val util = streamTestUtil()
val rowType = Types.ROW(Types.INT, Types.BOOLEAN, Types.ROW(Types.INT, Types.INT, Types.INT))
util.addTable[Row]("MyTable", 'a, 'b, 'c)(rowType)
val function = new TableFunc5
util.addFunction("tableFunc5", function)

val sqlQuery = "SELECT c, tf.f2 FROM MyTable, LATERAL TABLE(tableFunc5(c)) AS tf"

val expected = unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", "tableFunc5($cor0.c)"),
term("correlate", "table(tableFunc5($cor0.c))"),
term("select", "a", "b", "c", "f0", "f1", "f2"),
term("rowType", "RecordType(" +
"INTEGER a, " +
"BOOLEAN b, " +
"COMPOSITE(Row(f0: Integer, f1: Integer, f2: Integer)) c, " +
"INTEGER f0, " +
"INTEGER f1, " +
"INTEGER f2)"),
term("joinType", "INNER")
),
term("select", "c", "f2")
)

util.verifySql(sqlQuery, expected)
}

@Test
def testFilter(): Unit = {
val util = streamTestUtil()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils.{ExpressionTestBase, _}
import org.apache.flink.table.functions.ScalarFunction
import org.junit.Test
import java.lang.{Boolean => JBoolean}

class UserDefinedScalarFunctionTest extends ExpressionTestBase {

Expand Down Expand Up @@ -107,6 +108,14 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"Nullable(f0)",
"Nullable(f0)",
"42")

// test row type input
testAllApis(
Func19('f14),
"Func19(f14)",
"Func19(f14)",
"12,true,1,2,3"
)
}

@Test
Expand Down Expand Up @@ -368,7 +377,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
// ----------------------------------------------------------------------------------------------

override def testData: Any = {
val testData = new Row(14)
val testData = new Row(15)
testData.setField(0, 42)
testData.setField(1, "Test")
testData.setField(2, null)
Expand All @@ -383,6 +392,11 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
testData.setField(11, 3.toByte)
testData.setField(12, 3.toShort)
testData.setField(13, 3.toFloat)
testData.setField(14, Row.of(
12.asInstanceOf[Integer],
true.asInstanceOf[JBoolean],
Row.of(1.asInstanceOf[Integer], 2.asInstanceOf[Integer], 3.asInstanceOf[Integer]))
)
testData
}

Expand All @@ -401,7 +415,8 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO,
Types.BYTE,
Types.SHORT,
Types.FLOAT
Types.FLOAT,
Types.ROW(Types.INT, Types.BOOLEAN, Types.ROW(Types.INT, Types.INT, Types.INT))
).asInstanceOf[TypeInformation[Any]]
}

Expand All @@ -427,6 +442,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"Func15" -> Func15,
"Func16" -> Func16,
"Func17" -> Func17,
"Func19" -> Func19,
"JavaFunc0" -> new JavaFunc0,
"JavaFunc1" -> new JavaFunc1,
"JavaFunc2" -> new JavaFunc2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import java.sql.{Date, Time, Timestamp}

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.{ScalarFunction, FunctionContext}
import org.apache.flink.table.functions.{FunctionContext, ScalarFunction}
import org.apache.flink.types.Row
import org.junit.Assert

import scala.annotation.varargs
Expand Down Expand Up @@ -274,3 +275,16 @@ object Func18 extends ScalarFunction {
str.startsWith(prefix)
}
}

object Func19 extends ScalarFunction {
def eval(row: Row): Row = {
row
}

override def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] =
Array(Types.ROW(Types.INT, Types.BOOLEAN, Types.ROW(Types.INT, Types.INT, Types.INT)))

override def getResultType(signature: Array[Class[_]]): TypeInformation[_] =
Types.ROW(Types.INT, Types.BOOLEAN, Types.ROW(Types.INT, Types.INT, Types.INT))

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ import java.sql.{Date, Timestamp}

import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.table.api.{TableEnvironment, TableException, Types, ValidationException}
import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.JavaTableFunc0
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, RichFunc2}
import org.apache.flink.table.runtime.utils.TableProgramsClusterTestBase
import org.apache.flink.table.api.{TableEnvironment, Types, ValidationException}
import org.apache.flink.table.expressions.utils.{Func1, Func18, RichFunc2}
import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.JavaTableFunc0
import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode
import org.apache.flink.table.runtime.utils._
import org.apache.flink.table.runtime.utils.{TableProgramsClusterTestBase, _}
import org.apache.flink.table.utils._
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.TestBaseUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
*/
package org.apache.flink.table.runtime.stream.table

import java.lang.{Boolean => JBoolean}

import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
import org.apache.flink.table.api.{TableEnvironment, ValidationException}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableEnvironment, Types, ValidationException}
import org.apache.flink.table.expressions.utils.{Func18, RichFunc2}
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData}
import org.apache.flink.table.runtime.utils._
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, _}
import org.apache.flink.table.utils._
import org.apache.flink.types.Row
import org.junit.Assert._
Expand Down Expand Up @@ -231,6 +232,31 @@ class CorrelateITCase extends StreamingMultipleProgramsTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

@Test
def testRowType(): Unit = {
val row = Row.of(
12.asInstanceOf[Integer],
true.asInstanceOf[JBoolean],
Row.of(1.asInstanceOf[Integer], 2.asInstanceOf[Integer], 3.asInstanceOf[Integer])
)

val rowType = Types.ROW(Types.INT, Types.BOOLEAN, Types.ROW(Types.INT, Types.INT, Types.INT))
val in = env.fromElements(row, row)(rowType).toTable(tEnv).as('a, 'b, 'c)

val tableFunc5 = new TableFunc5()
val result = in
.join(tableFunc5('c) as ('f0, 'f1, 'f2))
.select('c, 'f2)

result.addSink(new StreamITCase.StringSink[Row])
env.execute()

val expected = mutable.MutableList(
"1,2,3,3",
"1,2,3,3")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

private def testData(
env: StreamExecutionEnvironment)
: DataStream[(Int, Long, String)] = {
Expand All @@ -242,5 +268,4 @@ class CorrelateITCase extends StreamingMultipleProgramsTestBase {
data.+=((4, 3L, "nosharp"))
env.fromCollection(data)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,26 @@ class TableFunc4 extends TableFunction[Row] {
}
}

class TableFunc5 extends TableFunction[Row] {
def eval(row: Row): Unit = {
collect(row)
}

override def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] =
Array(Types.ROW(Types.INT, Types.INT, Types.INT))

override def getResultType: TypeInformation[Row] =
Types.ROW(Types.INT, Types.INT, Types.INT)

}

class VarArgsFunc0 extends TableFunction[String] {
@varargs
def eval(str: String*): Unit = {
str.foreach(collect)
}
}

class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] {
def eval(user: String) {
if (user.contains("#")) {
Expand Down Expand Up @@ -215,10 +235,3 @@ class RichTableFunc1 extends TableFunction[String] {
separator = None
}
}

class VarArgsFunc0 extends TableFunction[String] {
@varargs
def eval(str: String*): Unit = {
str.foreach(collect)
}
}

0 comments on commit 54eeccf

Please sign in to comment.