Skip to content

Commit

Permalink
[FLINK-20304] Introduce the BroadcastStateTransformation.
Browse files Browse the repository at this point in the history
So far the Broadcast State pattern was a TwoInputTransformation.
This did not allow for special handling when it comes to
translating it for Batch or Streaming execution. This commit
introduces a special transformation just for this.

This closes apache#14216.
  • Loading branch information
kl0u committed Nov 27, 2020
1 parent 4bede4b commit ed82aab
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.operators.co.CoBroadcastWithKeyedOperator;
import org.apache.flink.streaming.api.operators.co.CoBroadcastWithNonKeyedOperator;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.BroadcastStateTransformation;
import org.apache.flink.util.Preconditions;

import java.util.List;
Expand Down Expand Up @@ -222,27 +223,39 @@ private <OUT> SingleOutputStreamOperator<OUT> transform(
nonBroadcastStream.getType();
broadcastStream.getType();

TwoInputTransformation<IN1, IN2, OUT> transform = new TwoInputTransformation<>(
nonBroadcastStream.getTransformation(),
broadcastStream.getTransformation(),
functionName,
operator,
outTypeInfo,
environment.getParallelism());
final BroadcastStateTransformation<IN1, IN2, OUT> transformation =
getBroadcastStateTransformation(functionName, outTypeInfo, operator);

if (nonBroadcastStream instanceof KeyedStream) {
KeyedStream<IN1, ?> keyedInput1 = (KeyedStream<IN1, ?>) nonBroadcastStream;
TypeInformation<?> keyType1 = keyedInput1.getKeyType();
transform.setStateKeySelectors(keyedInput1.getKeySelector(), null);
transform.setStateKeyType(keyType1);
}
@SuppressWarnings({"unchecked", "rawtypes"})
final SingleOutputStreamOperator<OUT> returnStream =
new SingleOutputStreamOperator(environment, transformation);

@SuppressWarnings({ "unchecked", "rawtypes" })
SingleOutputStreamOperator<OUT> returnStream = new SingleOutputStreamOperator(environment, transform);
getExecutionEnvironment().addOperator(transformation);
return returnStream;
}

getExecutionEnvironment().addOperator(transform);
private <OUT> BroadcastStateTransformation<IN1, IN2, OUT> getBroadcastStateTransformation(
final String functionName,
final TypeInformation<OUT> outTypeInfo,
final TwoInputStreamOperator<IN1, IN2, OUT> operator) {

return returnStream;
if (nonBroadcastStream instanceof KeyedStream) {
return BroadcastStateTransformation.forKeyedStream(
functionName,
(KeyedStream<IN1, ?>) nonBroadcastStream,
broadcastStream,
SimpleOperatorFactory.of(operator),
outTypeInfo,
environment.getParallelism());
} else {
return BroadcastStateTransformation.forNonKeyedStream(
functionName,
nonBroadcastStream,
broadcastStream,
SimpleOperatorFactory.of(operator),
outTypeInfo,
environment.getParallelism());
}
}

