Skip to content

Commit

Permalink
[FLINK-3842] [tableApi] Fix handling null record/row in generated code
Browse files Browse the repository at this point in the history
This closes apache#1974
  • Loading branch information
twalthr authored and fhueske committed May 10, 2016
1 parent 7ed0793 commit 08e8054
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import scala.collection.mutable
* A code generator for generating Flink [[org.apache.flink.api.common.functions.Function]]s.
*
* @param config configuration that determines runtime behavior
* @param nullableInput input(s) can be null.
* @param input1 type information about the first input of the Function
* @param input2 type information about the second input if the Function is binary
* @param inputPojoFieldMapping additional mapping information if input1 is a POJO (POJO types
Expand All @@ -50,11 +51,17 @@ import scala.collection.mutable
*/
class CodeGenerator(
config: TableConfig,
nullableInput: Boolean,
input1: TypeInformation[Any],
input2: Option[TypeInformation[Any]] = None,
inputPojoFieldMapping: Option[Array[Int]] = None)
extends RexVisitor[GeneratedExpression] {

// check if nullCheck is enabled when inputs can be null
if (nullableInput && !config.getNullCheck) {
throw new CodeGenException("Null check must be enabled if entire rows can be null.")
}

// check for POJO input mapping
input1 match {
case pt: PojoTypeInfo[_] =>
Expand All @@ -65,7 +72,7 @@ class CodeGenerator(

// check that input2 is never a POJO
input2 match {
case pt: PojoTypeInfo[_] =>
case Some(pt: PojoTypeInfo[_]) =>
throw new CodeGenException("Second input must not be a POJO type.")
case _ => // ok
}
Expand All @@ -75,12 +82,17 @@ class CodeGenerator(
* [[org.apache.flink.api.common.functions.Function]]s with one input.
*
* @param config configuration that determines runtime behavior
* @param nullableInput input(s) can be null.
* @param input type information about the input of the Function
* @param inputPojoFieldMapping additional mapping information necessary if input is a
* POJO (POJO types have no deterministic field order).
*/
def this(config: TableConfig, input: TypeInformation[Any], inputPojoFieldMapping: Array[Int]) =
this(config, input, None, Some(inputPojoFieldMapping))
def this(
config: TableConfig,
nullableInput: Boolean,
input: TypeInformation[Any],
inputPojoFieldMapping: Array[Int]) =
this(config, nullableInput, input, None, Some(inputPojoFieldMapping))


// set of member statements that will be added only once
Expand Down Expand Up @@ -212,7 +224,7 @@ class CodeGenerator(

${reuseMemberCode()}

public $funcName() throws Exception{
public $funcName() throws Exception {
${reuseInitCode()}
}

Expand Down Expand Up @@ -785,73 +797,128 @@ class CodeGenerator(

// generate input access and boxing if necessary
case None =>
val newExpr = inputType match {
val expr = if (nullableInput) {
generateNullableInputFieldAccess(inputType, inputTerm, index)
}
else {
generateFieldAccess(inputType, inputTerm, index)
}

case ct: CompositeType[_] =>
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
inputPojoFieldMapping.get(index)
}
else {
index
}
val accessor = fieldAccessorFor(ct, fieldIndex)
val fieldType: TypeInformation[Any] = ct.getTypeAt(fieldIndex)
val fieldTypeTerm = boxedTypeTermForTypeInfo(fieldType)
reusableInputUnboxingExprs((inputTerm, index)) = expr
expr
}
// hide the generated code as it will be executed only once
GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType)
}

accessor match {
case ObjectFieldAccessor(field) =>
// primitive
if (isFieldPrimitive(field)) {
generateNonNullLiteral(fieldType, s"$inputTerm.${field.getName}")
}
// Object
else {
generateNullableLiteral(
fieldType,
s"($fieldTypeTerm) $inputTerm.${field.getName}")
}
private def generateNullableInputFieldAccess(
inputType: TypeInformation[Any],
inputTerm: String,
index: Int)
: GeneratedExpression = {
val resultTerm = newName("result")
val nullTerm = newName("isNull")

val fieldType = inputType match {
case ct: CompositeType[_] =>
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
inputPojoFieldMapping.get(index)
}
else {
index
}
ct.getTypeAt(fieldIndex)
case at: AtomicType[_] => at
case _ => throw new CodeGenException("Unsupported type for input field access.")
}
val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index)

case ObjectGenericFieldAccessor(fieldName) =>
// Object
val inputCode = s"($fieldTypeTerm) $inputTerm.$fieldName"
generateNullableLiteral(fieldType, inputCode)
val inputCheckCode =
s"""
|$resultTypeTerm $resultTerm;
|boolean $nullTerm;
|if ($inputTerm == null) {
| $resultTerm = $defaultValue;
| $nullTerm = true;
|}
|else {
| ${fieldAccessExpr.code}
| $resultTerm = ${fieldAccessExpr.resultTerm};
| $nullTerm = ${fieldAccessExpr.nullTerm};
|}
|""".stripMargin

case ObjectMethodAccessor(methodName) =>
// Object
val inputCode = s"($fieldTypeTerm) $inputTerm.$methodName()"
generateNullableLiteral(fieldType, inputCode)
GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType)
}

case ProductAccessor(i) =>
// Object
val inputCode = s"($fieldTypeTerm) $inputTerm.productElement($i)"
generateNullableLiteral(fieldType, inputCode)
private def generateFieldAccess(
inputType: TypeInformation[Any],
inputTerm: String,
index: Int)
: GeneratedExpression = {
inputType match {
case ct: CompositeType[_] =>
val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) {
inputPojoFieldMapping.get(index)
}
else {
index
}
val accessor = fieldAccessorFor(ct, fieldIndex)
val fieldType: TypeInformation[Any] = ct.getTypeAt(fieldIndex)
val fieldTypeTerm = boxedTypeTermForTypeInfo(fieldType)

accessor match {
case ObjectFieldAccessor(field) =>
// primitive
if (isFieldPrimitive(field)) {
generateNonNullLiteral(fieldType, s"$inputTerm.${field.getName}")
}
// Object
else {
generateNullableLiteral(
fieldType,
s"($fieldTypeTerm) $inputTerm.${field.getName}")
}

case ObjectPrivateFieldAccessor(field) =>
val fieldTerm = addReusablePrivateFieldAccess(ct.getTypeClass, field.getName)
val reflectiveAccessCode = reflectiveFieldReadAccess(fieldTerm, field, inputTerm)
// primitive
if (isFieldPrimitive(field)) {
generateNonNullLiteral(fieldType, reflectiveAccessCode)
}
// Object
else {
generateNullableLiteral(fieldType, reflectiveAccessCode)
}
case ObjectGenericFieldAccessor(fieldName) =>
// Object
val inputCode = s"($fieldTypeTerm) $inputTerm.$fieldName"
generateNullableLiteral(fieldType, inputCode)

case ObjectMethodAccessor(methodName) =>
// Object
val inputCode = s"($fieldTypeTerm) $inputTerm.$methodName()"
generateNullableLiteral(fieldType, inputCode)

case ProductAccessor(i) =>
// Object
val inputCode = s"($fieldTypeTerm) $inputTerm.productElement($i)"
generateNullableLiteral(fieldType, inputCode)

case ObjectPrivateFieldAccessor(field) =>
val fieldTerm = addReusablePrivateFieldAccess(ct.getTypeClass, field.getName)
val reflectiveAccessCode = reflectiveFieldReadAccess(fieldTerm, field, inputTerm)
// primitive
if (isFieldPrimitive(field)) {
generateNonNullLiteral(fieldType, reflectiveAccessCode)
}
// Object
else {
generateNullableLiteral(fieldType, reflectiveAccessCode)
}
}

case at: AtomicType[_] =>
val fieldTypeTerm = boxedTypeTermForTypeInfo(at)
val inputCode = s"($fieldTypeTerm) $inputTerm"
generateNullableLiteral(at, inputCode)
case at: AtomicType[_] =>
val fieldTypeTerm = boxedTypeTermForTypeInfo(at)
val inputCode = s"($fieldTypeTerm) $inputTerm"
generateNullableLiteral(at, inputCode)

case _ =>
throw new CodeGenException("Unsupported type for input access.")
}
reusableInputUnboxingExprs((inputTerm, index)) = newExpr
newExpr
case _ =>
throw new CodeGenException("Unsupported type for input field access.")
}
// hide the generated code as it will be executed only once
GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType)
}

