Skip to content

Commit

Permalink
[FLINK-22994][table-planner] Improve the performace of invoking nesti…
Browse files Browse the repository at this point in the history
…ng udf

This closes apache#16163
  • Loading branch information
zicat authored Jul 9, 2021
1 parent 928b689 commit 31fcb6c
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.codegen

import org.apache.flink.table.types.DataType

/**
* Describes a external generated expression.
*
* @param dataType type of the resultTerm
* @param internalTerm term to access the internal result of the expression
* @param externalTerm term to access the external result of the expression
* @param nullTerm boolean term that indicates if expression is null
* @param internalCode code necessary to produce internalTerm and nullTerm
* @param externalCode code necessary to produce externalTerm
* @param literalValue None if the expression is not literal. Otherwise it represent the
* original object of the literal.
*
*/
class ExternalGeneratedExpression(
dataType: DataType,
internalTerm: String,
externalTerm: String,
nullTerm: String,
internalCode: String,
externalCode: String,
literalValue: Option[Any] = None)
extends GeneratedExpression(
internalTerm,
nullTerm,
internalCode,
dataType.getLogicalType,
literalValue) {

def getExternalCode: String = externalCode

def getExternalTerm: String = externalTerm

def getDataType: DataType = dataType

}

object ExternalGeneratedExpression {

def fromGeneratedExpression(
dataType: DataType,
externalTerm: String,
externalCode: String,
generatedExpression: GeneratedExpression)
: ExternalGeneratedExpression = {
new ExternalGeneratedExpression(
dataType,
generatedExpression.resultTerm,
externalTerm,
generatedExpression.nullTerm,
generatedExpression.code, externalCode,
generatedExpression.literalValue)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.flink.table.types.inference.{CallContext, TypeInference, TypeI
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts.supportsAvoidingCast
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{hasRoot, isCompositeType}
import org.apache.flink.table.types.logical.{LogicalType, LogicalTypeRoot, RowType}
import org.apache.flink.table.types.utils.DataTypeUtils.{validateInputDataType, validateOutputDataType}
import org.apache.flink.table.types.utils.DataTypeUtils.{isInternal, validateInputDataType, validateOutputDataType}
import org.apache.flink.util.Preconditions

import java.util.concurrent.CompletableFuture
Expand Down Expand Up @@ -288,16 +288,23 @@ object BridgingFunctionGenUtil {
s"($externalResultTypeTerm) (${typeTerm(externalResultClassBoxed)})"
}
val externalResultTerm = ctx.addReusableLocalVariable(externalResultTypeTerm, "externalResult")
val externalCode =
s"""
|${externalOperands.map(_.code).mkString("\n")}
|$externalResultTerm = $externalResultCasting $functionTerm
| .$SCALAR_EVAL(${externalOperands.map(_.resultTerm).mkString(", ")});
|""".stripMargin

val internalExpr = genToInternalConverterAll(ctx, outputDataType, externalResultTerm)

// function call
internalExpr.copy(code =
val copy = internalExpr.copy(code =
s"""
|${externalOperands.map(_.code).mkString("\n")}
|$externalResultTerm = $externalResultCasting $functionTerm
| .$SCALAR_EVAL(${externalOperands.map(_.resultTerm).mkString(", ")});
|${internalExpr.code}
|""".stripMargin)
|$externalCode
|${internalExpr.code}
|""".stripMargin)

ExternalGeneratedExpression.fromGeneratedExpression(
outputDataType, externalResultTerm, externalCode, copy)
}

private def prepareExternalOperands(
Expand All @@ -308,7 +315,15 @@ object BridgingFunctionGenUtil {
operands
.zip(argumentDataTypes)
.map { case (operand, dataType) =>
operand.copy(resultTerm = genToExternalConverterAll(ctx, dataType, operand))
operand match {
case external: ExternalGeneratedExpression
if !isInternal(dataType) && (external.getDataType == dataType) =>
operand.copy(
resultTerm = external.getExternalTerm,
code = external.getExternalCode)
case _ => operand.copy(
resultTerm = genToExternalConverterAll(ctx, dataType, operand))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.flink.table.planner.runtime.stream.table

import java.util

import org.apache.flink.api.scala._
import org.apache.flink.table.annotation.{DataTypeHint, InputGroup}
import org.apache.flink.table.api._
Expand Down Expand Up @@ -391,6 +393,17 @@ class CalcITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}

@Test
def testOptimizeNestingInvokeScalarFunction(): Unit = {

val t = env.fromElements(1, 2, 3, 4).toTable(tEnv).as("f1")
tEnv.createTemporaryView("t1", t)
tEnv.createTemporaryFunction("func", NestingFunc)
tEnv.sqlQuery("select func(func(f1)) from t1")
.toAppendStream[Row].addSink(new TestingAppendSink)
env.execute()
}

@SerialVersionUID(1L)
object ValidSubStringFilter extends ScalarFunction {
@varargs
Expand All @@ -409,6 +422,18 @@ class CalcITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode
}
}

@SerialVersionUID(1L)
object NestingFunc extends ScalarFunction {
val expected = new util.HashMap[Integer, Integer]()
def eval(a: Integer): util.Map[Integer, Integer] = {
expected
}
def eval(map: util.Map[Integer, Integer] ): util.Map[Integer, Integer] = {
Assert.assertTrue(map.eq(expected))
map
}
}

@Test
def testMapType(): Unit = {
val ds = env.fromCollection(tupleData3).toTable(tEnv).select(map('_1, '_3))
Expand Down

0 comments on commit 31fcb6c

Please sign in to comment.