Skip to content

Commit

Permalink
[FLINK-2951] [Table API] add union operator to Table API.
Browse files Browse the repository at this point in the history
This closes apache#1315.
  • Loading branch information
chengxiang li authored and twalthr committed Nov 6, 2015
1 parent 0ca425a commit b80d89a
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ class JavaBatchTranslator extends PlanTranslator {
val inType = translatedInput.getType.asInstanceOf[CompositeType[Row]]
val filter = new ExpressionFilterFunction[Row](predicate, inType)
translatedInput.filter(filter).name(predicate.toString)

case uni@UnionAll(left, right) =>
val translatedLeft = translateInternal(left)
val translatedRight = translateInternal(right)
translatedLeft.union(translatedRight).name("Union: " + uni)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ class JavaStreamingTranslator extends PlanTranslator {
val inType = translatedInput.getType.asInstanceOf[CompositeType[Row]]
val filter = new ExpressionFilterFunction[Row](predicate, inType)
translatedInput.filter(filter)

case UnionAll(left, right) =>
val translatedLeft = translateInternal(left)
val translatedRight = translateInternal(right)
translatedLeft.union(translatedRight)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,29 @@ case class Table(private[flink] val operation: PlanNode) {
this.copy(operation = Join(operation, right.operation))
}

/**
* Union two[[Table]]s. Similar to an SQL UNION ALL. The fields of the two union operations
* must fully overlap.
*
* Example:
*
* {{{
* left.unionAll(right)
* }}}
*/
def unionAll(right: Table): Table = {
val leftInputFields = operation.outputFields
val rightInputFields = right.operation.outputFields
if (!leftInputFields.equals(rightInputFields)) {
throw new ExpressionException(
"The fields names of join inputs should be fully overlapped, left inputs fields:" +
operation.outputFields.mkString(", ") +
" and right inputs fields" +
right.operation.outputFields.mkString(", ")
)
}
this.copy(operation = UnionAll(operation, right.operation))
}

override def toString: String = s"Expression($operation)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,14 @@ case class Aggregate(

override def toString = s"Aggregate($input, ${aggregations.mkString(",")})"
}

/**
* UnionAll operation, union all elements from left and right.
*/
case class UnionAll(left: PlanNode, right: PlanNode) extends PlanNode{
val children = Seq(left, right)

def outputFields = left.outputFields

override def toString = s"Union($left, $right)"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.table.test;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.table.TableEnvironment;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.api.table.ExpressionException;
import org.apache.flink.api.table.Row;
import org.apache.flink.api.table.Table;
import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import java.util.List;

@RunWith(Parameterized.class)
public class UnionITCase extends MultipleProgramsTestBase {


public UnionITCase(TestExecutionMode mode) {
super(mode);
}

@Test
public void testUnion() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);

Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, c");

Table selected = in1.unionAll(in2).select("c");
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();

String expected = "Hi\n" + "Hello\n" + "Hello world\n" + "Hi\n" + "Hello\n" + "Hello world\n";
compareResultAsText(results, expected);
}

@Test
public void testUnionWithFilter() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);

Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, d, c, e").select("a, b, c");

Table selected = in1.unionAll(in2).where("b < 2").select("c");
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();

String expected = "Hi\n" + "Hallo\n";
compareResultAsText(results, expected);
}

@Test(expected = ExpressionException.class)
public void testUnionFieldsNameNotOverlap1() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);

Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "d, e, f, g, h");

Table selected = in1.unionAll(in2);
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();

String expected = "";
compareResultAsText(results, expected);
}

@Test(expected = ExpressionException.class)
public void testUnionFieldsNameNotOverlap2() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);

Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, c, d, e").select("a, b, c");

Table selected = in1.unionAll(in2);
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();

String expected = "";
compareResultAsText(results, expected);
}

@Test
public void testUnionWithAggregation() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);

Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, d, c, e").select("a, b, c");

Table selected = in1.unionAll(in2).select("c.count");
DataSet<Row> ds = tableEnv.toDataSet(selected, Row.class);
List<Row> results = ds.collect();

String expected = "18";
compareResultAsText(results, expected);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.scala.table.test

import org.apache.flink.api.scala._
import org.apache.flink.api.scala.table._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.api.table.{ExpressionException, Row}
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

import scala.collection.JavaConversions

@RunWith(classOf[Parameterized])
class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {

@Test
def testUnion(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)

val unionDs = ds1.unionAll(ds2).select('c)

val results = unionDs.toDataSet[Row].collect()
val expected = "Hi\n" + "Hello\n" + "Hello world\n" + "Hi\n" + "Hello\n" + "Hello world\n"
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}

@Test
def testUnionWithFilter(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e)

val joinDs = ds1.unionAll(ds2.select('a, 'b, 'c)).filter('b < 2).select('c)

val results = joinDs.toDataSet[Row].collect()
val expected = "Hi\n" + "Hallo\n"
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}

@Test(expected = classOf[ExpressionException])
def testUnionFieldsNameNotOverlap1(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e)

val unionDs = ds1.unionAll(ds2)

val results = unionDs.toDataSet[Row].collect()
val expected = ""
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}

@Test(expected = classOf[ExpressionException])
def testUnionFieldsNameNotOverlap2(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'c, 'd, 'e).select('a, 'b, 'c)

val unionDs = ds1.unionAll(ds2)

val results = unionDs.toDataSet[Row].collect()
val expected = ""
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}

@Test
def testUnionWithAggregation(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e)

val unionDs = ds1.unionAll(ds2.select('a, 'b, 'c)).select('c.count)

val results = unionDs.toDataSet[Row].collect()
val expected = "18"
TestBaseUtils.compareResultAsText(JavaConversions.seqAsJavaList(results), expected)
}
}

0 comments on commit b80d89a

Please sign in to comment.