private def generateNullableLiteral(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ import org.apache.calcite.plan._
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.TableScan
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.PojoTypeInfo
import org.apache.flink.api.table.TableConfig
import org.apache.flink.api.table.codegen.CodeGenerator
import org.apache.flink.api.table.plan.schema.FlinkTable
import org.apache.flink.api.table.runtime.MapRunner
import org.apache.flink.api.table.typeutils.TypeConverter.determineReturnType

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -84,6 +81,7 @@ abstract class BatchScan(

val mapFunc = getConversionMapper(
config,
false,
inputType,
determinedType,
"DataSetSourceConversion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,17 @@

package org.apache.flink.api.table.plan.nodes.dataset

import org.apache.calcite.plan.{RelOptCost, RelOptPlanner, RelOptCluster, RelTraitSet}
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.codegen.CodeGenerator
import org.apache.flink.api.table.runtime.MapRunner
import org.apache.flink.api.table.runtime.aggregate.AggregateUtil
import org.apache.flink.api.table.runtime.aggregate.AggregateUtil.CalcitePair
import org.apache.flink.api.table.typeutils.{TypeConverter, RowTypeInfo}
import org.apache.flink.api.table.{BatchTableEnvironment, Row, TableConfig}
import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter}
import org.apache.flink.api.table.{BatchTableEnvironment, Row}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -143,7 +140,9 @@ class DataSetAggregate(
expectedType match {
case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>
val mapName = s"convert: (${rowType.getFieldNames.asScala.toList.mkString(", ")})"
result.map(getConversionMapper(config,
result.map(getConversionMapper(
config,
false,
rowTypeInfo.asInstanceOf[TypeInformation[Any]],
expectedType.get,
"AggregateOutputConversion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class DataSetCalc(
config.getNullCheck,
config.getEfficientTypeUsage)

val generator = new CodeGenerator(config, inputDS.getType)
val generator = new CodeGenerator(config, false, inputDS.getType)

val body = functionBody(
generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ class DataSetJoin(
val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv)

val generator = new CodeGenerator(config, leftDataSet.getType, Some(rightDataSet.getType))
val generator = new CodeGenerator(
config,
false,
leftDataSet.getType,
Some(rightDataSet.getType))
val conversion = generator.generateConverterResultExpression(
returnType,
joinRowType.getFieldNames)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.codegen.CodeGenerator
import org.apache.flink.api.table.{BatchTableEnvironment, TableConfig, TableEnvironment}
import org.apache.flink.api.table.plan.nodes.FlinkRel
import org.apache.flink.api.table.runtime.MapRunner
import org.apache.flink.api.table.{BatchTableEnvironment, TableConfig}

import scala.collection.JavaConversions._

Expand Down Expand Up @@ -69,6 +69,7 @@ trait DataSetRel extends RelNode with FlinkRel {

private[dataset] def getConversionMapper(
config: TableConfig,
nullableInput: Boolean,
inputType: TypeInformation[Any],
expectedType: TypeInformation[Any],
conversionOperatorName: String,
Expand All @@ -77,6 +78,7 @@ trait DataSetRel extends RelNode with FlinkRel {

val generator = new CodeGenerator(
config,
nullableInput,
inputType,
None,
inputPojoFieldMapping)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.calcite.rel.{RelCollation, RelNode, RelWriter, SingleRel}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.PojoTypeInfo
import org.apache.flink.api.table.BatchTableEnvironment
import org.apache.flink.api.table.typeutils.TypeConverter._

Expand Down Expand Up @@ -88,7 +87,9 @@ class DataSetSort(
// conversion
if (determinedType != inputType) {

val mapFunc = getConversionMapper(config,
val mapFunc = getConversionMapper(
config,
false,
partitionedDs.getType,
determinedType,
"DataSetSortConversion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DataStreamCalc(
config.getNullCheck,
config.getEfficientTypeUsage)

val generator = new CodeGenerator(config, inputDataStream.getType)
val generator = new CodeGenerator(config, false, inputDataStream.getType)

val body = functionBody(
generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ abstract class StreamScan(
if (determinedType != inputType) {
val generator = new CodeGenerator(
config,
false,
input.getType,
flinkTable.fieldIndexes)

Expand Down
Loading

0 comments on commit 08e8054

Please sign in to comment.