protected <F> F clean(F f) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionInternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.sorted.state.BatchExecutionStateBackend;
import org.apache.flink.streaming.api.transformations.BroadcastStateTransformation;
import org.apache.flink.streaming.api.transformations.CoFeedbackTransformation;
import org.apache.flink.streaming.api.transformations.FeedbackTransformation;
import org.apache.flink.streaming.api.transformations.KeyedMultipleInputTransformation;
Expand All @@ -53,6 +54,7 @@
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.UnionTransformation;
import org.apache.flink.streaming.api.transformations.WithBoundedness;
import org.apache.flink.streaming.runtime.translators.BroadcastStateTransformationTranslator;
import org.apache.flink.streaming.runtime.translators.LegacySinkTransformationTranslator;
import org.apache.flink.streaming.runtime.translators.LegacySourceTransformationTranslator;
import org.apache.flink.streaming.runtime.translators.MultiInputTransformationTranslator;
Expand Down Expand Up @@ -167,6 +169,7 @@ public class StreamGraphGenerator {
tmp.put(SideOutputTransformation.class, new SideOutputTransformationTranslator<>());
tmp.put(ReduceTransformation.class, new ReduceTransformationTranslator<>());
tmp.put(TimestampsAndWatermarksTransformation.class, new TimestampsAndWatermarksTransformationTranslator<>());
tmp.put(BroadcastStateTransformation.class, new BroadcastStateTransformationTranslator<>());
translatorMap = Collections.unmodifiableMap(tmp);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.transformations;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.api.datastream.BroadcastStream;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.List;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
* This is the transformation for the Broadcast State pattern. In a nutshell, this transformation
* allows to take a broadcasted (non-keyed) stream, connect it with another keyed or non-keyed
* stream, and apply a function on the resulting connected stream. This function will have access
* to all the elements that belong to the non-keyed, broadcasted side, as this is kept in Flink's
* state.
*
* <p>For more information see the
* <a href="https://ci.apache.org/projects/flink/flink-docs-stable/dev/stream/state/broadcast_state.html">
* Broadcast State Pattern documentation page</a>.
*
* @param <IN1> The type of the elements in the non-broadcasted input.
* @param <IN2> The type of the elements in the broadcasted input.
* @param <OUT> The type of the elements that result from this transformation.
*/
@Internal
public class BroadcastStateTransformation<IN1, IN2, OUT> extends PhysicalTransformation<OUT> {

private final Transformation<IN1> nonBroadcastStream;

private final Transformation<IN2> broadcastStream;

private final StreamOperatorFactory<OUT> operatorFactory;

private final TypeInformation<?> stateKeyType;

private final KeySelector<IN1, ?> keySelector;

private BroadcastStateTransformation(
final String name,
final Transformation<IN1> inputStream,
final Transformation<IN2> broadcastStream,
final StreamOperatorFactory<OUT> operatorFactory,
@Nullable final TypeInformation<?> keyType,
@Nullable final KeySelector<IN1, ?> keySelector,
final TypeInformation<OUT> outTypeInfo,
final int parallelism) {
super(name, outTypeInfo, parallelism);
this.nonBroadcastStream = checkNotNull(inputStream);
this.broadcastStream = checkNotNull(broadcastStream);
this.operatorFactory = checkNotNull(operatorFactory);

this.stateKeyType = keyType;
this.keySelector = keySelector;
updateManagedMemoryStateBackendUseCase(keySelector != null);
}

public Transformation<IN2> getBroadcastStream() {
return broadcastStream;
}

public Transformation<IN1> getNonBroadcastStream() {
return nonBroadcastStream;
}

public StreamOperatorFactory<OUT> getOperatorFactory() {
return operatorFactory;
}

public TypeInformation<?> getStateKeyType() {
return stateKeyType;
}

public KeySelector<IN1, ?> getKeySelector() {
return keySelector;
}

@Override
public void setChainingStrategy(ChainingStrategy strategy) {
this.operatorFactory.getChainingStrategy();
}

@Override
public List<Transformation<?>> getTransitivePredecessors() {
final List<Transformation<?>> predecessors = new ArrayList<>();
predecessors.add(this);
predecessors.add(nonBroadcastStream);
predecessors.add(broadcastStream);
return predecessors;
}

@Override
public List<Transformation<?>> getInputs() {
final List<Transformation<?>> predecessors = new ArrayList<>();
predecessors.add(nonBroadcastStream);
predecessors.add(broadcastStream);
return predecessors;
}

// ------------------------------- Static Constructors -------------------------------

public static <IN1, IN2, OUT> BroadcastStateTransformation<IN1, IN2, OUT> forNonKeyedStream(
final String name,
final DataStream<IN1> nonBroadcastStream,
final BroadcastStream<IN2> broadcastStream,
final StreamOperatorFactory<OUT> operatorFactory,
final TypeInformation<OUT> outTypeInfo,
final int parallelism) {
return new BroadcastStateTransformation<>(
name,
checkNotNull(nonBroadcastStream).getTransformation(),
checkNotNull(broadcastStream).getTransformation(),
operatorFactory,
null,
null,
outTypeInfo,
parallelism);
}

public static <IN1, IN2, OUT> BroadcastStateTransformation<IN1, IN2, OUT> forKeyedStream(
final String name,
final KeyedStream<IN1, ?> nonBroadcastStream,
final BroadcastStream<IN2> broadcastStream,
final StreamOperatorFactory<OUT> operatorFactory,
final TypeInformation<OUT> outTypeInfo,
final int parallelism) {
return new BroadcastStateTransformation<>(
name,
checkNotNull(nonBroadcastStream).getTransformation(),
checkNotNull(broadcastStream).getTransformation(),
operatorFactory,
nonBroadcastStream.getKeyType(),
nonBroadcastStream.getKeySelector(),
outTypeInfo,
parallelism);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
*/
abstract class AbstractOneInputTransformationTranslator<IN, OUT, OP extends Transformation<OUT>>
extends SimpleTransformationTranslator<OUT, OP> {

protected Collection<Integer> translateInternal(
final Transformation<OUT> transformation,
final StreamOperatorFactory<OUT> operatorFactory,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.runtime.translators;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.streaming.api.graph.SimpleTransformationTranslator;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.graph.TransformationTranslator;
import org.apache.flink.streaming.api.transformations.BroadcastStateTransformation;

import java.util.Collection;
import java.util.Collections;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
* A {@link TransformationTranslator} for the {@link BroadcastStateTransformation}.
*
* @param <IN1> The type of the elements in the non-broadcasted input of the {@link BroadcastStateTransformation}.
* @param <IN2> The type of the elements in the broadcasted input of the {@link BroadcastStateTransformation}.
* @param <OUT> The type of the elements that result from the {@link BroadcastStateTransformation}.
*/
@Internal
public class BroadcastStateTransformationTranslator<IN1, IN2, OUT>
extends SimpleTransformationTranslator<OUT, BroadcastStateTransformation<IN1, IN2, OUT>> {

@Override
protected Collection<Integer> translateForBatchInternal(
final BroadcastStateTransformation<IN1, IN2, OUT> transformation,
final Context context) {
throw new UnsupportedOperationException("The Broadcast State Pattern is not support in BATCH execution mode.");
}

@Override
protected Collection<Integer> translateForStreamingInternal(
final BroadcastStateTransformation<IN1, IN2, OUT> transformation,
final Context context) {
checkNotNull(transformation);
checkNotNull(context);

final TypeInformation<IN1> nonBroadcastTypeInfo =
transformation.getNonBroadcastStream().getOutputType();
final TypeInformation<IN2> broadcastTypeInfo =
transformation.getBroadcastStream().getOutputType();

final StreamGraph streamGraph = context.getStreamGraph();
final String slotSharingGroup = context.getSlotSharingGroup();
final int transformationId = transformation.getId();
final ExecutionConfig executionConfig = streamGraph.getExecutionConfig();

streamGraph.addCoOperator(
transformationId,
slotSharingGroup,
transformation.getCoLocationGroupKey(),
transformation.getOperatorFactory(),
nonBroadcastTypeInfo,
broadcastTypeInfo,
transformation.getOutputType(),
transformation.getName());

if (transformation.getKeySelector() != null) {
final TypeSerializer<?> keySerializer =
transformation.getStateKeyType().createSerializer(executionConfig);

streamGraph.setTwoInputStateKey(
transformationId,
transformation.getKeySelector(),
null,
keySerializer);
}

final int parallelism = transformation.getParallelism() != ExecutionConfig.PARALLELISM_DEFAULT
? transformation.getParallelism()
: executionConfig.getParallelism();
streamGraph.setParallelism(transformationId, parallelism);
streamGraph.setMaxParallelism(transformationId, transformation.getMaxParallelism());

for (Integer inputId: context.getStreamNodeIds(transformation.getNonBroadcastStream())) {
streamGraph.addEdge(inputId, transformationId, 1);
}

for (Integer inputId: context.getStreamNodeIds(transformation.getBroadcastStream())) {
streamGraph.addEdge(inputId, transformationId, 2);
}

return Collections.singleton(transformationId);

}
}
Loading

0 comments on commit ed82aab

Please sign in to comment.