Skip to content

Commit

Permalink
[FLINK-13225][table-planner-blink] Fix type inference for hive udf
Browse files Browse the repository at this point in the history
  • Loading branch information
JingsongLi authored and KurtYoung committed Aug 7, 2019
1 parent f695a76 commit e08117f
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.flink.table.functions.ScalarFunctionDefinition;
import org.apache.flink.table.functions.TableFunctionDefinition;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.functions.utils.HiveScalarSqlFunction;
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils;
import org.apache.flink.table.types.utils.TypeConversions;

Expand All @@ -40,6 +41,8 @@
import java.util.List;
import java.util.Optional;

import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.isHiveFunc;

/**
* Thin adapter between {@link SqlOperatorTable} and {@link FunctionCatalog}.
*/
Expand Down Expand Up @@ -92,7 +95,16 @@ private Optional<SqlFunction> convertToSqlFunction(
if (functionDefinition instanceof AggregateFunctionDefinition) {
return convertAggregateFunction(name, (AggregateFunctionDefinition) functionDefinition);
} else if (functionDefinition instanceof ScalarFunctionDefinition) {
return convertScalarFunction(name, (ScalarFunctionDefinition) functionDefinition);
ScalarFunctionDefinition def = (ScalarFunctionDefinition) functionDefinition;
if (isHiveFunc(def.getScalarFunction())) {
return Optional.of(new HiveScalarSqlFunction(
name,
name,
def.getScalarFunction(),
typeFactory));
} else {
return convertScalarFunction(name, def);
}
} else if (functionDefinition instanceof TableFunctionDefinition &&
category != null &&
category.isTableFunction()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.utils;

import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.utils.TypeConversions;

import org.apache.calcite.rel.type.RelDataType;

import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

import static org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType;

/**
* Hack utils for hive function.
*/
public class HiveFunctionUtils {

public static boolean isHiveFunc(Object function) {
try {
getSetArgsMethod(function);
return true;
} catch (NoSuchMethodException e) {
return false;
}
}

private static Method getSetArgsMethod(Object function) throws NoSuchMethodException {
return function.getClass().getMethod(
"setArgumentTypesAndConstants", Object[].class, DataType[].class);

}

static Serializable invokeSetArgs(
Serializable function, Object[] constantArguments, LogicalType[] argTypes) {
try {
// See hive HiveFunction
Method method = getSetArgsMethod(function);
method.invoke(function, constantArguments, TypeConversions.fromLogicalToDataType(argTypes));
return function;
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}

static RelDataType invokeGetResultType(
Object function, Object[] constantArguments, LogicalType[] argTypes,
FlinkTypeFactory typeFactory) {
try {
// See hive HiveFunction
Method method = function.getClass()
.getMethod("getHiveResultType", Object[].class, DataType[].class);
DataType resultType = (DataType) method.invoke(
function, constantArguments, TypeConversions.fromLogicalToDataType(argTypes));
return typeFactory.createFieldTypeFromLogicalType(fromDataTypeToLogicalType(resultType));
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.utils;

import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.util.InstantiationUtil;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.type.SqlReturnTypeInference;

import java.io.IOException;
import java.util.List;

import scala.Some;

import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeGetResultType;
import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeSetArgs;
import static org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getDefaultExternalClassForType;

/**
* Hive {@link ScalarSqlFunction}.
* Override getFunction to clone function and invoke {@code HiveScalarFunction#setArgumentTypesAndConstants}.
* Override SqlReturnTypeInference to invoke {@code HiveScalarFunction#getHiveResultType} instead of
* {@code HiveScalarFunction#getResultType(Class[])}.
*
* @deprecated TODO hack code, its logical should be integrated to ScalarSqlFunction
*/
@Deprecated
public class HiveScalarSqlFunction extends ScalarSqlFunction {

private final ScalarFunction function;

public HiveScalarSqlFunction(
String name, String displayName,
ScalarFunction function, FlinkTypeFactory typeFactory) {
super(name, displayName, function, typeFactory, new Some<>(createReturnTypeInference(function, typeFactory)));
this.function = function;
}

@Override
public ScalarFunction makeFunction(Object[] constantArguments, LogicalType[] argTypes) {
ScalarFunction clone;
try {
clone = InstantiationUtil.clone(function);
} catch (IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
return (ScalarFunction) invokeSetArgs(clone, constantArguments, argTypes);
}

private static SqlReturnTypeInference createReturnTypeInference(
ScalarFunction function, FlinkTypeFactory typeFactory) {
return opBinding -> {
List<RelDataType> sqlTypes = opBinding.collectOperandTypes();
LogicalType[] parameters = UserDefinedFunctionUtils.getOperandTypeArray(opBinding);

Object[] constantArguments = new Object[sqlTypes.size()];
for (int i = 0; i < sqlTypes.size(); i++) {
if (!opBinding.isOperandNull(i, false) && opBinding.isOperandLiteral(i, false)) {
constantArguments[i] = opBinding.getOperandLiteralValue(
i, getDefaultExternalClassForType(parameters[i]));
}
}
return invokeGetResultType(function, constantArguments, parameters, typeFactory);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen

import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.api.TableException
import org.apache.flink.table.dataformat.DataFormatConverters.{DataFormatConverter, getConverterForDataType}
import org.apache.flink.table.dataformat._
import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, RexAggLocalVariable, RexDistinctKeyVariable}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{requireTemporal, requireTimeInterval, _}
Expand All @@ -30,6 +31,7 @@ import org.apache.flink.table.planner.codegen.calls.{FunctionGenerator, ScalarFu
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._
import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction
import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction}
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval}
Expand Down Expand Up @@ -730,7 +732,9 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
GeneratedExpression(nullValue.resultTerm, nullValue.nullTerm, code, resultType)

case ssf: ScalarSqlFunction =>
new ScalarFunctionCallGen(ssf.getScalarFunction).generate(ctx, operands, resultType)
new ScalarFunctionCallGen(
ssf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
.generate(ctx, operands, resultType)

case tsf: TableSqlFunction =>
new TableFunctionCallGen(tsf.getTableFunction).generate(ctx, operands, resultType)
Expand All @@ -757,4 +761,16 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
throw new CodeGenException(s"Unsupported call: $explainCall")
}
}

def getOperandLiterals(operands: Seq[GeneratedExpression]): Array[AnyRef] = {
operands.map { expr =>
expr.literalValue match {
case None => null
case Some(literal) =>
getConverterForDataType(fromLogicalTypeToDataType(expr.resultType))
.asInstanceOf[DataFormatConverter[AnyRef, AnyRef]
].toExternal(literal.asInstanceOf[AnyRef])
}
}.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.{
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getDefaultExternalClassForType
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.runtime.types.TypeInfoLogicalTypeConverter.fromTypeInfoToLogicalType
import org.apache.flink.table.types.logical.LogicalType

import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.sql._
Expand All @@ -47,16 +48,18 @@ class ScalarSqlFunction(
name: String,
displayName: String,
scalarFunction: ScalarFunction,
typeFactory: FlinkTypeFactory)
typeFactory: FlinkTypeFactory,
returnTypeInfer: Option[SqlReturnTypeInference] = None)
extends SqlFunction(
new SqlIdentifier(name, SqlParserPos.ZERO),
createReturnTypeInference(name, scalarFunction, typeFactory),
returnTypeInfer.getOrElse(createReturnTypeInference(name, scalarFunction, typeFactory)),
createOperandTypeInference(name, scalarFunction, typeFactory),
createOperandTypeChecker(name, scalarFunction),
null,
SqlFunctionCategory.USER_DEFINED_FUNCTION) {

def getScalarFunction: ScalarFunction = scalarFunction
def makeFunction(constants: Array[AnyRef], argTypes: Array[LogicalType]): ScalarFunction =
scalarFunction

override def isDeterministic: Boolean = scalarFunction.isDeterministic

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,10 @@ object UserDefinedFunctionUtils {
}
}

def getOperandTypeArray(callBinding: SqlOperatorBinding): Array[LogicalType] = {
getOperandType(callBinding).toArray
}

def getOperandType(callBinding: SqlOperatorBinding): Seq[LogicalType] = {
val operandTypes = for (i <- 0 until callBinding.getOperandCount)
yield callBinding.getOperandType(i)
Expand Down

0 comments on commit e08117f

Please sign in to comment.