Skip to content

Commit

Permalink
[hotFix] [tableAPI] Improve naming of DataSetRelNodes
Browse files Browse the repository at this point in the history
This closes apache#1799
  • Loading branch information
fhueske authored and vasia committed Mar 18, 2016
1 parent e430dee commit 2c62546
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class DataSetAggregate(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
rowType: RelDataType,
inputType: RelDataType,
opName: String,
grouping: Array[Int])
extends SingleRel(cluster, traitSet, input)
with DataSetRel {
Expand All @@ -59,12 +58,13 @@ class DataSetAggregate(
namedAggregates,
rowType,
inputType,
opName,
grouping)
}

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw).item("name", opName)
super.explainTerms(pw)
.itemIf("groupBy",groupingToString, !grouping.isEmpty)
.item("select", aggregationToString)
}

override def translateToPlan(
Expand All @@ -87,28 +87,34 @@ class DataSetAggregate(
.map(n => TypeConverter.sqlTypeToTypeInfo(n))
.toArray

val rowTypeInfo = new RowTypeInfo(fieldTypes, rowType.getFieldNames.asScala)
val aggString = aggregationToString
val mappedInput = inputDS.map(aggregateResult._1).name(s"prepare $aggString")
val prepareOpName = s"prepare select: ($aggString)"
val mappedInput = inputDS
.map(aggregateResult._1)
.name(prepareOpName)

val groupReduceFunction = aggregateResult._2
val rowTypeInfo = new RowTypeInfo(fieldTypes, rowType.getFieldNames.asScala)

val result = {
if (groupingKeys.length > 0) {
val inFields = inputType.getFieldNames.asScala.toList
val groupByString = s"groupBy: (${grouping.map(inFields(_)).mkString(", ")})"
// grouped aggregation
val aggOpName = s"groupBy: ($groupingToString), select:($aggString)"

mappedInput.asInstanceOf[DataSet[Row]]
.groupBy(groupingKeys: _*)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(groupByString + ", " + aggString)
.name(aggOpName)
.asInstanceOf[DataSet[Any]]
}
else {
// global aggregation
val aggOpName = s"select:($aggString)"
mappedInput.asInstanceOf[DataSet[Row]]
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataSet[Any]]
}
}
Expand All @@ -123,24 +129,35 @@ class DataSetAggregate(
}
}

private def groupingToString: String = {

val inFields = inputType.getFieldNames.asScala
grouping.map( inFields(_) ).mkString(", ")
}

private def aggregationToString: String = {

val inFields = inputType.getFieldNames.asScala.toList
val outFields = rowType.getFieldNames.asScala.toList
val aggs = namedAggregates.map(_.getKey)
val inFields = inputType.getFieldNames.asScala
val outFields = rowType.getFieldNames.asScala

val groupFieldsString = grouping.map( inFields(_) )
val aggsString = aggs.map( a => s"${a.getAggregation}(${inFields(a.getArgList.get(0))})")
val groupStrings = grouping.map( inFields(_) )

val outFieldsString = (groupFieldsString ++ aggsString).zip(outFields).map {
val aggs = namedAggregates.map(_.getKey)
val aggStrings = aggs.map( a => s"${a.getAggregation}(${
if (a.getArgList.size() > 0) {
inFields(a.getArgList.get(0))
} else {
"*"
}
})")

(groupStrings ++ aggStrings).zip(outFields).map {
case (f, o) => if (f == o) {
f
} else {
s"$f AS $o"
}
}

s"select: (${outFieldsString.mkString(", ")})"
}.mkString(", ")
}

private def typeConversion(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class DataSetCalc(
input: RelNode,
rowType: RelDataType,
calcProgram: RexProgram,
opName: String,
ruleDescription: String)
extends SingleRel(cluster, traitSet, input)
with DataSetRel {
Expand All @@ -58,16 +57,15 @@ class DataSetCalc(
inputs.get(0),
rowType,
calcProgram,
opName,
ruleDescription)
}

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw).item("name", opName)
super.explainTerms(pw)
.item("select", selectionToString)
.itemIf("where", conditionToString, calcProgram.getCondition != null)
}

override def toString = opName

override def translateToPlan(config: TableConfig,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {

Expand Down Expand Up @@ -150,40 +148,45 @@ class DataSetCalc(
genFunction.code,
genFunction.returnType)

val calcDesc = calcProgramToString()
val calcOpName =
s"${if (condition != null) {
s"where: ($conditionToString), "
} else {
""
}}select: ($selectionToString)"

inputDS.flatMap(mapFunc).name(calcDesc)
inputDS.flatMap(mapFunc).name(calcOpName)
}

