Skip to content

Commit

Permalink
[FLINK-20517] Support mixed keyed/non-keyed operations in BATCH mode
Browse files Browse the repository at this point in the history
  • Loading branch information
aljoscha committed Jan 7, 2021
1 parent 606c44b commit 9bec335
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,25 +436,35 @@ public <R> SingleOutputStreamOperator<R> transform(
outTypeInfo,
environment.getParallelism());

if (inputStream1 instanceof KeyedStream && inputStream2 instanceof KeyedStream) {
TypeInformation<?> keyType = null;
if (inputStream1 instanceof KeyedStream) {
KeyedStream<IN1, ?> keyedInput1 = (KeyedStream<IN1, ?>) inputStream1;

keyType = keyedInput1.getKeyType();

transform.setStateKeySelectors(keyedInput1.getKeySelector(), null);
transform.setStateKeyType(keyType);
}
if (inputStream2 instanceof KeyedStream) {
KeyedStream<IN2, ?> keyedInput2 = (KeyedStream<IN2, ?>) inputStream2;

TypeInformation<?> keyType1 = keyedInput1.getKeyType();
TypeInformation<?> keyType2 = keyedInput2.getKeyType();
if (!(keyType1.canEqual(keyType2) && keyType1.equals(keyType2))) {

if (keyType != null && !(keyType.canEqual(keyType2) && keyType.equals(keyType2))) {
throw new UnsupportedOperationException(
"Key types if input KeyedStreams "
+ "don't match: "
+ keyType1
+ keyType
+ " and "
+ keyType2
+ ".");
}

transform.setStateKeySelectors(
keyedInput1.getKeySelector(), keyedInput2.getKeySelector());
transform.setStateKeyType(keyType1);
transform.getStateKeySelector1(), keyedInput2.getKeySelector());

// we might be overwriting the one that's already set, but it's the same
transform.setStateKeyType(keyType2);
}

@SuppressWarnings({"unchecked", "rawtypes"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.api.graph.SimpleTransformationTranslator;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.graph.StreamGraph;
Expand All @@ -34,6 +35,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.IntStream;

import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
Expand All @@ -52,13 +54,24 @@ public class MultiInputTransformationTranslator<OUT>
protected Collection<Integer> translateForBatchInternal(
final AbstractMultipleInputTransformation<OUT> transformation, final Context context) {
Collection<Integer> ids = translateInternal(transformation, context);
boolean isKeyed = transformation instanceof KeyedMultipleInputTransformation;
if (isKeyed) {
if (transformation instanceof KeyedMultipleInputTransformation) {
KeyedMultipleInputTransformation<OUT> keyedTransformation =
(KeyedMultipleInputTransformation<OUT>) transformation;
List<Transformation<?>> inputs = transformation.getInputs();
List<KeySelector<?, ?>> keySelectors = keyedTransformation.getStateKeySelectors();

StreamConfig.InputRequirement[] inputRequirements =
inputs.stream()
.map((input) -> StreamConfig.InputRequirement.SORTED)
IntStream.range(0, inputs.size())
.mapToObj(
idx -> {
if (keySelectors.get(idx) != null) {
return StreamConfig.InputRequirement.SORTED;
} else {
return StreamConfig.InputRequirement.PASS_THROUGH;
}
})
.toArray(StreamConfig.InputRequirement[]::new);

BatchExecutionUtils.applyBatchExecutionSettings(
transformation.getId(), context, inputRequirements);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,21 @@ public class TwoInputTransformationTranslator<IN1, IN2, OUT>
protected Collection<Integer> translateForBatchInternal(
final TwoInputTransformation<IN1, IN2, OUT> transformation, final Context context) {
Collection<Integer> ids = translateInternal(transformation, context);
boolean isKeyed =

StreamConfig.InputRequirement input1Requirement =
transformation.getStateKeySelector1() != null
&& transformation.getStateKeySelector2() != null;
if (isKeyed) {
? StreamConfig.InputRequirement.SORTED
: StreamConfig.InputRequirement.PASS_THROUGH;

StreamConfig.InputRequirement input2Requirement =
transformation.getStateKeySelector2() != null
? StreamConfig.InputRequirement.SORTED
: StreamConfig.InputRequirement.PASS_THROUGH;

if (input1Requirement == StreamConfig.InputRequirement.SORTED
|| input2Requirement == StreamConfig.InputRequirement.SORTED) {
BatchExecutionUtils.applyBatchExecutionSettings(
transformation.getId(),
context,
StreamConfig.InputRequirement.SORTED,
StreamConfig.InputRequirement.SORTED);
transformation.getId(), context, input1Requirement, input2Requirement);
}
return ids;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.flink.api.common.state.ReadOnlyBroadcastState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
Expand All @@ -38,6 +39,9 @@
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.CollectionUtil;
Expand Down Expand Up @@ -191,6 +195,119 @@ public void batchSumSingleResultPerKey() throws Exception {
}
}

/**
* Verifies that all regular input is processed before keyed input.
*
* <p>Here, the first input is keyed while the second input is not keyed.
*/
@Test
public void batchKeyedNonKeyedTwoInputOperator() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
env.setRuntimeMode(RuntimeExecutionMode.BATCH);

DataStream<Tuple2<String, Integer>> keyedInput =
env.fromElements(
Tuple2.of("regular2", 4),
Tuple2.of("regular1", 3),
Tuple2.of("regular1", 2),
Tuple2.of("regular2", 1))
.assignTimestampsAndWatermarks(
WatermarkStrategy.<Tuple2<String, Integer>>forMonotonousTimestamps()
.withTimestampAssigner((in, ts) -> in.f1));

DataStream<Tuple2<String, Integer>> regularInput =
env.fromElements(
Tuple2.of("regular4", 4),
Tuple2.of("regular3", 3),
Tuple2.of("regular3", 2),
Tuple2.of("regular4", 1))
.assignTimestampsAndWatermarks(
WatermarkStrategy.<Tuple2<String, Integer>>forMonotonousTimestamps()
.withTimestampAssigner((in, ts) -> in.f1));

DataStream<String> result =
keyedInput
.keyBy(in -> in.f0)
.connect(regularInput)
.transform(
"operator",
BasicTypeInfo.STRING_TYPE_INFO,
new TwoInputIdentityOperator());

try (CloseableIterator<String> resultIterator = result.executeAndCollect()) {
List<String> results = CollectionUtil.iteratorToList(resultIterator);
assertThat(
results,
equalTo(
Arrays.asList(
"(regular4,4)",
"(regular3,3)",
"(regular3,2)",
"(regular4,1)",
"(regular1,2)",
"(regular1,3)",
"(regular2,1)",
"(regular2,4)")));
}
}

/**
* Verifies that all regular input is processed before keyed input.
*
* <p>Here, the first input is not keyed while the second input is keyed.
*/
@Test
public void batchNonKeyedKeyedTwoInputOperator() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
env.setRuntimeMode(RuntimeExecutionMode.BATCH);

DataStream<Tuple2<String, Integer>> keyedInput =
env.fromElements(
Tuple2.of("regular2", 4),
Tuple2.of("regular1", 3),
Tuple2.of("regular1", 2),
Tuple2.of("regular2", 1))
.assignTimestampsAndWatermarks(
WatermarkStrategy.<Tuple2<String, Integer>>forMonotonousTimestamps()
.withTimestampAssigner((in, ts) -> in.f1));

DataStream<Tuple2<String, Integer>> regularInput =
env.fromElements(
Tuple2.of("regular4", 4),
Tuple2.of("regular3", 3),
Tuple2.of("regular3", 2),
Tuple2.of("regular4", 1))
.assignTimestampsAndWatermarks(
WatermarkStrategy.<Tuple2<String, Integer>>forMonotonousTimestamps()
.withTimestampAssigner((in, ts) -> in.f1));

DataStream<String> result =
regularInput
.connect(keyedInput.keyBy(in -> in.f0))
.transform(
"operator",
BasicTypeInfo.STRING_TYPE_INFO,
new TwoInputIdentityOperator());

try (CloseableIterator<String> resultIterator = result.executeAndCollect()) {
List<String> results = CollectionUtil.iteratorToList(resultIterator);
assertThat(
results,
equalTo(
Arrays.asList(
"(regular4,4)",
"(regular3,3)",
"(regular3,2)",
"(regular4,1)",
"(regular1,2)",
"(regular1,3)",
"(regular2,1)",
"(regular2,4)")));
}
}

/** Verifies that all broadcast input is processed before keyed input. */
@Test
public void batchKeyedBroadcastExecution() throws Exception {
Expand Down Expand Up @@ -402,4 +519,22 @@ public void processBroadcastElement(
state.put(value.f0, value.f0);
}
}

private static class TwoInputIdentityOperator extends AbstractStreamOperator<String>
implements TwoInputStreamOperator<
Tuple2<String, Integer>, Tuple2<String, Integer>, String> {
@Override
public void processElement1(StreamRecord<Tuple2<String, Integer>> element)
throws Exception {
output.collect(
new StreamRecord<>(element.getValue().toString(), element.getTimestamp()));
}

@Override
public void processElement2(StreamRecord<Tuple2<String, Integer>> element)
throws Exception {
output.collect(
new StreamRecord<>(element.getValue().toString(), element.getTimestamp()));
}
}
}

0 comments on commit 9bec335

Please sign in to comment.