Skip to content

Commit

Permalink
[FLINK-14398][table-planner] Further split input unboxing code into s…
Browse files Browse the repository at this point in the history
…eparate methods (apache#10000)
  • Loading branch information
haodang authored and wuchong committed Oct 30, 2019
1 parent 31e89c7 commit 4c87259
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1052,23 +1052,51 @@ abstract class CodeGenerator(
// ----------------------------------------------------------------------------------------------
// generator helping methods
// ----------------------------------------------------------------------------------------------
protected def makeReusableInSplits(expr: GeneratedExpression): GeneratedExpression = {
// prepare declaration in class
val resultTypeTerm = primitiveTypeTermForTypeInfo(expr.resultType)
if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
reusableMemberStatements.add(s"private boolean ${expr.nullTerm};")
}
reusableMemberStatements.add(s"private $resultTypeTerm ${expr.resultTerm};")

protected def makeReusableInSplits(exprs: Iterable[GeneratedExpression]): Unit = {
// add results of expressions to member area such that all split functions can access it
exprs.foreach { expr =>

// declaration
val resultTypeTerm = primitiveTypeTermForTypeInfo(expr.resultType)
if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
reusableMemberStatements.add(s"private boolean ${expr.nullTerm};")
}
reusableMemberStatements.add(s"private $resultTypeTerm ${expr.resultTerm};")

// assignment
// when expr has no code, no need to split it into a method, but still need to assign
if (expr.code.isEmpty) {
if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
reusablePerRecordStatements.add(s"this.${expr.nullTerm} = ${expr.nullTerm};")
}
reusablePerRecordStatements.add(s"this.${expr.resultTerm} = ${expr.resultTerm};")
expr
} else {
// create a method for the unboxing block
val methodName = newName(s"inputUnboxingSplit")
val method =
if (nullCheck && !expr.nullTerm.equals(NEVER_NULL) && !expr.nullTerm.equals(ALWAYS_NULL)) {
s"""
|private final void $methodName() throws Exception {
| ${expr.code}
| this.${expr.nullTerm} = ${expr.nullTerm};
| this.${expr.resultTerm} = ${expr.resultTerm};
|}
""".stripMargin
} else {
s"""
|private final void $methodName() throws Exception {
| ${expr.code}
| this.${expr.resultTerm} = ${expr.resultTerm};
|}
""".stripMargin
}

// add this method to reusable section for later generation
reusableMemberStatements.add(method)

// create method call
GeneratedExpression(
expr.resultTerm,
expr.nullTerm,
s"$methodName();",
expr.resultType)
}
}

Expand All @@ -1081,7 +1109,9 @@ abstract class CodeGenerator(
hasCodeSplits = true

// add input unboxing to member area such that all split functions can access it
makeReusableInSplits(reusableInputUnboxingExprs.values)
reusableInputUnboxingExprs.keys.foreach(
key =>
reusableInputUnboxingExprs(key) = makeReusableInSplits(reusableInputUnboxingExprs(key)))

// add split methods to the member area and return the code necessary to call those methods
val methodCalls = splits.map { split =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,16 @@ class CollectorCodeGenerator(
val input1TypeClass = boxedTypeTermForTypeInfo(input1)
val input2TypeClass = boxedTypeTermForTypeInfo(collectedType)

// declaration in case of code splits
val recordMember = if (hasCodeSplits) {
s"private $input2TypeClass $input2Term;"
} else {
""
}

// assignment in case of code splits
val recordAssignment = if (hasCodeSplits) {
s"$input2Term" // use member
} else {
s"$input2TypeClass $input2Term" // local variable
}

reusableMemberStatements ++= filterGenerator.reusableMemberStatements
reusableInitStatements ++= filterGenerator.reusableInitStatements
reusablePerRecordStatements ++= filterGenerator.reusablePerRecordStatements

val funcCode = j"""
|public class $className extends ${classOf[TableFunctionCollector[_]].getCanonicalName} {
|
| $recordMember
| private $input1TypeClass $input1Term;
| private $input2TypeClass $input2Term;
|
| ${reuseMemberCode()}
|
| public $className() throws Exception {
Expand All @@ -109,8 +97,8 @@ class CollectorCodeGenerator(
| @Override
| public void collect(Object record) throws Exception {
| super.collect(record);
| $input1TypeClass $input1Term = ($input1TypeClass) getInput();
| $recordAssignment = ($input2TypeClass) record;
| $input1Term = ($input1TypeClass) getInput();
| $input2Term = ($input2TypeClass) record;
| ${reuseInputUnboxingCode()}
| ${reusePerRecordCode()}
| $bodyCode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,22 @@ class FunctionCodeGenerator(
if (clazz == classOf[FlatMapFunction[_, _]]) {
val baseClass = classOf[RichFlatMapFunction[_, _]]
val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
// declaration: make variable accessible for separated method
reusableMemberStatements.add(s"private $inputTypeTerm $input1Term;")
(baseClass,
s"void flatMap(Object _in1, $collectorTypeTerm $collectorTerm)",
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
List(s"$input1Term = ($inputTypeTerm) _in1;"))
}

// MapFunction
else if (clazz == classOf[MapFunction[_, _]]) {
val baseClass = classOf[RichMapFunction[_, _]]
val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
// declaration: make variable accessible for separated method
reusableMemberStatements.add(s"private $inputTypeTerm $input1Term;")
(baseClass,
"Object map(Object _in1)",
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
List(s"$input1Term = ($inputTypeTerm) _in1;"))
}

// FlatJoinFunction
Expand All @@ -121,10 +125,13 @@ class FunctionCodeGenerator(
val inputTypeTerm1 = boxedTypeTermForTypeInfo(input1)
val inputTypeTerm2 = boxedTypeTermForTypeInfo(input2.getOrElse(
throw new CodeGenException("Input 2 for FlatJoinFunction should not be null")))
// declaration: make variables accessible for separated methods
reusableMemberStatements.add(s"private $inputTypeTerm1 $input1Term;")
reusableMemberStatements.add(s"private $inputTypeTerm2 $input2Term;")
(baseClass,
s"void join(Object _in1, Object _in2, $collectorTypeTerm $collectorTerm)",
List(s"$inputTypeTerm1 $input1Term = ($inputTypeTerm1) _in1;",
s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
List(s"$input1Term = ($inputTypeTerm1) _in1;",
s"$input2Term = ($inputTypeTerm2) _in2;"))
}

// JoinFunction
Expand All @@ -133,10 +140,13 @@ class FunctionCodeGenerator(
val inputTypeTerm1 = boxedTypeTermForTypeInfo(input1)
val inputTypeTerm2 = boxedTypeTermForTypeInfo(input2.getOrElse(
throw new CodeGenException("Input 2 for JoinFunction should not be null")))
// declaration: make variables accessible for separated methods
reusableMemberStatements.add(s"private $inputTypeTerm1 $input1Term;")
reusableMemberStatements.add(s"private $inputTypeTerm2 $input2Term;")
(baseClass,
s"Object join(Object _in1, Object _in2)",
List(s"$inputTypeTerm1 $input1Term = ($inputTypeTerm1) _in1;",
s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
List(s"$input1Term = ($inputTypeTerm1) _in1;",
s"$input2Term = ($inputTypeTerm2) _in2;"))
}

// ProcessFunction
Expand All @@ -155,10 +165,12 @@ class FunctionCodeGenerator(
Nil
}

// declaration: make variable accessible for separated method
reusableMemberStatements.add(s"private $inputTypeTerm $input1Term;")
(baseClass,
s"void processElement(Object _in1, $contextTypeTerm $contextTerm, " +
s"$collectorTypeTerm $collectorTerm)",
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;") ++ globalContext)
List(s"$input1Term = ($inputTypeTerm) _in1;") ++ globalContext)
}
else {
// TODO more functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,20 +316,22 @@ class MatchCodeGenerator(
val baseClass = classOf[RichIterativeCondition[_]]
val inputTypeTerm = boxedTypeTermForTypeInfo(input)
val contextType = classOf[IterativeCondition.Context[_]].getCanonicalName

// declaration: make variable accessible for separated methods
reusableMemberStatements.add(s"private $inputTypeTerm $input1Term;")
(baseClass,
s"boolean filter(Object _in1, $contextType $contextTerm)",
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
List(s"$input1Term = ($inputTypeTerm) _in1;"))
} else if (clazz == classOf[PatternProcessFunction[_, _]]) {
val baseClass = classOf[PatternProcessFunction[_, _]]
val inputTypeTerm =
s"java.util.Map<String, java.util.List<${boxedTypeTermForTypeInfo(input)}>>"
val contextTypeTerm = classOf[PatternProcessFunction.Context].getCanonicalName

// declaration: make variable accessible for separated method
reusableMemberStatements.add(s"private $inputTypeTerm $input1Term;")
(baseClass,
s"void processMatch($inputTypeTerm $input1Term, $contextTypeTerm $contextTerm, " +
s"void processMatch($inputTypeTerm _in1, $contextTypeTerm $contextTerm, " +
s"$collectorTypeTerm $collectorTerm)",
List())
List(s"this.$input1Term = ($inputTypeTerm) _in1;"))
} else {
throw new CodeGenException("Unsupported Function.")
}
Expand Down Expand Up @@ -434,7 +436,7 @@ class MatchCodeGenerator(
returnType.fieldNames)
aggregatesPerVariable.values.foreach(_.generateAggFunction())
if (hasCodeSplits) {
makeReusableInSplits(reusableAggregationExpr.values)
makeReusableInSplits()
}

exp
Expand All @@ -444,12 +446,18 @@ class MatchCodeGenerator(
val exp = call.accept(this)
aggregatesPerVariable.values.foreach(_.generateAggFunction())
if (hasCodeSplits) {
makeReusableInSplits(reusableAggregationExpr.values)
makeReusableInSplits()
}

exp
}

private def makeReusableInSplits(): Unit = {
reusableAggregationExpr.keys.foreach(
key =>
reusableAggregationExpr(key) = makeReusableInSplits(reusableAggregationExpr(key)))
}

override def visitCall(call: RexCall): GeneratedExpression = {
call.getOperator match {
case PREV | NEXT =>
Expand Down Expand Up @@ -539,11 +547,13 @@ class MatchCodeGenerator(
} else {
""
}

reusableMemberStatements.add(s"java.util.List $listName = new java.util.ArrayList();")
val listCode = if (patternName == ALL_PATTERN_VARIABLE) {
addReusablePatternNames()
val patternTerm = newName("pattern")
j"""
|java.util.List $listName = new java.util.ArrayList();
|$listName = new java.util.ArrayList();
|for (String $patternTerm : $patternNamesTerm) {
| for ($eventTypeTerm $eventNameTerm :
| $contextTerm.getEventsForPattern($patternTerm)) {
Expand All @@ -554,7 +564,7 @@ class MatchCodeGenerator(
} else {
val escapedPatternName = EncodingUtils.escapeJava(patternName)
j"""
|java.util.List $listName = new java.util.ArrayList();
|$listName = new java.util.ArrayList();
|for ($eventTypeTerm $eventNameTerm :
| $contextTerm.getEventsForPattern("$escapedPatternName")) {
| $listName.add($eventNameTerm);
Expand All @@ -574,13 +584,14 @@ class MatchCodeGenerator(
private def generateMeasurePatternVariableExp(patternName: String): GeneratedPatternList = {
val listName = newName("patternEvents")

reusableMemberStatements.add(s"java.util.List $listName = new java.util.ArrayList();")
val code = if (patternName == ALL_PATTERN_VARIABLE) {
addReusablePatternNames()

val patternTerm = newName("pattern")

j"""
|java.util.List $listName = new java.util.ArrayList();
|$listName = new java.util.ArrayList();
|for (String $patternTerm : $patternNamesTerm) {
| java.util.List rows = (java.util.List) $input1Term.get($patternTerm);
| if (rows != null) {
Expand All @@ -591,7 +602,7 @@ class MatchCodeGenerator(
} else {
val escapedPatternName = EncodingUtils.escapeJava(patternName)
j"""
|java.util.List $listName = (java.util.List) $input1Term.get("$escapedPatternName");
|$listName = (java.util.List) $input1Term.get("$escapedPatternName");
|if ($listName == null) {
| $listName = java.util.Collections.emptyList();
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.junit.Assert._
import org.junit._

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class SqlITCase extends StreamingWithStateTestBase {

Expand Down Expand Up @@ -870,6 +871,45 @@ class SqlITCase extends StreamingWithStateTestBase {

assertEquals(List(expected.toString()), StreamITCase.testResults.sorted)
}

@Test
def testProjectionWithManyColumns(): Unit = {

val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env)
StreamITCase.clear

// force code split
tEnv.getConfig.setMaxGeneratedCodeLength(1)

val length = 1000
val rowData = List.range(0, length)
val row: Row = new Row(length)
val fieldTypes = new ArrayBuffer[TypeInformation[_]]()
val fieldNames = new ArrayBuffer[String]()
rowData.foreach { i =>
row.setField(i, i)
fieldTypes += Types.INT()
fieldNames += s"f$i"
}

val data = new mutable.MutableList[Row]
data.+=(row)
val t = env.fromCollection(data)(new RowTypeInfo(fieldTypes.toArray: _*)).toTable(tEnv)
tEnv.registerTable("MyTable", t)

val expected = List(rowData.reverse.mkString(","))
val sql =
s"""
|SELECT ${fieldNames.reverse.mkString(", ")} FROM MyTable
""".stripMargin

val result = tEnv.sqlQuery(sql).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()

assertEquals(expected, StreamITCase.testResults)
}
}

object SqlITCase {
Expand Down

0 comments on commit 4c87259

Please sign in to comment.