private def calcProgramToString(): String = {

val cond = calcProgram.getCondition
private def selectionToString: String = {
val proj = calcProgram.getProjectList.asScala.toList
val localExprs = calcProgram.getExprList.asScala.toList
val inFields = calcProgram.getInputRowType.getFieldNames.asScala.toList
val localExprs = calcProgram.getExprList.asScala.toList
val outFields = calcProgram.getInputRowType.getFieldNames.asScala.toList

val projString = s"select: (${
proj
.map(getExpressionString(_, inFields, Some(localExprs)))
.zip(outFields).map { case (e, o) => {
if (e != o) {
e + " AS " + o
} else {
e
}
proj
.map(getExpressionString(_, inFields, Some(localExprs)))
.zip(outFields).map { case (e, o) => {
if (e != o) {
e + " AS " + o
} else {
e
}
}
.mkString(", ")
})"
if (cond != null) {
val condString = s"where: (${getExpressionString(cond, inFields, Some(localExprs))})"
}.mkString(", ")
}

condString + ", " + projString
private def conditionToString: String = {

val cond = calcProgram.getCondition
val inFields = calcProgram.getInputRowType.getFieldNames.asScala.toList
val localExprs = calcProgram.getExprList.asScala.toList

if (cond != null) {
getExpressionString(cond, inFields, Some(localExprs))
} else {
projString
""
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class DataSetJoin(
left: RelNode,
right: RelNode,
rowType: RelDataType,
opName: String,
joinCondition: RexNode,
joinRowType: RelDataType,
joinInfo: JoinInfo,
Expand All @@ -68,7 +67,6 @@ class DataSetJoin(
inputs.get(0),
inputs.get(1),
rowType,
opName,
joinCondition,
joinRowType,
joinInfo,
Expand All @@ -79,7 +77,9 @@ class DataSetJoin(
}

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw).item("name", opName)
super.explainTerms(pw)
.item("where", joinConditionToString)
.item("join", joinSelectionToString)
}

override def translateToPlan(
Expand Down Expand Up @@ -145,19 +145,20 @@ class DataSetJoin(
genFunction.code,
genFunction.returnType)

val joinOpName = joinConditionToString()
val joinOpName = s"where: ($joinConditionToString), join: ($joinSelectionToString)"

leftDataSet.join(rightDataSet).where(leftKeys.toArray: _*).equalTo(rightKeys.toArray: _*)
.`with`(joinFun).name(joinOpName).asInstanceOf[DataSet[Any]]
}

private def joinConditionToString(): String = {
private def joinSelectionToString: String = {
rowType.getFieldNames.asScala.toList.mkString(", ")
}

val inFields = joinRowType.getFieldNames.asScala.toList
val condString = s"where: (${getExpressionString(joinCondition, inFields, None)})"
val outFieldString = s"join: (${rowType.getFieldNames.asScala.toList.mkString(", ")})"
private def joinConditionToString: String = {

condString + ", " + outFieldString
val inFields = joinRowType.getFieldNames.asScala.toList
getExpressionString(joinCondition, inFields, None)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.TableConfig

import scala.collection.JavaConverters._

/**
* Flink RelNode which matches along with UnionOperator.
*
Expand All @@ -34,8 +36,7 @@ class DataSetUnion(
traitSet: RelTraitSet,
left: RelNode,
right: RelNode,
rowType: RelDataType,
opName: String)
rowType: RelDataType)
extends BiRel(cluster, traitSet, left, right)
with DataSetRel {

Expand All @@ -47,13 +48,12 @@ class DataSetUnion(
traitSet,
inputs.get(0),
inputs.get(1),
rowType,
opName
rowType
)
}

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw).item("name", opName)
super.explainTerms(pw).item("union", unionSelectionToString)
}

override def translateToPlan(
Expand All @@ -65,4 +65,8 @@ class DataSetUnion(
leftDataSet.union(rightDataSet).asInstanceOf[DataSet[Any]]
}

private def unionSelectionToString: String = {
rowType.getFieldNames.asScala.toList.mkString(", ")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class DataSetAggregateRule
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
agg.toString,
agg.getGroupSet.toArray)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class DataSetCalcRule
convInput,
rel.getRowType,
calc.getProgram,
calc.toString,
description)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ class DataSetJoinRule
convLeft,
convRight,
rel.getRowType,
join.toString,
join.getCondition,
join.getRowType,
joinInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ class DataSetUnionRule
traitSet,
convLeft,
convRight,
rel.getRowType,
union.toString)
rel.getRowType)
}
}

Expand Down

0 comments on commit 2c62546

Please sign in to comment.