Skip to content

Commit

Permalink
[FLINK-7853] [table] Reject table function outer joins with predicate…
Browse files Browse the repository at this point in the history
…s in Table API.

This closes apache#4842.
  • Loading branch information
xccui authored and fhueske committed Oct 18, 2017
1 parent eaa5a46 commit c81a6db
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 26 deletions.
10 changes: 6 additions & 4 deletions docs/dev/table/tableApi.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ Table result = left.join(right)
</tr>
<tr>
<td>
<strong>TableFunction Join</strong><br>
<strong>TableFunction Inner Join</strong><br>
<span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span>
</td>
<td>
Expand All @@ -562,7 +562,7 @@ tEnv.registerFunction("split", split);
// join
Table orders = tableEnv.scan("Orders");
Table result = orders
.join(new Table(tEnv, "split(c)").as("s", "t", "v")))
.join(new Table(tEnv, "split(c)").as("s", "t", "v"))
.select("a, b, s, t, v");
{% endhighlight %}
</td>
Expand All @@ -574,6 +574,7 @@ Table result = orders
</td>
<td>
<p>Joins a table with a the results of a table function. Each row of the left (outer) table is joined with all rows produced by the corresponding call of the table function. If a table function call returns an empty result, the corresponding outer row is preserved and the result padded with null values.
<p><b>Note:</b> Currently, the predicate of a table function left outer join can only be empty or literal <code>true</code>.</p>
</p>
{% highlight java %}
// register function
Expand All @@ -583,7 +584,7 @@ tEnv.registerFunction("split", split);
// join
Table orders = tableEnv.scan("Orders");
Table result = orders
.leftOuterJoin(new Table(tEnv, "split(c)").as("s", "t", "v")))
.leftOuterJoin(new Table(tEnv, "split(c)").as("s", "t", "v"))
.select("a, b, s, t, v");
{% endhighlight %}
</td>
Expand Down Expand Up @@ -664,7 +665,7 @@ val result = left.join(right)
</tr>
<tr>
<td>
<strong>TableFunction Join</strong><br>
<strong>TableFunction Inner Join</strong><br>
<span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span>
</td>
<td>
Expand All @@ -687,6 +688,7 @@ val result: Table = table
<span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span></td>
<td>
<p>Joins a table with a the results of a table function. Each row of the left (outer) table is joined with all rows produced by the corresponding call of the table function. If a table function call returns an empty result, the corresponding outer row is preserved and the result padded with null values.
<p><b>Note:</b> Currently, the predicate of a table function left outer join can only be empty or literal <code>true</code>.</p>
</p>
{% highlight scala %}
// instantiate function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment, UnresolvedException}
import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment, Types, UnresolvedException}
import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
import org.apache.flink.table.expressions.ExpressionUtils.isRowCountLiteral
import org.apache.flink.table.expressions._
Expand Down Expand Up @@ -362,7 +362,7 @@ case class Join(
left: LogicalNode,
right: LogicalNode) extends Attribute {

val isFromLeftInput = left.output.map(_.name).contains(name)
val isFromLeftInput: Boolean = left.output.map(_.name).contains(name)

val (indexInInput, indexInJoin) = if (isFromLeftInput) {
val indexInLeft = left.output.map(_.name).indexOf(name)
Expand Down Expand Up @@ -459,6 +459,11 @@ case class Join(
var equiJoinPredicateFound = false
var nonEquiJoinPredicateFound = false
var localPredicateFound = false
// Whether the predicate is literal true.
val alwaysTrue = expression match {
case x: Literal if x.value.equals(true) => true
case _ => false
}

def validateConditions(exp: Expression, isAndBranch: Boolean): Unit = exp match {
case x: And => x.children.foreach(validateConditions(_, isAndBranch))
Expand All @@ -476,20 +481,30 @@ case class Join(
} else {
nonEquiJoinPredicateFound = true
}
// The boolean literal should be a valid condition type.
case x: Literal if x.resultType == Types.BOOLEAN =>
case x => failValidation(
s"Unsupported condition type: ${x.getClass.getSimpleName}. Condition: $x")
}

validateConditions(expression, isAndBranch = true)
if (!equiJoinPredicateFound) {
failValidation(
s"Invalid join condition: $expression. At least one equi-join predicate is " +
s"required.")
}
if (joinType != JoinType.INNER && (nonEquiJoinPredicateFound || localPredicateFound)) {
failValidation(
s"Invalid join condition: $expression. Non-equality join predicates or local" +
s" predicates are not supported in outer joins.")

// Due to a bug in Apache Calcite (see CALCITE-2004 and FLINK-7865) we cannot accept join
// predicates except literal true for TableFunction left outer join.
if (correlated && right.isInstanceOf[LogicalTableFunctionCall] && joinType != JoinType.INNER ) {
if (!alwaysTrue) failValidation("TableFunction left outer join predicate can only be " +
"empty or literal true.")
} else {
if (!equiJoinPredicateFound) {
failValidation(
s"Invalid join condition: $expression. At least one equi-join predicate is " +
s"required.")
}
if (joinType != JoinType.INNER && (nonEquiJoinPredicateFound || localPredicateFound)) {
failValidation(
s"Invalid join condition: $expression. Non-equality join predicates or local" +
s" predicates are not supported in outer joins.")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.table.api.batch.table

import org.apache.flink.api.scala._
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.api.scala._
import org.apache.flink.table.utils.TableTestUtil._
import org.apache.flink.table.utils.{TableFunc1, TableTestBase}
Expand Down Expand Up @@ -74,12 +75,12 @@ class CorrelateTest extends TableTestBase {
}

@Test
def testLeftOuterJoin(): Unit = {
def testLeftOuterJoinWithoutJoinPredicates(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
val function = util.addFunction("func1", new TableFunc1)

val result = table.leftOuterJoin(function('c) as 's).select('c, 's)
val result = table.leftOuterJoin(function('c) as 's).select('c, 's).where('s > "")

val expected = unaryNode(
"DataSetCalc",
Expand All @@ -93,9 +94,35 @@ class CorrelateTest extends TableTestBase {
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "LEFT")
),
term("select", "c", "s")
term("select", "c", "s"),
term("where", ">(s, '')")
)

util.verifyTable(result, expected)
}

@Test
def testLeftOuterJoinWithLiteralTrue(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
val function = util.addFunction("func1", new TableFunc1)

val result = table.leftOuterJoin(function('c) as 's, true).select('c, 's)

val expected = unaryNode(
"DataSetCalc",
unaryNode(
"DataSetCorrelate",
batchTableNode(0),
term("invocation", s"${function.functionIdentifier}($$2)"),
term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
term("select", "a", "b", "c", "s"),
term("rowType",
"RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, VARCHAR(65536) s)"),
term("joinType", "LEFT")
),
term("select", "c", "s"))

util.verifyTable(result, expected)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.api.batch.table.validation

import org.apache.flink.api.scala._
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.api.scala._
import org.apache.flink.table.utils.{TableFunc1, TableTestBase}
import org.junit.Test

class CorrelateValidationTest extends TableTestBase {

/**
* Due to the improper translation of TableFunction left outer join (see CALCITE-2004), the
* join predicate can only be empty or literal true (the restriction should be removed in
* FLINK-7865).
*/
@Test (expected = classOf[ValidationException])
def testLeftOuterJoinWithPredicates(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
val function = util.addFunction("func1", new TableFunc1)
val result = table
.leftOuterJoin(function('c) as 's, 'c === 's)
.select('c, 's)
util.verifyTable(result, "")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.flink.table.api.stream.table

import org.apache.flink.api.scala._
import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils.Func13
import org.apache.flink.table.utils.TableTestUtil._
Expand Down Expand Up @@ -74,12 +75,12 @@ class CorrelateTest extends TableTestBase {
}

@Test
def testLeftOuterJoin(): Unit = {
def testLeftOuterJoinWithLiteralTrue(): Unit = {
val util = streamTestUtil()
val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
val function = util.addFunction("func1", new TableFunc1)

val result = table.leftOuterJoin(function('c) as 's).select('c, 's)
val result = table.leftOuterJoin(function('c) as 's, true).select('c, 's)

val expected = unaryNode(
"DataStreamCalc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils._
import org.apache.flink.table.runtime.utils._
import org.apache.flink.table.utils.{ObjectTableFunction, TableFunc1, TableFunc2, TableTestBase}
import org.junit.Assert.{assertTrue, fail}
import org.junit.Test
Expand Down Expand Up @@ -176,6 +175,22 @@ class CorrelateValidationTest extends TableTestBase {
"Given parameters of function 'func2' do not match any signature.")
}

/**
* Due to the improper translation of TableFunction left outer join (see CALCITE-2004), the
* join predicate can only be empty or literal true (the restriction should be removed in
* FLINK-7865).
*/
@Test (expected = classOf[ValidationException])
def testLeftOuterJoinWithPredicates(): Unit = {
val util = streamTestUtil()
val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
val function = util.addFunction("func1", new TableFunc1)

val result = table.leftOuterJoin(function('c) as 's, 'c === 's).select('c, 's).where('a > 10)

util.verifyTable(result, "")
}

// ----------------------------------------------------------------------------------------------

private def expectExceptionThrown(
Expand All @@ -196,5 +211,4 @@ class CorrelateValidationTest extends TableTestBase {
case e: Throwable => fail(s"Expected throw ${clazz.getSimpleName}, but is $e.")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}

import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.JavaTableFunc0
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, RichFunc2}
Expand Down Expand Up @@ -69,7 +69,7 @@ class CorrelateITCase(
}

@Test
def testLeftOuterJoin(): Unit = {
def testLeftOuterJoinWithoutPredicates(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env, config)
val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
Expand All @@ -82,6 +82,25 @@ class CorrelateITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

/**
* Common join predicates are temporarily forbidden (see FLINK-7865).
*/
@Test (expected = classOf[ValidationException])
def testLeftOuterJoinWithPredicates(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env, config)
val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)

val func2 = new TableFunc2
val result = in
.leftOuterJoin(func2('c) as ('s, 'l), 'a === 'l)
.select('c, 's, 'l)
.toDataSet[Row]
val results = result.collect()
val expected = "John#19,19,2\n" + "nosharp,null,null"
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testWithFilter(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.flink.table.runtime.stream.table
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.{TableEnvironment, ValidationException}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils.{Func18, RichFunc2}
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData}
Expand Down Expand Up @@ -64,7 +64,7 @@ class CorrelateITCase extends StreamingMultipleProgramsTestBase {
}

@Test
def testLeftOuterJoin(): Unit = {
def testLeftOuterJoinWithoutPredicates(): Unit = {
val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
val func0 = new TableFunc0

Expand All @@ -82,6 +82,27 @@ class CorrelateITCase extends StreamingMultipleProgramsTestBase {
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

/**
* Common join predicates are temporarily forbidden (see FLINK-7865).
*/
@Test (expected = classOf[ValidationException])
def testLeftOuterJoinWithPredicates(): Unit = {
val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
val func0 = new TableFunc0

val result = t
.leftOuterJoin(func0('c) as ('s, 'l), 'a === 'l)
.select('c, 's, 'l)
.toAppendStream[Row]

result.addSink(new StreamITCase.StringSink[Row])
env.execute()

val expected = "John#19,null,null\n" + "John#22,null,null\n" + "Anna44,null,null\n" +
"nosharp,null,null"
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

@Test
def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
val t = testData(env).toTable(tEnv).as('a, 'b, 'c)
Expand Down

0 comments on commit c81a6db

Please sign in to comment.