Skip to content

Commit

Permalink
[FLINK-33412] Implement type inference for reinterpret_cast function
Browse files Browse the repository at this point in the history
  • Loading branch information
dawidwys committed Nov 3, 2023
1 parent 7295c3b commit f6b662f
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2254,7 +2254,8 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
BuiltInFunctionDefinition.newBuilder()
.name("reinterpretCast")
.kind(SCALAR)
.outputTypeStrategy(TypeStrategies.MISSING)
.inputTypeStrategy(SpecificInputTypeStrategies.REINTERPRET_CAST)
.outputTypeStrategy(TypeStrategies.argument(1))
.build();

public static final BuiltInFunctionDefinition AS =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.types.inference.strategies;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.logical.LegacyTypeInformationType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
* {@link InputTypeStrategy} specific for {@link BuiltInFunctionDefinitions#REINTERPRET_CAST}.
*
* <p>It expects three arguments where the type of first one must be reinterpretable as the type of
* the second one. The second one must be a type literal. The third a BOOLEAN literal if the
* reinterpretation may result in an overflow.
*/
@Internal
public final class ReinterpretCastInputTypeStrategy implements InputTypeStrategy {
@Override
public ArgumentCount getArgumentCount() {
return ConstantArgumentCount.of(3);
}

@Override
public Optional<List<DataType>> inferInputTypes(
CallContext callContext, boolean throwOnFailure) {
final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();

// check for type literal
if (!callContext.isArgumentLiteral(1)
|| !callContext.getArgumentValue(1, DataType.class).isPresent()) {
return callContext.fail(
throwOnFailure, "Expected type literal for the second argument.");
}

if (!argumentDataTypes.get(2).getLogicalType().is(LogicalTypeRoot.BOOLEAN)
|| !callContext.isArgumentLiteral(2)
|| callContext.isArgumentNull(2)) {
return callContext.fail(
throwOnFailure, "Not null boolean literal expected for overflow.");
}

final LogicalType fromType = argumentDataTypes.get(0).getLogicalType();
final LogicalType toType = argumentDataTypes.get(1).getLogicalType();

// A hack to support legacy types. To be removed when we drop the legacy types.
if (fromType instanceof LegacyTypeInformationType) {
return Optional.of(argumentDataTypes);
}
if (!LogicalTypeCasts.supportsReinterpretCast(fromType, toType)) {
return callContext.fail(
throwOnFailure,
"Unsupported reinterpret cast from '%s' to '%s'.",
fromType,
toType);
}

return Optional.of(argumentDataTypes);
}

@Override
public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
return Collections.singletonList(
Signature.of(
Signature.Argument.ofGroup("ANY"),
Signature.Argument.ofGroup("TYPE LITERAL"),
Signature.Argument.ofGroup("TRUE | FALSE")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public final class SpecificInputTypeStrategies {
/** See {@link CastInputTypeStrategy}. */
public static final InputTypeStrategy CAST = new CastInputTypeStrategy();

public static final InputTypeStrategy REINTERPRET_CAST = new ReinterpretCastInputTypeStrategy();

/** See {@link MapInputTypeStrategy}. */
public static final InputTypeStrategy MAP = new MapInputTypeStrategy();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,54 @@ public static boolean supportsExplicitCast(LogicalType sourceType, LogicalType t
return supportsCasting(sourceType, targetType, true);
}

/**
* Returns whether the source type can be reinterpreted as the target type.
*
* <p>Reinterpret casts correspond to the SQL reinterpret_cast and represent the logic behind a
* {@code REINTERPRET_CAST(sourceType AS targetType)} operation.
*/
public static boolean supportsReinterpretCast(LogicalType sourceType, LogicalType targetType) {
if (sourceType.getTypeRoot() == targetType.getTypeRoot()) {
return true;
}

switch (sourceType.getTypeRoot()) {
case INTEGER:
switch (targetType.getTypeRoot()) {
case DATE:
case TIME_WITHOUT_TIME_ZONE:
case INTERVAL_YEAR_MONTH:
return true;
default:
return false;
}
case BIGINT:
switch (targetType.getTypeRoot()) {
case TIMESTAMP_WITHOUT_TIME_ZONE:
case INTERVAL_DAY_TIME:
return true;
default:
return false;
}
case DATE:
case TIME_WITHOUT_TIME_ZONE:
case INTERVAL_YEAR_MONTH:
switch (targetType.getTypeRoot()) {
case INTEGER:
case BIGINT:
return true;
default:
return false;
}

case TIMESTAMP_WITHOUT_TIME_ZONE:
case INTERVAL_DAY_TIME:
return targetType.getTypeRoot() == BIGINT;
default:
return false;
}
}

// --------------------------------------------------------------------------------------------

private static boolean supportsCasting(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,34 @@ ANY, explicit(DataTypes.INT())
.expectSignature("f(<ARRAY>, <ARRAY ELEMENT>)")
.expectArgumentTypes(
DataTypes.ARRAY(DataTypes.INT().notNull()).notNull(),
DataTypes.INT()));
DataTypes.INT()),
TestSpec.forStrategy(
"Reinterpret_cast strategy",
SpecificInputTypeStrategies.REINTERPRET_CAST)
.calledWithArgumentTypes(
DataTypes.DATE(), DataTypes.BIGINT(), DataTypes.BOOLEAN().notNull())
.calledWithLiteralAt(1, DataTypes.BIGINT())
.calledWithLiteralAt(2, true)
.expectSignature("f(<ANY>, <TYPE LITERAL>, <TRUE | FALSE>)")
.expectArgumentTypes(
DataTypes.DATE(),
DataTypes.BIGINT(),
DataTypes.BOOLEAN().notNull()),
TestSpec.forStrategy(
"Reinterpret_cast strategy non literal overflow",
SpecificInputTypeStrategies.REINTERPRET_CAST)
.calledWithArgumentTypes(
DataTypes.DATE(), DataTypes.BIGINT(), DataTypes.BOOLEAN().notNull())
.calledWithLiteralAt(1, DataTypes.BIGINT())
.expectErrorMessage("Not null boolean literal expected for overflow."),
TestSpec.forStrategy(
"Reinterpret_cast strategy not supported cast",
SpecificInputTypeStrategies.REINTERPRET_CAST)
.calledWithArgumentTypes(
DataTypes.INT(), DataTypes.BIGINT(), DataTypes.BOOLEAN().notNull())
.calledWithLiteralAt(1, DataTypes.BIGINT())
.calledWithLiteralAt(2, true)
.expectErrorMessage("Unsupported reinterpret cast from 'INT' to 'BIGINT'"));
}

/** Simple pojo that should be converted to a Structured type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -92,15 +94,11 @@ private TypeInferenceUtil.Result runTypeInference(
callContextMock.argumentDataTypes = actualArgumentTypes;
callContextMock.argumentLiterals =
IntStream.range(0, actualArgumentTypes.size())
.mapToObj(i -> testSpec.literalPos != null && i == testSpec.literalPos)
.mapToObj(i -> testSpec.literals.containsKey(i))
.collect(Collectors.toList());
callContextMock.argumentValues =
IntStream.range(0, actualArgumentTypes.size())
.mapToObj(
i ->
(testSpec.literalPos != null && i == testSpec.literalPos)
? Optional.ofNullable(testSpec.literalValue)
: Optional.empty())
.mapToObj(i -> Optional.ofNullable(testSpec.literals.get(i)))
.collect(Collectors.toList());
callContextMock.argumentNulls =
IntStream.range(0, actualArgumentTypes.size())
Expand Down Expand Up @@ -161,9 +159,7 @@ protected static class TestSpec {

private List<List<DataType>> actualArgumentTypes = new ArrayList<>();

private @Nullable Integer literalPos;

private @Nullable Object literalValue;
private Map<Integer, Object> literals = new HashMap<>();

private @Nullable InputTypeStrategy surroundingStrategy;

Expand Down Expand Up @@ -207,13 +203,12 @@ public TestSpec calledWithArgumentTypes(AbstractDataType<?>... dataTypes) {
}

public TestSpec calledWithLiteralAt(int pos) {
this.literalPos = pos;
this.literals.put(pos, null);
return this;
}

public TestSpec calledWithLiteralAt(int pos, Object value) {
this.literalPos = pos;
this.literalValue = value;
this.literals.put(pos, value);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,6 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
// special case: requires individual handling of child expressions
func match {

case REINTERPRET_CAST =>
assert(children.size == 3)
return Reinterpret(
children.head.accept(this),
fromDataTypeToTypeInfo(children(1).asInstanceOf[TypeLiteralExpression].getOutputDataType),
getValue[Boolean](children(2).accept(this))
)

case WINDOW_START =>
assert(children.size == 1)
val windowReference = translateWindowReference(children.head)
Expand Down

This file was deleted.

Loading

0 comments on commit f6b662f

Please sign in to comment.