Skip to content

Commit

Permalink
[FLINK-4263] [table] SQL's VALUES does not work properly
Browse files Browse the repository at this point in the history
This closes apache#2818.
  • Loading branch information
twalthr committed Nov 17, 2016
1 parent a1362c3 commit 836fe97
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 55 deletions.
2 changes: 1 addition & 1 deletion flink-libraries/flink-table/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ under the License.
<dependency>
<groupId>org.codehaus.janino</groupId>
<artifactId>janino</artifactId>
<version>2.7.5</version>
<version>3.0.6</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction}
import org.apache.flink.api.common.io.GenericInputFormat
import org.apache.flink.api.common.typeinfo.{AtomicType, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, TupleTypeInfo}
Expand All @@ -35,7 +36,7 @@ import org.apache.flink.api.table.codegen.Indenter.toISC
import org.apache.flink.api.table.codegen.calls.ScalarFunctions
import org.apache.flink.api.table.codegen.calls.ScalarOperators._
import org.apache.flink.api.table.functions.UserDefinedFunction
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter}
import org.apache.flink.api.table.typeutils.TypeCheckUtils._
import org.apache.flink.api.table.{FlinkTypeFactory, TableConfig}

Expand Down Expand Up @@ -98,6 +99,13 @@ class CodeGenerator(
inputPojoFieldMapping: Array[Int]) =
this(config, nullableInput, input, None, Some(inputPojoFieldMapping))

/**
* A code generator for generating Flink input formats.
*
* @param config configuration that determines runtime behavior
*/
def this(config: TableConfig) =
this(config, false, TypeConverter.DEFAULT_ROW_TYPE, None, None)

// set of member statements that will be added only once
// we use a LinkedHashSet to keep the insertion order
Expand Down Expand Up @@ -256,6 +264,61 @@ class CodeGenerator(
GeneratedFunction(funcName, returnType, funcCode)
}

/**
* Generates a values input format that can be passed to Java compiler.
*
* @param name Class name of the input format. Must not be unique but has to be a
* valid Java class identifier.
* @param records code for creating records
* @param returnType expected return type
* @tparam T Flink Function to be generated.
* @return instance of GeneratedFunction
*/
def generateValuesInputFormat[T](
name: String,
records: Seq[String],
returnType: TypeInformation[Any])
: GeneratedFunction[GenericInputFormat[T]] = {
val funcName = newName(name)

addReusableOutRecord(returnType)

val funcCode = j"""
public class $funcName extends ${classOf[GenericInputFormat[_]].getCanonicalName} {

private int nextIdx = 0;

${reuseMemberCode()}

public $funcName() throws Exception {
${reuseInitCode()}
}

@Override
public boolean reachedEnd() throws java.io.IOException {
return nextIdx >= ${records.length};
}

@Override
public Object nextRecord(Object reuse) {
switch (nextIdx) {
${records.zipWithIndex.map { case (r, i) =>
s"""
|case $i:
| $r
|break;
""".stripMargin
}.mkString("\n")}
}
nextIdx++;
return $outRecordTerm;
}
}
""".stripMargin

GeneratedFunction[GenericInputFormat[T]](funcName, returnType, funcCode)
}

