Skip to content

Commit

Permalink
[FLINK-13314][table-planner-blink] Correct resultType of some Planner…
Browse files Browse the repository at this point in the history
…Expression when operands contains DecimalTypeInfo or BigDecimalTypeInfo in Blink planner

This also fix some minor bugs:
- Fix minor bug in RexNodeConverter when convert between and not between to RexNode.
- Fix minor bug in PlannerExpressionConverter when convert DataType to TypeInformation.

This closes apache#9152
  • Loading branch information
beyond1920 authored and wuchong committed Jul 23, 2019
1 parent 399de8b commit 0002032
Show file tree
Hide file tree
Showing 7 changed files with 1,053 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -368,16 +368,24 @@ private RexNode convertIsNull(List<Expression> children) {

private RexNode convertNotBetween(List<Expression> children) {
List<RexNode> childrenRexNode = convertCallChildren(children);
Preconditions.checkArgument(childrenRexNode.size() == 3);
RexNode expr = childrenRexNode.get(0);
RexNode lowerBound = childrenRexNode.get(1);
RexNode upperBound = childrenRexNode.get(2);
return relBuilder.or(
relBuilder.call(FlinkSqlOperatorTable.LESS_THAN, childrenRexNode),
relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN, childrenRexNode));
relBuilder.call(FlinkSqlOperatorTable.LESS_THAN, expr, lowerBound),
relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN, expr, upperBound));
}

private RexNode convertBetween(List<Expression> children) {
List<RexNode> childrenRexNode = convertCallChildren(children);
Preconditions.checkArgument(childrenRexNode.size() == 3);
RexNode expr = childrenRexNode.get(0);
RexNode lowerBound = childrenRexNode.get(1);
RexNode upperBound = childrenRexNode.get(2);
return relBuilder.and(
relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, childrenRexNode),
relBuilder.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, childrenRexNode));
relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, expr, lowerBound),
relBuilder.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, expr, upperBound));
}

private RexNode convertCeil(List<Expression> children) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.flink.table.expressions.{E => PlannerE, UUID => PlannerUUID}
import org.apache.flink.table.functions._
import org.apache.flink.table.types.logical.LogicalTypeRoot.{CHAR, DECIMAL, SYMBOL, TIMESTAMP_WITHOUT_TIME_ZONE}
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks._
import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo
import org.apache.flink.table.types.TypeInfoDataTypeConverter.fromDataTypeToTypeInfo

import _root_.scala.collection.JavaConverters._

Expand Down Expand Up @@ -53,14 +53,14 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
assert(children.size == 2)
return Cast(
children.head.accept(this),
fromDataTypeToLegacyInfo(
fromDataTypeToTypeInfo(
children(1).asInstanceOf[TypeLiteralExpression].getOutputDataType))

case REINTERPRET_CAST =>
assert(children.size == 3)
Reinterpret(
children.head.accept(this),
fromDataTypeToLegacyInfo(
fromDataTypeToTypeInfo(
children(1).asInstanceOf[TypeLiteralExpression].getOutputDataType),
getValue[Boolean](children(2).accept(this)))

Expand Down Expand Up @@ -749,7 +749,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
}
}

fromDataTypeToLegacyInfo(literal.getOutputDataType)
fromDataTypeToTypeInfo(literal.getOutputDataType)
}

private def getSymbol(symbol: TableSymbol): PlannerSymbol = symbol match {
Expand Down Expand Up @@ -786,7 +786,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
override def visit(fieldReference: FieldReferenceExpression): PlannerExpression = {
PlannerResolvedFieldReference(
fieldReference.getName,
fromDataTypeToLegacyInfo(fieldReference.getOutputDataType))
fromDataTypeToTypeInfo(fieldReference.getOutputDataType))
}

override def visit(fieldReference: UnresolvedReferenceExpression)
Expand Down Expand Up @@ -834,7 +834,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp

private def translateWindowReference(reference: Expression): PlannerExpression = reference match {
case expr : LocalReferenceExpression =>
WindowReference(expr.getName, Some(fromDataTypeToLegacyInfo(expr.getOutputDataType)))
WindowReference(expr.getName, Some(fromDataTypeToTypeInfo(expr.getOutputDataType)))
//just because how the datastream is converted to table
case expr: UnresolvedReferenceExpression =>
UnresolvedFieldReference(expr.getName)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.expressions

import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.table.api.TableException
import org.apache.flink.table.calcite.{FlinkTypeFactory, FlinkTypeSystem}
import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.{fromLogicalTypeToTypeInfo, fromTypeInfoToLogicalType}
import org.apache.flink.table.types.logical.{DecimalType, LogicalType}
import org.apache.flink.table.typeutils.{BigDecimalTypeInfo, DecimalTypeInfo, TypeCoercion}

import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.sql.`type`.SqlTypeUtil

import scala.collection.JavaConverters._

object ReturnTypeInference {

private lazy val typeSystem = new FlinkTypeSystem
private lazy val typeFactory = new FlinkTypeFactory(typeSystem)

/**
* Infer resultType of [[Minus]] expression.
* The decimal type inference keeps consistent with Calcite
* [[org.apache.calcite.sql.type.ReturnTypes.NULLABLE_SUM]] which is the return type of
* [[org.apache.calcite.sql.fun.SqlStdOperatorTable.MINUS]].
*
* @param minus minus Expression
* @return result type
*/
def inferMinus(minus: Minus): TypeInformation[_] = inferPlusOrMinus(minus)

/**
* Infer resultType of [[Plus]] expression.
* The decimal type inference keeps consistent with Calcite
* [[org.apache.calcite.sql.type.ReturnTypes.NULLABLE_SUM]] which is the return type of
* * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS]].
*
* @param plus plus Expression
* @return result type
*/
def inferPlus(plus: Plus): TypeInformation[_] = inferPlusOrMinus(plus)

private def inferPlusOrMinus(op: BinaryArithmetic): TypeInformation[_] = {
val decimalTypeInference = (
leftType: RelDataType,
rightType: RelDataType,
wideResultType: LogicalType) => {
if (SqlTypeUtil.isExactNumeric(leftType) &&
SqlTypeUtil.isExactNumeric(rightType) &&
(SqlTypeUtil.isDecimal(leftType) || SqlTypeUtil.isDecimal(rightType))) {
val lp = leftType.getPrecision
val ls = leftType.getScale
val rp = rightType.getPrecision
val rs = rightType.getScale
val scale = Math.max(ls, rs)
assert(scale <= typeSystem.getMaxNumericScale)
var precision = Math.max(lp - ls, rp - rs) + scale + 1
precision = Math.min(precision, typeSystem.getMaxNumericPrecision)
assert(precision > 0)
fromLogicalTypeToTypeInfo(wideResultType) match {
case _: DecimalTypeInfo => DecimalTypeInfo.of(precision, scale)
case _: BigDecimalTypeInfo => BigDecimalTypeInfo.of(precision, scale)
}
} else {
val resultType = typeFactory.leastRestrictive(
List(leftType, rightType).asJava)
fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(resultType))
}
}
inferBinaryArithmetic(op, decimalTypeInference, t => fromLogicalTypeToTypeInfo(t))
}

/**
* Infer resultType of [[Mul]] expression.
* The decimal type inference keeps consistent with Calcite
* [[org.apache.calcite.sql.type.ReturnTypes.PRODUCT_NULLABLE]] which is the return type of
* * * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.MULTIPLY]].
*
* @param mul mul Expression
* @return result type
*/
def inferMul(mul: Mul): TypeInformation[_] = {
val decimalTypeInference = (
leftType: RelDataType,
rightType: RelDataType) => typeFactory.createDecimalProduct(leftType, rightType)
inferDivOrMul(mul, decimalTypeInference)
}

/**
* Infer resultType of [[Div]] expression.
* The decimal type inference keeps consistent with
* [[org.apache.flink.table.calcite.type.FlinkReturnTypes.FLINK_QUOTIENT_NULLABLE]] which
* is the return type of [[org.apache.flink.table.functions.sql.FlinkSqlOperatorTable.DIVIDE]].
*
* @param div div Expression
* @return result type
*/
def inferDiv(div: Div): TypeInformation[_] = {
val decimalTypeInference = (
leftType: RelDataType,
rightType: RelDataType) => typeFactory.createDecimalQuotient(leftType, rightType)
inferDivOrMul(div, decimalTypeInference)
}

private def inferDivOrMul(
op: BinaryArithmetic,
decimalTypeInfer: (RelDataType, RelDataType) => RelDataType
): TypeInformation[_] = {
val decimalFunc = (
leftType: RelDataType,
rightType: RelDataType,
_: LogicalType) => {
val decimalType = decimalTypeInfer(leftType, rightType)
if (decimalType != null) {
fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(decimalType))
} else {
val resultType = typeFactory.leastRestrictive(
List(leftType, rightType).asJava)
fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(resultType))
}
}
val nonDecimalType = op match {
case _: Div => (_: LogicalType) => BasicTypeInfo.DOUBLE_TYPE_INFO
case _: Mul => (t: LogicalType) => fromLogicalTypeToTypeInfo(t)
}
inferBinaryArithmetic(op, decimalFunc, nonDecimalType)
}

