Skip to content

Commit

Permalink
[FLINK-3108] [java] JoinOperator's with() calls the wrong TypeExtract…
Browse files Browse the repository at this point in the history
…or method

This closes apache#1440
  • Loading branch information
twalthr authored and fhueske committed Dec 7, 2015
1 parent 9547f08 commit 4dbb10f
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ public <R> CoGroupOperator<I1, I2, R> with(CoGroupFunction<I1, I2, R> function)
if (function == null) {
throw new NullPointerException("CoGroup function must not be null.");
}
TypeInformation<R> returnType = TypeExtractor.getCoGroupReturnTypes(function, input1.getType(), input2.getType());
TypeInformation<R> returnType = TypeExtractor.getCoGroupReturnTypes(function, input1.getType(), input2.getType(),
Utils.getCallLocationName(), true);

return new CoGroupOperator<I1, I2, R>(input1, input2, keys1, keys2, input1.clean(function), returnType,
groupSortKeyOrderFirst, groupSortKeyOrderSecond,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,16 +559,16 @@ public <R> EquiJoin<I1, I2, R> with(FlatJoinFunction<I1, I2, R> function) {
if (function == null) {
throw new NullPointerException("Join function must not be null.");
}
TypeInformation<R> returnType = TypeExtractor.getFlatJoinReturnTypes(function, getInput1Type(), getInput2Type());
TypeInformation<R> returnType = TypeExtractor.getFlatJoinReturnTypes(function, getInput1Type(), getInput2Type(), Utils.getCallLocationName(), true);
return new EquiJoin<>(getInput1(), getInput2(), getKeys1(), getKeys2(), clean(function), returnType, getJoinHint(), Utils.getCallLocationName(), joinType);
}

public <R> EquiJoin<I1, I2, R> with (JoinFunction<I1, I2, R> function) {
public <R> EquiJoin<I1, I2, R> with(JoinFunction<I1, I2, R> function) {
if (function == null) {
throw new NullPointerException("Join function must not be null.");
}
FlatJoinFunction<I1, I2, R> generatedFunction = new WrappingFlatJoinFunction<>(clean(function));
TypeInformation<R> returnType = TypeExtractor.getJoinReturnTypes(function, getInput1Type(), getInput2Type());
TypeInformation<R> returnType = TypeExtractor.getJoinReturnTypes(function, getInput1Type(), getInput2Type(), Utils.getCallLocationName(), true);
return new EquiJoin<>(getInput1(), getInput2(), getKeys1(), getKeys2(), generatedFunction, function, returnType, getJoinHint(), Utils.getCallLocationName(), joinType);
}

Expand All @@ -582,7 +582,7 @@ public WrappingFlatJoinFunction(JoinFunction<IN1, IN2, OUT> wrappedFunction) {

@Override
public void join(IN1 left, IN2 right, Collector<OUT> out) throws Exception {
out.collect (this.wrappedFunction.join(left, right));
out.collect(this.wrappedFunction.join(left, right));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ public <R> GroupReduceOperator<T, R> reduceGroup(GroupReduceFunction<T, R> reduc
throw new NullPointerException("GroupReduce function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupReduceReturnTypes(reducer,
this.getDataSet().getType());
return new GroupReduceOperator<T, R>(this, resultType, dataSet.clean(reducer), Utils.getCallLocationName() );
this.getDataSet().getType(), Utils.getCallLocationName(), true);
return new GroupReduceOperator<T, R>(this, resultType, dataSet.clean(reducer), Utils.getCallLocationName());
}

/**
Expand All @@ -182,7 +182,8 @@ public <R> GroupCombineOperator<T, R> combineGroup(GroupCombineFunction<T, R> co
if (combiner == null) {
throw new NullPointerException("GroupCombine function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner, this.getDataSet().getType());
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner,
this.getDataSet().getType(), Utils.getCallLocationName(), true);

return new GroupCombineOperator<T, R>(this, resultType, dataSet.clean(combiner), Utils.getCallLocationName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ public <R> GroupReduceOperator<T, R> reduceGroup(GroupReduceFunction<T, R> reduc
if (reducer == null) {
throw new NullPointerException("GroupReduce function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupReduceReturnTypes(reducer, this.getDataSet().getType());
TypeInformation<R> resultType = TypeExtractor.getGroupReduceReturnTypes(reducer,
this.getDataSet().getType(), Utils.getCallLocationName(), true);

return new GroupReduceOperator<T, R>(this, resultType, dataSet.clean(reducer), Utils.getCallLocationName());
}
Expand All @@ -177,7 +178,8 @@ public <R> GroupCombineOperator<T, R> combineGroup(GroupCombineFunction<T, R> co
if (combiner == null) {
throw new NullPointerException("GroupCombine function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner, this.getDataSet().getType());
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner,
this.getDataSet().getType(), Utils.getCallLocationName(), true);

return new GroupCombineOperator<T, R>(this, resultType, dataSet.clean(combiner), Utils.getCallLocationName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
import java.util.LinkedList;
import java.util.List;

import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
Expand All @@ -42,7 +48,7 @@
@RunWith(Parameterized.class)
public class TypeHintITCase extends JavaProgramTestBase {

private static int NUM_PROGRAMS = 3;
private static int NUM_PROGRAMS = 9;

private int curProgId = config.getInteger("ProgramId", -1);

Expand Down Expand Up @@ -114,9 +120,9 @@ public static void runProgram(int progId) throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> identityMapDs = ds.
flatMap(new FlatMapper<Tuple3<Integer, Long, String>, Integer>())
.returns(Integer.class);
DataSet<Integer> identityMapDs = ds
.flatMap(new FlatMapper<Tuple3<Integer, Long, String>, Integer>())
.returns(Integer.class);
List<Integer> result = identityMapDs.collect();

String expectedResult = "2\n" +
Expand All @@ -126,6 +132,124 @@ public static void runProgram(int progId) throws Exception {
compareResultAsText(result, expectedResult);
break;
}
// Test join with type information type hint
case 4: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds1
.join(ds2)
.where(0)
.equalTo(0)
.with(new Joiner<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();

String expectedResult = "2\n" +
"3\n" +
"1\n";

compareResultAsText(result, expectedResult);
break;
}
// Test flat join with type information type hint
case 5: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds1
.join(ds2)
.where(0)
.equalTo(0)
.with(new FlatJoiner<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();

String expectedResult = "2\n" +
"3\n" +
"1\n";

compareResultAsText(result, expectedResult);
break;
}
// Test unsorted group reduce with type information type hint
case 6: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds
.groupBy(0)
.reduceGroup(new GroupReducer<Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();

String expectedResult = "2\n" +
"3\n" +
"1\n";

compareResultAsText(result, expectedResult);
break;
}
// Test sorted group reduce with type information type hint
case 7: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds
.groupBy(0)
.sortGroup(0, Order.ASCENDING)
.reduceGroup(new GroupReducer<Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();

String expectedResult = "2\n" +
"3\n" +
"1\n";

compareResultAsText(result, expectedResult);
break;
}
// Test combine group with type information type hint
case 8: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds
.groupBy(0)
.combineGroup(new GroupCombiner<Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();

String expectedResult = "2\n" +
"3\n" +
"1\n";

compareResultAsText(result, expectedResult);
break;
}
// Test cogroup with type information type hint
case 9: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds1
.coGroup(ds2)
.where(0)
.equalTo(0)
.with(new CoGrouper<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();

String expectedResult = "2\n" +
"3\n" +
"1\n";

compareResultAsText(result, expectedResult);
break;
}
default:
throw new IllegalArgumentException("Invalid program id");
}
Expand Down Expand Up @@ -154,4 +278,49 @@ public void flatMap(T value, Collector<V> out) throws Exception {
}
}

public static class Joiner<IN1, IN2, OUT> implements JoinFunction<IN1, IN2, OUT> {
private static final long serialVersionUID = 1L;

@Override
public OUT join(IN1 first, IN2 second) throws Exception {
return (OUT) ((Tuple3) first).f0;
}
}

public static class FlatJoiner<IN1, IN2, OUT> implements FlatJoinFunction<IN1, IN2, OUT> {
private static final long serialVersionUID = 1L;

@Override
public void join(IN1 first, IN2 second, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) first).f0);
}
}

public static class GroupReducer<IN, OUT> implements GroupReduceFunction<IN, OUT> {
private static final long serialVersionUID = 1L;

@Override
public void reduce(Iterable<IN> values, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) values.iterator().next()).f0);
}
}

public static class GroupCombiner<IN, OUT> implements GroupCombineFunction<IN, OUT> {
private static final long serialVersionUID = 1L;

@Override
public void combine(Iterable<IN> values, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) values.iterator().next()).f0);
}
}

public static class CoGrouper<IN1, IN2, OUT> implements CoGroupFunction<IN1, IN2, OUT> {
private static final long serialVersionUID = 1L;

@Override
public void coGroup(Iterable<IN1> first, Iterable<IN2> second, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) first.iterator().next()).f0);
}
}

}

0 comments on commit 4dbb10f

Please sign in to comment.