/**
* Generates an expression that converts the first input (and second input) into the given type.
* If two inputs are converted, the second input is appended. If objects or variables can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,18 @@ package org.apache.flink.api.table.plan.nodes.dataset

import com.google.common.collect.ImmutableList
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Values
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rex.RexLiteral
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.BatchTableEnvironment
import org.apache.flink.api.table.codegen.CodeGenerator
import org.apache.flink.api.table.runtime.io.ValuesInputFormat
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.typeutils.TypeConverter._
import org.apache.flink.api.table.{BatchTableEnvironment, Row}

import scala.collection.JavaConverters._
import scala.collection.JavaConversions._

/**
* DataSet RelNode for a LogicalValues.
Expand All @@ -42,7 +41,8 @@ class DataSetValues(
cluster: RelOptCluster,
traitSet: RelTraitSet,
rowRelDataType: RelDataType,
tuples: ImmutableList[ImmutableList[RexLiteral]])
tuples: ImmutableList[ImmutableList[RexLiteral]],
ruleDescription: String)
extends Values(cluster, rowRelDataType, tuples, traitSet)
with DataSetRel {

Expand All @@ -53,7 +53,8 @@ class DataSetValues(
cluster,
traitSet,
getRowType,
getTuples
getTuples,
ruleDescription
)
}

Expand All @@ -75,16 +76,29 @@ class DataSetValues(
getRowType,
expectedType,
config.getNullCheck,
config.getEfficientTypeUsage).asInstanceOf[RowTypeInfo]
config.getEfficientTypeUsage)

val generator = new CodeGenerator(config)

// convert List[RexLiteral] to Row
val rows: Seq[Row] = getTuples.asList.map { t =>
val row = new Row(t.size())
t.zipWithIndex.foreach( x => row.setField(x._2, x._1.getValue.asInstanceOf[Any]) )
row
// generate code for every record
val generatedRecords = getTuples.asScala.map { r =>
generator.generateResultExpression(
returnType,
getRowType.getFieldNames.asScala,
r.asScala)
}

val inputFormat = new ValuesInputFormat(rows)
// generate input format
val generatedFunction = generator.generateValuesInputFormat(
ruleDescription,
generatedRecords.map(_.code),
returnType)

val inputFormat = new ValuesInputFormat[Any](
generatedFunction.name,
generatedFunction.code,
generatedFunction.returnType)

tableEnv.execEnv.createInput(inputFormat, returnType).asInstanceOf[DataSet[Any]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Values
import org.apache.calcite.rex.RexLiteral
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.StreamTableEnvironment
import org.apache.flink.api.table.codegen.CodeGenerator
import org.apache.flink.api.table.runtime.io.ValuesInputFormat
import org.apache.flink.api.table.{Row, StreamTableEnvironment}
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.typeutils.TypeConverter._
import org.apache.flink.streaming.api.datastream.DataStream

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._

/**
* DataStream RelNode for LogicalValues.
Expand All @@ -40,7 +40,8 @@ class DataStreamValues(
cluster: RelOptCluster,
traitSet: RelTraitSet,
rowRelDataType: RelDataType,
tuples: ImmutableList[ImmutableList[RexLiteral]])
tuples: ImmutableList[ImmutableList[RexLiteral]],
ruleDescription: String)
extends Values(cluster, rowRelDataType, tuples, traitSet)
with DataStreamRel {

Expand All @@ -51,30 +52,45 @@ class DataStreamValues(
cluster,
traitSet,
getRowType,
getTuples
getTuples,
ruleDescription
)
}

override def translateToPlan(
tableEnv: StreamTableEnvironment,
expectedType: Option[TypeInformation[Any]]) : DataStream[Any] = {
expectedType: Option[TypeInformation[Any]])
: DataStream[Any] = {

val config = tableEnv.getConfig

val returnType = determineReturnType(
getRowType,
expectedType,
config.getNullCheck,
config.getEfficientTypeUsage).asInstanceOf[RowTypeInfo]
config.getEfficientTypeUsage)

// convert List[RexLiteral] to Row
val rows: Seq[Row] = getTuples.asList.map { t =>
val row = new Row(t.size())
t.zipWithIndex.foreach( x => row.setField(x._2, x._1.getValue.asInstanceOf[Any]) )
row
val generator = new CodeGenerator(config)

// generate code for every record
val generatedRecords = getTuples.asScala.map { r =>
generator.generateResultExpression(
returnType,
getRowType.getFieldNames.asScala,
r.asScala)
}

val inputFormat = new ValuesInputFormat(rows)
// generate input format
val generatedFunction = generator.generateValuesInputFormat(
ruleDescription,
generatedRecords.map(_.code),
returnType)

val inputFormat = new ValuesInputFormat[Any](
generatedFunction.name,
generatedFunction.code,
generatedFunction.returnType)

tableEnv.execEnv.createInput(inputFormat, returnType).asInstanceOf[DataStream[Any]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class DataSetValuesRule
rel.getCluster,
traitSet,
rel.getRowType,
values.getTuples)
values.getTuples,
description)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class DataStreamValuesRule
rel.getCluster,
traitSet,
rel.getRowType,
values.getTuples)
values.getTuples,
description)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.flink.api.common.functions.Function
import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.SimpleCompiler

trait FunctionCompiler[T <: Function] {
trait Compiler[T] {

@throws(classOf[CompileException])
def compile(cl: ClassLoader, name: String, code: String): Class[T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class FlatJoinRunner[IN1, IN2, OUT](
@transient returnType: TypeInformation[OUT])
extends RichFlatJoinFunction[IN1, IN2, OUT]
with ResultTypeQueryable[OUT]
with FunctionCompiler[FlatJoinFunction[IN1, IN2, OUT]] {
with Compiler[FlatJoinFunction[IN1, IN2, OUT]] {

val LOG = LoggerFactory.getLogger(this.getClass)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class FlatMapRunner[IN, OUT](
@transient returnType: TypeInformation[OUT])
extends RichFlatMapFunction[IN, OUT]
with ResultTypeQueryable[OUT]
with FunctionCompiler[FlatMapFunction[IN, OUT]] {
with Compiler[FlatMapFunction[IN, OUT]] {

val LOG = LoggerFactory.getLogger(this.getClass)

private var function: FlatMapFunction[IN, OUT] = null
private var function: FlatMapFunction[IN, OUT] = _

override def open(parameters: Configuration): Unit = {
LOG.debug(s"Compiling FlatMapFunction: $name \n\n Code:\n$code")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MapRunner[IN, OUT](
@transient returnType: TypeInformation[OUT])
extends RichMapFunction[IN, OUT]
with ResultTypeQueryable[OUT]
with FunctionCompiler[MapFunction[IN, OUT]] {
with Compiler[MapFunction[IN, OUT]] {

val LOG = LoggerFactory.getLogger(this.getClass)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,35 @@
package org.apache.flink.api.table.runtime.io

import org.apache.flink.api.common.io.{GenericInputFormat, NonParallelInput}
import org.apache.flink.api.table.Row

class ValuesInputFormat(val rows: Seq[Row])
extends GenericInputFormat[Row]
with NonParallelInput {

var readIdx = 0

override def reachedEnd(): Boolean = readIdx == rows.size

override def nextRecord(reuse: Row): Row = {
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.api.table.runtime.Compiler
import org.apache.flink.core.io.GenericInputSplit
import org.slf4j.LoggerFactory

class ValuesInputFormat[OUT](
name: String,
code: String,
@transient returnType: TypeInformation[OUT])
extends GenericInputFormat[OUT]
with NonParallelInput
with ResultTypeQueryable[OUT]
with Compiler[GenericInputFormat[OUT]] {

val LOG = LoggerFactory.getLogger(this.getClass)

private var format: GenericInputFormat[OUT] = _

override def open(split: GenericInputSplit): Unit = {
LOG.debug(s"Compiling GenericInputFormat: $name \n\n Code:\n$code")
val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
LOG.debug("Instantiating GenericInputFormat.")
format = clazz.newInstance()
}

if (readIdx == rows.size) {
return null
}
override def reachedEnd(): Boolean = format.reachedEnd()

val outRow = rows(readIdx)
readIdx += 1
override def nextRecord(reuse: OUT): OUT = format.nextRecord(reuse)

outRow
}
override def getProducedType: TypeInformation[OUT] = returnType
}
Loading

0 comments on commit 836fe97

Please sign in to comment.