Skip to content

Commit

Permalink
[FLINK-7942] [table] Reduce aliasing in RexNodes
Browse files Browse the repository at this point in the history
This closes apache#5019.
  • Loading branch information
twalthr committed Nov 15, 2017
1 parent 59df4b7 commit b6a2dc3
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,13 @@ class OverWindowedTable(

new Table(
table.tableEnv,
Project(expandedOverFields.map(UnresolvedAlias), table.logicalPlan).validate(table.tableEnv))
Project(
expandedOverFields.map(UnresolvedAlias),
table.logicalPlan,
// required for proper projection push down
explicitAlias = true)
.validate(table.tableEnv)
)
}

def select(fields: String): Table = {
Expand Down Expand Up @@ -1150,7 +1156,9 @@ class WindowGroupedTable(
propNames.map(a => Alias(a._1, a._2)).toSeq,
aggNames.map(a => Alias(a._1, a._2)).toSeq,
Project(projectFields, table.logicalPlan).validate(table.tableEnv)
).validate(table.tableEnv)
).validate(table.tableEnv),
// required for proper resolution of the time attribute in multi-windows
explicitAlias = true
).validate(table.tableEnv))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ import org.apache.flink.table.validate.{ValidationFailure, ValidationSuccess}
import scala.collection.JavaConverters._
import scala.collection.mutable

case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extends UnaryNode {
case class Project(
projectList: Seq[NamedExpression],
child: LogicalNode,
explicitAlias: Boolean = false)
extends UnaryNode {

override def output: Seq[Attribute] = projectList.map(_.toAttribute)

override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
Expand All @@ -61,7 +66,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend
throw new RuntimeException("This should never be called and probably points to a bug.")
}
}
Project(newProjectList, child)
Project(newProjectList, child, explicitAlias)
}

override def validate(tableEnv: TableEnvironment): LogicalNode = {
Expand Down Expand Up @@ -90,8 +95,19 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend

override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
child.construct(relBuilder)

val exprs = if (explicitAlias) {
projectList
} else {
// remove AS expressions, according to Calcite they should not be in a final RexNode
projectList.map {
case Alias(e: Expression, _, _) => e
case e: Expression => e
}
}

relBuilder.project(
projectList.map(_.toRexNode(relBuilder)).asJava,
exprs.map(_.toRexNode(relBuilder)).asJava,
projectList.map(_.name).asJava,
true)
}
Expand All @@ -116,7 +132,9 @@ case class AliasNode(aliasList: Seq[Expression], child: LogicalNode) extends Una
val input = child.output
Project(
names.zip(input).map { case (name, attr) =>
Alias(attr, name)} ++ input.drop(names.length), child)
Alias(attr, name)} ++ input.drop(names.length),
child,
explicitAlias = true)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
package org.apache.flink.table.api.batch.table

import org.apache.flink.api.scala._
import org.apache.flink.table.api.batch.table.JoinTest.Merger
import org.apache.flink.table.api.scala._
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.utils.TableTestBase
import org.apache.flink.table.utils.TableTestUtil._
import org.junit.Test
Expand Down Expand Up @@ -301,4 +303,49 @@ class JoinTest extends TableTestBase {

util.verifyTable(joined, expected)
}

@Test
def testFilterJoinRule(): Unit = {
val util = batchTestUtil()
val t1 = util.addTable[(String, Int, Int)]('a, 'b, 'c)
val t2 = util.addTable[(String, Int, Int)]('d, 'e, 'f)
val results = t1
.leftOuterJoin(t2, 'b === 'e)
.select('c, Merger('c, 'f) as 'c0)
.select(Merger('c, 'c0) as 'c1)
.where('c1 >= 0)

val expected = unaryNode(
"DataSetCalc",
binaryNode(
"DataSetJoin",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "b", "c")
),
unaryNode(
"DataSetCalc",
batchTableNode(1),
term("select", "e", "f")
),
term("where", "=(b, e)"),
term("join", "b", "c", "e", "f"),
term("joinType", "LeftOuterJoin")
),
term("select", "Merger$(c, Merger$(c, f)) AS c1"),
term("where", ">=(Merger$(c, Merger$(c, f)), 0)")
)

util.verifyTable(results, expected)
}
}

object JoinTest {

object Merger extends ScalarFunction {
def eval(f0: Int, f1: Int): Int = {
f0 + f1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,18 @@ class RetractionRulesTest extends TableTestBase {

val expected =
unaryNode(
"DataStreamCalc",
"DataStreamGroupAggregate",
unaryNode(
"DataStreamGroupAggregate",
"DataStreamCalc",
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamGroupAggregate",
"DataStreamScan(true, Acc)",
"true, AccRetract"
),
"DataStreamGroupAggregate",
"DataStreamScan(true, Acc)",
"true, AccRetract"
),
s"$defaultStatus"
"true, AccRetract"
),
s"$defaultStatus"
)
)

util.verifyTableTrait(resultTable, expected)
}
Expand Down Expand Up @@ -253,28 +249,20 @@ class RetractionRulesTest extends TableTestBase {

val expected =
unaryNode(
"DataStreamCalc",
"DataStreamGroupAggregate",
unaryNode(
"DataStreamGroupAggregate",
unaryNode(
"DataStreamCalc",
binaryNode(
"DataStreamUnion",
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamGroupAggregate",
"DataStreamScan(true, Acc)",
"true, AccRetract"
),
"true, AccRetract"
),
"DataStreamCalc",
binaryNode(
"DataStreamUnion",
unaryNode(
"DataStreamGroupAggregate",
"DataStreamScan(true, Acc)",
"true, AccRetract"
),
"DataStreamScan(true, Acc)",
"true, AccRetract"
),
s"$defaultStatus"
"true, AccRetract"
),
s"$defaultStatus"
)
Expand Down

0 comments on commit b6a2dc3

Please sign in to comment.