Skip to content

Commit

Permalink
[FLINK-5226] [table] Use correct DataSetCostFactory and improve DataS…
Browse files Browse the repository at this point in the history
…etCalc costs.

- Improved DataSetCalc costs make projections cheap and help to push them down.

This closes apache#2926.
  • Loading branch information
fhueske committed Dec 8, 2016
1 parent 55d6061 commit 677d0d9
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object FlinkRelBuilder {
val typeFactory = new FlinkTypeFactory(typeSystem)

// create context instances with Flink type factory
val planner = new VolcanoPlanner(Contexts.empty())
val planner = new VolcanoPlanner(config.getCostFactory, Contexts.empty())
planner.setExecutor(config.getExecutor)
planner.addRelTraitDef(ConventionTraitDef.INSTANCE)
val cluster = RelOptCluster.create(planner, new RexBuilder(typeFactory))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import TypeConverter._
import org.apache.flink.api.table.BatchTableEnvironment
import org.apache.calcite.rex._

import scala.collection.JavaConverters._

/**
* Flink RelNode which matches along with LogicalCalc.
*
Expand Down Expand Up @@ -73,8 +75,16 @@ class DataSetCalc(

val child = this.getInput
val rowCnt = metadata.getRowCount(child)
val exprCnt = calcProgram.getExprCount
planner.getCostFactory.makeCost(rowCnt, rowCnt * exprCnt, 0)

// compute number of expressions that do not access a field or literal, i.e. computations,
// conditions, etc. We only want to account for computations, not for simple projections.
val compCnt = calcProgram.getExprList.asScala.toList.count {
case i: RexInputRef => false
case l: RexLiteral => false
case _ => true
}

planner.getCostFactory.makeCost(rowCnt, rowCnt * compCnt, 0)
}

override def estimateRowCount(metadata: RelMetadataQuery): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,17 @@ class DataSetJoin(

override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {

val children = this.getInputs
children.foldLeft(planner.getCostFactory.makeCost(0, 0, 0)) { (cost, child) =>
val rowCnt = metadata.getRowCount(child)
val rowSize = this.estimateRowSize(child.getRowType)
cost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * rowSize))
}
val leftRowCnt = metadata.getRowCount(getLeft)
val leftRowSize = estimateRowSize(getLeft.getRowType)

val rightRowCnt = metadata.getRowCount(getRight)
val rightRowSize = estimateRowSize(getRight.getRowType)

val ioCost = (leftRowCnt * leftRowSize) + (rightRowCnt * rightRowSize)
val cpuCost = leftRowCnt + rightRowCnt
val rowCnt = leftRowCnt + rightRowCnt

planner.getCostFactory.makeCost(rowCnt, cpuCost, ioCost)
}

override def translateToPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,23 @@ class SetOperatorsTest extends TableTestBase {
"DataSetCalc",
binaryNode(
"DataSetJoin",
batchTableNode(1),
unaryNode(
"DataSetCalc",
batchTableNode(1),
term("select", "b_long")
),
unaryNode(
"DataSetAggregate",
batchTableNode(0),
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "a_long")
),
term("groupBy", "a_long"),
term("select", "a_long")
),
term("where", "=(a_long, b_long)"),
term("join", "b_long", "b_int", "b_string", "a_long"),
term("join", "b_long", "a_long"),
term("joinType", "InnerJoin")
),
term("select", "true AS $f0", "a_long")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,23 @@ class SingleRowJoinTest extends TableTestBase {
"DataSetUnion",
unaryNode(
"DataSetValues",
batchTableNode(0),
tuples(List(null, null)),
term("values", "a1", "a2")
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "a1")
),
tuples(List(null)),
term("values", "a1")
),
term("union","a1","a2")
term("union","a1")
),
term("select", "COUNT(a1) AS cnt")
),
term("where", "true"),
term("where", "=(CAST(a1), cnt)"),
term("join", "a1", "a2", "cnt"),
term("joinType", "NestedLoopJoin")
),
term("select", "a1", "a2"),
term("where", "=(CAST(a1), cnt)")
term("select", "a1", "a2")
)

util.verifySql(query, expected)
Expand Down Expand Up @@ -89,20 +92,23 @@ class SingleRowJoinTest extends TableTestBase {
"DataSetUnion",
unaryNode(
"DataSetValues",
batchTableNode(0),
tuples(List(null, null)),
term("values", "a1", "a2")
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "a1")
),
tuples(List(null)),
term("values", "a1")
),
term("union","a1","a2")
term("union", "a1")
),
term("select", "COUNT(a1) AS cnt")
),
term("where", "true"),
term("where", "<(a1, cnt)"),
term("join", "a1", "a2", "cnt"),
term("joinType", "NestedLoopJoin")
),
term("select", "a1", "a2"),
term("where", "<(a1, cnt)")
term("select", "a1", "a2")
)

util.verifySql(query, expected)
Expand Down

0 comments on commit 677d0d9

Please sign in to comment.