Skip to content

Commit

Permalink
[FLINK-25227][table-planner] Fix LEAST/GREATEST to return primitives
Browse files Browse the repository at this point in the history
Previously, `LEAST` and `GREATEST` functions would return primitive
types in the generated code implementing their logic, producing issues
for operators applied on top of them, and most importantly comparison
operators, i.e.:
```
f0 INT, f1 INT
SELECT GREATEST(f0, f1) = GREATEST(f0, f1)
```
would return `FALSE`, since the generated code would return `Integer`
instead of `int`, as the result of `GREATEST`, and the `=` operator
on `Integer` objects would return false, even if the actual integer
value of them was the same.
  • Loading branch information
matriv authored and MartijnVisser committed Mar 30, 2022
1 parent 317ba33 commit 89cdc6e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1318,10 +1318,11 @@ object ScalarOperatorGens {
elements: Seq[GeneratedExpression],
greatest: Boolean = true)
: GeneratedExpression = {
val Seq(result, cur, nullTerm) = newNames("result", "cur", "nullTerm")
val Seq(result, tmpResult, cur, nullTerm) = newNames("result", "tmpResult", "cur", "nullTerm")
val widerType = toScala(findCommonType(elements.map(element => element.resultType)))
.orElse(throw new CodeGenException(s"Unable to find common type for $elements."))
val resultTypeTerm = boxedTypeTermForType(widerType.get)
val boxedResultTypeTerm = boxedTypeTermForType(widerType.get)
val primitiveResultTypeTerm = primitiveTypeTermForType(widerType.get)

def castIfNumeric(t: GeneratedExpression): String = {
if (isNumeric(widerType.get)) {
Expand All @@ -1335,13 +1336,13 @@ object ScalarOperatorGens {
s"""
| ${element.code}
| if (!$nullTerm) {
| $resultTypeTerm $cur = ${castIfNumeric(element)};
| $boxedResultTypeTerm $cur = ${castIfNumeric(element)};
| if (${element.nullTerm}) {
| $nullTerm = true;
| } else {
| int compareResult = $result.compareTo($cur);
| int compareResult = $tmpResult.compareTo($cur);
| if (($greatest && compareResult < 0) || (compareResult > 0 && !$greatest)) {
| $result = $cur;
| $tmpResult = $cur;
| }
| }
| }
Expand All @@ -1350,11 +1351,12 @@ object ScalarOperatorGens {

val code =
s"""
| $resultTypeTerm $result = ${castIfNumeric(elements.head)};
| $boxedResultTypeTerm $tmpResult = ${castIfNumeric(elements.head)};
| $primitiveResultTypeTerm $result = ${primitiveDefaultValue(widerType.get)};
| boolean $nullTerm = false;
| $elementsCode
| if ($nullTerm) {
| $result = null;
| if (!$nullTerm) {
| $result = $tmpResult;
| }
""".stripMargin
GeneratedExpression(result, nullTerm, code, resultType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,47 @@ Stream<TestSetSpec> getTestSetSpecs() {
"GREATEST(f6, f7)",
LocalDateTime.parse("1970-01-01T00:00:03.001"),
DataTypes.TIMESTAMP(3).notNull())
// assert that primitive types are returned and used in the equality
// operator applied on top of the GREATEST functions
.testResult(
call(
"EQUALS",
call("GREATEST", $("f1"), $("f2")),
call("GREATEST", $("f1"), $("f2"))),
"GREATEST(f1, f2) = GREATEST(f1, f2)",
true,
DataTypes.BOOLEAN().notNull())
.testResult(
call(
"EQUALS",
call("GREATEST", $("f0"), $("f1")),
call("GREATEST", $("f0"), $("f1"))),
"GREATEST(f0, f1) = GREATEST(f0, f1)",
null,
DataTypes.BOOLEAN())
.testSqlValidationError(
"GREATEST(f5, f6)",
"SQL validation failed. Invalid function call:\n"
+ "GREATEST(STRING NOT NULL, TIMESTAMP(3) NOT NULL)"),
TestSetSpec.forFunction(BuiltInFunctionDefinitions.LEAST)
.onFieldsWithData(null, 1, 2, 3.14, "hello", "world")
.onFieldsWithData(
null,
1,
2,
3.14,
"hello",
"world",
LocalDateTime.parse("1970-01-01T00:00:03.001"),
LocalDateTime.parse("1970-01-01T00:00:02.001"))
.andDataTypes(
DataTypes.INT().nullable(),
DataTypes.INT().notNull(),
DataTypes.INT().notNull(),
DataTypes.DECIMAL(3, 2).notNull(),
DataTypes.STRING().notNull(),
DataTypes.STRING().notNull())
DataTypes.STRING().notNull(),
DataTypes.TIMESTAMP(3).notNull(),
DataTypes.TIMESTAMP(3).notNull())
.testSqlValidationError(
"LEAST(f1, f4)",
"SQL validation failed. Invalid function call:\n"
Expand All @@ -110,6 +138,28 @@ Stream<TestSetSpec> getTestSetSpecs() {
call("LEAST", $("f4"), $("f5")),
"LEAST(f4, f5)",
"hello",
DataTypes.STRING().notNull()));
DataTypes.STRING().notNull())
// assert that primitive types are returned and used in the equality
// operator applied on top of the GREATEST functions
.testResult(
call(
"EQUALS",
call("LEAST", $("f1"), $("f2")),
call("LEAST", $("f1"), $("f2"))),
"LEAST(f1, f2) = LEAST(f1, f2)",
true,
DataTypes.BOOLEAN().notNull())
.testResult(
call(
"EQUALS",
call("LEAST", $("f0"), $("f1")),
call("LEAST", $("f0"), $("f1"))),
"LEAST(f0, f1) = LEAST(f0, f1)",
null,
DataTypes.BOOLEAN())
.testSqlValidationError(
"LEAST(f5, f6)",
"SQL validation failed. Invalid function call:\n"
+ "LEAST(STRING NOT NULL, TIMESTAMP(3) NOT NULL)"));
}
}

0 comments on commit 89cdc6e

Please sign in to comment.