Skip to content

Commit

Permalink
[FLINK-6436] [table] Fix code-gen bug when using a scalar UDF in a UD…
Browse files Browse the repository at this point in the history
…TF join condition.

This closes apache#3815.
  • Loading branch information
godfreyhe authored and fhueske committed May 9, 2017
1 parent f26a911 commit e2cb221
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.flink.table.plan.nodes

import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexShuttle}
import org.apache.calcite.sql.SemiJoinType
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
Expand Down Expand Up @@ -143,6 +143,17 @@ trait CommonCorrelate[T] {
|getCollector().collect(${crossResultExpr.resultTerm});
|""".stripMargin
} else {

// adjust indicies of InputRefs to adhere to schema expected by generator
val changeInputRefIndexShuttle = new RexShuttle {
override def visitInputRef(inputRef: RexInputRef): RexNode = {
new RexInputRef(inputSchema.physicalArity + inputRef.getIndex, inputRef.getType)
}
}
// Run generateExpression to add init statements (ScalarFunctions) of condition to generator.
// The generated expression is discarded.
generator.generateExpression(condition.get.accept(changeInputRefIndexShuttle))

val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo, None, pojoFieldMapping)
filterGenerator.input1Term = filterGenerator.input2Term
val filterCondition = filterGenerator.generateExpression(condition.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.{ScalarFunction, FunctionContext}
import org.junit.Assert

import scala.annotation.varargs
import scala.collection.mutable
import scala.io.Source

import scala.annotation.varargs

case class SimplePojo(name: String, age: Int)

object Func0 extends ScalarFunction {
Expand Down Expand Up @@ -263,3 +262,9 @@ object Func17 extends ScalarFunction {
a.mkString(", ")
}
}

object Func18 extends ScalarFunction {
def eval(str: String, prefix: String): Boolean = {
str.startsWith(prefix)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
import org.apache.flink.table.expressions.utils.{Func13, RichFunc2}
import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, RichFunc2}
import org.apache.flink.table.utils._
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.TestBaseUtils
Expand Down Expand Up @@ -143,7 +143,7 @@ class DataSetUserDefinedFunctionITCase(
val pojo = new PojoTableFunc()
val result = in
.join(pojo('c))
.where(('age > 20))
.where('age > 20)
.select('c, 'name, 'age)
.toDataSet[Row]

Expand All @@ -170,6 +170,24 @@ class DataSetUserDefinedFunctionITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testUserDefinedTableFunctionWithScalarFunctionInCondition(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env, config)
val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
val func0 = new TableFunc0

val result = in
.join(func0('c))
.where(Func18('name, "J") && (Func1('a) < 3) && Func1('age) > 20)
.select('c, 'name, 'age)
.toDataSet[Row]

val results = result.collect()
val expected = "Jack#22,Jack,22"
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testLongAndTemporalTypes(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
import org.apache.flink.table.expressions.utils.{Func13, RichFunc2}
import org.apache.flink.table.expressions.utils.{Func13, Func18, RichFunc2}
import org.apache.flink.table.utils._
import org.apache.flink.types.Row
import org.junit.Assert._
Expand Down Expand Up @@ -51,7 +51,7 @@ class DataStreamUserDefinedFunctionITCase extends StreamingMultipleProgramsTestB
.join(func0('c) as('d, 'e))
.select('c, 'd, 'e)
.join(pojoFunc0('c))
.where(('age > 20))
.where('age > 20)
.select('c, 'name, 'age)
.toDataStream[Row]

Expand Down Expand Up @@ -81,6 +81,24 @@ class DataStreamUserDefinedFunctionITCase extends StreamingMultipleProgramsTestB
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

@Test
def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
val func0 = new TableFunc0

val result = t
.join(func0('c) as('d, 'e))
.where(Func18('d, "J"))
.select('c, 'd, 'e)
.toDataStream[Row]

result.addSink(new StreamITCase.StringSink)
env.execute()

val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

@Test
def testUserDefinedTableFunctionWithParameter(): Unit = {
val tableFunc1 = new RichTableFunc1
Expand Down

0 comments on commit e2cb221

Please sign in to comment.