private def inferBinaryArithmetic(
binaryOp: BinaryArithmetic,
decimalInfer: (RelDataType, RelDataType, LogicalType) => TypeInformation[_],
nonDecimalInfer: LogicalType => TypeInformation[_]
): TypeInformation[_] = {
val leftType = fromTypeInfoToLogicalType(binaryOp.left.resultType)
val rightType = fromTypeInfoToLogicalType(binaryOp.right.resultType)
TypeCoercion.widerTypeOf(leftType, rightType) match {
case Some(t: DecimalType) =>
val leftRelDataType = typeFactory.createFieldTypeFromLogicalType(leftType)
val rightRelDataType = typeFactory.createFieldTypeFromLogicalType(rightType)
decimalInfer(leftRelDataType, rightRelDataType, t)
case Some(t) => nonDecimalInfer(t)
case None => throw new TableException("This will not happen here!")
}
}

/**
* Infer resultType of [[Round]] expression.
* The decimal type inference keeps consistent with Calcite
* [[org.apache.flink.table.calcite.type.FlinkReturnTypes]].ROUND_FUNCTION_NULLABLE
*
* @param round round Expression
* @return result type
*/
def inferRound(round: Round): TypeInformation[_] = {
val numType = round.left.resultType
numType match {
case _: DecimalTypeInfo | _: BigDecimalTypeInfo =>
val lenValue = round.right match {
case Literal(v: Int, BasicTypeInfo.INT_TYPE_INFO) => v
case _ => throw new TableException("This will not happen here!")
}
val numLogicalType = fromTypeInfoToLogicalType(numType)
val numRelDataType = typeFactory.createFieldTypeFromLogicalType(numLogicalType)
val p = numRelDataType.getPrecision
val s = numRelDataType.getScale
val dt = FlinkTypeSystem.inferRoundType(p, s, lenValue)
fromLogicalTypeToTypeInfo(dt)
case t => t
}
}

/**
* Infer resultType of [[Floor]] expression.
* The decimal type inference keeps consistent with Calcite
* [[org.apache.calcite.sql.type.ReturnTypes]].ARG0_OR_EXACT_NO_SCALE
*
* @param floor floor Expression
* @return result type
*/
def inferFloor(floor: Floor): TypeInformation[_] = getArg0OrExactNoScale(floor)

/**
* Infer resultType of [[Ceil]] expression.
* The decimal type inference keeps consistent with Calcite
* [[org.apache.calcite.sql.type.ReturnTypes]].ARG0_OR_EXACT_NO_SCALE
*
* @param ceil ceil Expression
* @return result type
*/
def inferCeil(ceil: Ceil): TypeInformation[_] = getArg0OrExactNoScale(ceil)

