Skip to content

Commit

Permalink
[FLINK-15853][hive][table-planner-blink] Use the new type inference f…
Browse files Browse the repository at this point in the history
…or hive udf

This closes apache#13144
  • Loading branch information
lirui-apache authored Aug 31, 2020
1 parent 56ee6b4 commit 7646188
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.flink.table.factories.FunctionDefinitionFactory;
import org.apache.flink.table.functions.AggregateFunctionDefinition;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.ScalarFunctionDefinition;
import org.apache.flink.table.functions.TableFunctionDefinition;
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
import org.apache.flink.table.functions.hive.HiveFunctionWrapper;
Expand Down Expand Up @@ -94,17 +93,11 @@ public FunctionDefinition createFunctionDefinitionFromHiveFunction(String name,
if (UDF.class.isAssignableFrom(clazz)) {
LOG.info("Transforming Hive function '{}' into a HiveSimpleUDF", name);

return new ScalarFunctionDefinition(
name,
new HiveSimpleUDF(new HiveFunctionWrapper<>(functionClassName), hiveShim)
);
return new HiveSimpleUDF(new HiveFunctionWrapper<>(functionClassName), hiveShim);
} else if (GenericUDF.class.isAssignableFrom(clazz)) {
LOG.info("Transforming Hive function '{}' into a HiveGenericUDF", name);

return new ScalarFunctionDefinition(
name,
new HiveGenericUDF(new HiveFunctionWrapper<>(functionClassName), hiveShim)
);
return new HiveGenericUDF(new HiveFunctionWrapper<>(functionClassName), hiveShim);
} else if (GenericUDTF.class.isAssignableFrom(clazz)) {
LOG.info("Transforming Hive function '{}' into a HiveGenericUDTF", name);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,15 @@ public Object evalInternal(Object[] args) {
}

@Override
public DataType getHiveResultType(Object[] constantArguments, DataType[] argTypes) {
protected DataType inferReturnType() throws UDFArgumentException {
LOG.info("Getting result type of HiveGenericUDF from {}", hiveFunctionWrapper.getClassName());
ObjectInspector[] argumentInspectors = HiveInspectors.toInspectors(hiveShim, constantArguments, argTypes);

try {
ObjectInspector[] argumentInspectors = HiveInspectors.toInspectors(hiveShim, constantArguments, argTypes);

ObjectInspector resultObjectInspector =
createFunction().initializeAndFoldConstants(argumentInspectors);
ObjectInspector resultObjectInspector =
createFunction().initializeAndFoldConstants(argumentInspectors);

return HiveTypeUtil.toFlinkType(
return HiveTypeUtil.toFlinkType(
TypeInfoUtils.getTypeInfoFromObjectInspector(resultObjectInspector));
} catch (UDFArgumentException e) {
throw new FlinkHiveUDFException(e);
}
}

private GenericUDF createFunction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,36 @@
package org.apache.flink.table.functions.hive;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.hive.util.HiveFunctionUtil;
import org.apache.flink.table.runtime.types.TypeInfoDataTypeConverter;
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategy;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
* Abstract class to provide more information for Hive {@link UDF} and {@link GenericUDF} functions.
*/
@Internal
public abstract class HiveScalarFunction<UDFType> extends ScalarFunction implements HiveFunction {
public abstract class HiveScalarFunction<UDFType> extends ScalarFunction {

protected final HiveFunctionWrapper<UDFType> hiveFunctionWrapper;

Expand All @@ -50,12 +64,6 @@ public abstract class HiveScalarFunction<UDFType> extends ScalarFunction impleme
this.hiveFunctionWrapper = hiveFunctionWrapper;
}

@Override
public void setArgumentTypesAndConstants(Object[] constantArguments, DataType[] argTypes) {
this.constantArguments = constantArguments;
this.argTypes = argTypes;
}

@Override
public boolean isDeterministic() {
try {
Expand All @@ -69,19 +77,21 @@ public boolean isDeterministic() {
}
}

@Override
public TypeInformation getResultType(Class[] signature) {
return TypeInfoDataTypeConverter.fromDataTypeToTypeInfo(
getHiveResultType(this.constantArguments, this.argTypes));
}

@Override
public void open(FunctionContext context) {
openInternal();

isArgsSingleArray = HiveFunctionUtil.isSingleBoxedArray(argTypes);
}

@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
TypeInference.Builder builder = TypeInference.newBuilder();
builder.inputTypeStrategy(new HiveUDFInputStrategy());
builder.outputTypeStrategy(new HiveUDFOutputStrategy());
return builder.build();
}

/**
* See {@link ScalarFunction#open(FunctionContext)}.
*/
Expand All @@ -104,4 +114,66 @@ public Object eval(Object... args) {
* Evaluation logical, args will be wrapped when is a single array.
*/
protected abstract Object evalInternal(Object[] args);

private void setArguments(CallContext callContext) {
DataType[] inputTypes = callContext.getArgumentDataTypes().toArray(new DataType[0]);
Object[] constantArgs = new Object[inputTypes.length];
for (int i = 0; i < constantArgs.length; i++) {
if (callContext.isArgumentLiteral(i)) {
constantArgs[i] = callContext.getArgumentValue(
i, ClassLogicalTypeConverter.getDefaultExternalClassForType(inputTypes[i].getLogicalType()))
.orElse(null);
}
}
this.constantArguments = constantArgs;
this.argTypes = inputTypes;
}

/**
* Infer return type of this function call.
*/
protected abstract DataType inferReturnType() throws UDFArgumentException;

private class HiveUDFOutputStrategy implements TypeStrategy {

@Override
public Optional<DataType> inferType(CallContext callContext) {
setArguments(callContext);
try {
return Optional.of(inferReturnType());
} catch (UDFArgumentException e) {
throw new FlinkHiveUDFException(e);
}
}
}

private class HiveUDFInputStrategy implements InputTypeStrategy {

@Override
public ArgumentCount getArgumentCount() {
return ConstantArgumentCount.any();
}

@Override
public Optional<List<DataType>> inferInputTypes(CallContext callContext, boolean throwOnFailure) {
setArguments(callContext);
try {
inferReturnType();
} catch (UDFArgumentException e) {
if (throwOnFailure) {
throw new ValidationException(
String.format("Cannot find a suitable Hive function from %s for the input arguments",
hiveFunctionWrapper.getClassName()), e);
} else {
return Optional.empty();
}
}
return Optional.of(callContext.getArgumentDataTypes());
}

@Override
public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
return Collections.singletonList(Signature.of(Signature.Argument.of("*")));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,14 @@ public Object evalInternal(Object[] args) {
}

@Override
public DataType getHiveResultType(Object[] constantArguments, DataType[] argTypes) {
try {
List<TypeInfo> argTypeInfo = new ArrayList<>();
for (DataType argType : argTypes) {
argTypeInfo.add(HiveTypeUtil.toHiveTypeInfo(argType, false));
}
protected DataType inferReturnType() throws UDFArgumentException {
List<TypeInfo> argTypeInfo = new ArrayList<>();
for (DataType argType : argTypes) {
argTypeInfo.add(HiveTypeUtil.toHiveTypeInfo(argType, false));
}

Method evalMethod = hiveFunctionWrapper.createFunction().getResolver().getEvalMethod(argTypeInfo);
return HiveTypeUtil.toFlinkType(
Method evalMethod = hiveFunctionWrapper.createFunction().getResolver().getEvalMethod(argTypeInfo);
return HiveTypeUtil.toFlinkType(
ObjectInspectorFactory.getReflectionObjectInspector(evalMethod.getGenericReturnType(), ObjectInspectorFactory.ObjectInspectorOptions.JAVA));
} catch (UDFArgumentException e) {
throw new FlinkHiveUDFException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.hive.client.HiveShim;
import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
import org.apache.flink.table.functions.hive.HiveSimpleUDFTest.HiveUDFCallContext;
import org.apache.flink.table.functions.hive.util.TestGenericUDFArray;
import org.apache.flink.table.functions.hive.util.TestGenericUDFStructSize;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.types.Row;

import org.apache.hadoop.hive.ql.udf.UDFUnhex;
Expand Down Expand Up @@ -391,8 +393,8 @@ public void testStruct() {
private static HiveGenericUDF init(Class hiveUdfClass, Object[] constantArgs, DataType[] argTypes) {
HiveGenericUDF udf = new HiveGenericUDF(new HiveFunctionWrapper(hiveUdfClass.getName()), hiveShim);

udf.setArgumentTypesAndConstants(constantArgs, argTypes);
udf.getHiveResultType(constantArgs, argTypes);
CallContext callContext = new HiveUDFCallContext(constantArgs, argTypes);
udf.getTypeInference(null).getOutputTypeStrategy().inferType(callContext);

udf.open(null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
package org.apache.flink.table.functions.hive;

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.catalog.hive.client.HiveShim;
import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.hive.util.TestHiveUDFArray;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.CallContext;

import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.udf.UDFBase64;
Expand All @@ -40,6 +43,9 @@
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
Expand Down Expand Up @@ -248,8 +254,8 @@ protected static HiveSimpleUDF init(Class hiveUdfClass, DataType[] argTypes) {
HiveSimpleUDF udf = new HiveSimpleUDF(new HiveFunctionWrapper(hiveUdfClass.getName()), hiveShim);

// Hive UDF won't have literal args
udf.setArgumentTypesAndConstants(new Object[0], argTypes);
udf.getHiveResultType(new Object[0], argTypes);
CallContext callContext = new HiveUDFCallContext(new Object[0], argTypes);
udf.getTypeInference(null).getOutputTypeStrategy().inferType(callContext);

udf.open(null);

Expand Down Expand Up @@ -291,4 +297,58 @@ public String evaluate(String content) {
return content;
}
}

/**
* A CallContext implementation for Hive UDF tests.
*/
public static class HiveUDFCallContext implements CallContext {

private final Object[] constantArgs;
private final DataType[] argTypes;

public HiveUDFCallContext(Object[] constantArgs, DataType[] argTypes) {
this.constantArgs = constantArgs;
this.argTypes = argTypes;
}

@Override
public DataTypeFactory getDataTypeFactory() {
return null;
}

@Override
public FunctionDefinition getFunctionDefinition() {
return null;
}

@Override
public boolean isArgumentLiteral(int pos) {
return pos >= 0 && pos < constantArgs.length && constantArgs[pos] != null;
}

@Override
public boolean isArgumentNull(int pos) {
return false;
}

@Override
public <T> Optional<T> getArgumentValue(int pos, Class<T> clazz) {
return (Optional<T>) Optional.of(constantArgs[pos]);
}

@Override
public String getName() {
return null;
}

@Override
public List<DataType> getArgumentDataTypes() {
return Arrays.asList(argTypes);
}

@Override
public Optional<DataType> getOutputDataType() {
return Optional.empty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import org.apache.flink.table.catalog.hive.HiveTestUtils;
import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.ScalarFunctionDefinition;
import org.apache.flink.table.functions.hive.HiveSimpleUDF;
import org.apache.flink.table.functions.hive.HiveSimpleUDFTest.HiveUDFCallContext;
import org.apache.flink.table.module.CoreModule;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.types.Row;

import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
Expand Down Expand Up @@ -86,16 +86,14 @@ public void testNumberOfBuiltinFunctions() {
@Test
public void testHiveBuiltInFunction() {
FunctionDefinition fd = new HiveModule().getFunctionDefinition("reverse").get();

ScalarFunction func = ((ScalarFunctionDefinition) fd).getScalarFunction();
HiveSimpleUDF udf = (HiveSimpleUDF) func;
HiveSimpleUDF udf = (HiveSimpleUDF) fd;

DataType[] inputType = new DataType[] {
DataTypes.STRING()
};

udf.setArgumentTypesAndConstants(new Object[0], inputType);
udf.getHiveResultType(new Object[0], inputType);
CallContext callContext = new HiveUDFCallContext(new Object[0], inputType);
udf.getTypeInference(null).getOutputTypeStrategy().inferType(callContext);

udf.open(null);

Expand Down
Loading

0 comments on commit 7646188

Please sign in to comment.