private def getArg0OrExactNoScale(op: UnaryExpression) = {
val childType = op.child.resultType
childType match {
case t: DecimalTypeInfo => DecimalTypeInfo.of(t.precision(), 0)
case t: BigDecimalTypeInfo => BigDecimalTypeInfo.of(t.precision(), 0)
case _ => childType
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
*/
package org.apache.flink.table.expressions

import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable
import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.{fromLogicalTypeToTypeInfo, fromTypeInfoToLogicalType}
import org.apache.flink.table.typeutils.{DecimalTypeInfo, TypeCoercion}
import org.apache.flink.table.typeutils.TypeCoercion
import org.apache.flink.table.typeutils.TypeInfoCheckUtils._
import org.apache.flink.table.validate._

Expand Down Expand Up @@ -71,6 +71,10 @@ case class Plus(left: PlannerExpression, right: PlannerExpression) extends Binar
s"but was '$left' : '${left.resultType}' and '$right' : '${right.resultType}'.")
}
}

override private[flink] def resultType: TypeInformation[_] = {
ReturnTypeInference.inferPlus(this)
}
}

case class UnaryMinus(child: PlannerExpression) extends UnaryExpression {
Expand Down Expand Up @@ -111,24 +115,31 @@ case class Minus(left: PlannerExpression, right: PlannerExpression) extends Bina
s"but was '$left' : '${left.resultType}' and '$right' : '${right.resultType}'.")
}
}

override private[flink] def resultType: TypeInformation[_] = {
ReturnTypeInference.inferMinus(this)
}
}

case class Div(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic {
override def toString = s"($left / $right)"

private[flink] val sqlOperator = FlinkSqlOperatorTable.DIVIDE

override private[flink] def resultType: TypeInformation[_] =
super.resultType match {
case dt: DecimalTypeInfo => dt
case _ => BasicTypeInfo.DOUBLE_TYPE_INFO
}
override private[flink] def resultType: TypeInformation[_] = {
ReturnTypeInference.inferDiv(this)
}

}

case class Mul(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic {
override def toString = s"($left * $right)"

private[flink] val sqlOperator = FlinkSqlOperatorTable.MULTIPLY

override private[flink] def resultType: TypeInformation[_] = {
ReturnTypeInference.inferMul(this)
}
}

case class Mod(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ case class Abs(child: PlannerExpression) extends UnaryExpression {
}

case class Ceil(child: PlannerExpression) extends UnaryExpression {
override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO
override private[flink] def resultType: TypeInformation[_] = {
ReturnTypeInference.inferCeil(this)
}

override private[flink] def validateInput(): ValidationResult =
TypeInfoCheckUtils.assertNumericExpr(child.resultType, "Ceil")
Expand All @@ -50,7 +52,9 @@ case class Exp(child: PlannerExpression) extends UnaryExpression with InputTypeS


case class Floor(child: PlannerExpression) extends UnaryExpression {
override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO
override private[flink] def resultType: TypeInformation[_] = {
ReturnTypeInference.inferFloor(this)
}

override private[flink] def validateInput(): ValidationResult =
TypeInfoCheckUtils.assertNumericExpr(child.resultType, "Floor")
Expand Down Expand Up @@ -258,7 +262,9 @@ case class Sign(child: PlannerExpression) extends UnaryExpression {

case class Round(left: PlannerExpression, right: PlannerExpression)
extends BinaryExpression {
override private[flink] def resultType: TypeInformation[_] = left.resultType
override private[flink] def resultType: TypeInformation[_] = {
ReturnTypeInference.inferRound(this)
}

override private[flink] def validateInput(): ValidationResult = {
if (!TypeInfoCheckUtils.isInteger(right.resultType)) {
Expand Down
Loading

0 comments on commit 0002032

Please sign in to comment.