Skip to content

Commit

Permalink
[FLINK-30213] Change ForwardPartitioner to RebalancePartitioner on pa…
Browse files Browse the repository at this point in the history
…rallelism changes

In case of parallelism changes to the JobGraph, as done via the AdaptiveScheduler
or through providing JobVertexId overrides in PipelineOptions#PARALLELISM_OVERRIDES, the inner
serialized PartitionStrategy of a StreamTask may not be suitable anymore.

This is the case for the ForwardPartitioner strategy which uses a fixed local channel for
transmitting data. Whenever the consumer parallelism doesn't match the local parallelism, we should
be replacing it with the RebalancePartitioner.
  • Loading branch information
mxm authored and gyfora committed Jan 2, 2023
1 parent ded2df5 commit 61d6e78
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1342,11 +1342,12 @@ private void applyParallelismOverrides(JobGraph jobGraph) {
for (JobVertex vertex : jobGraph.getVertices()) {
String override = overrides.get(vertex.getID().toHexString());
if (override != null) {
int currentParallelism = vertex.getParallelism();
int overrideParallelism = Integer.parseInt(override);
log.info(
"Changing job vertex {} parallelism from {} to {}",
vertex.getID(),
vertex.getParallelism(),
currentParallelism,
overrideParallelism);
vertex.setParallelism(overrideParallelism);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.runtime.io.network.api.writer;

import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.core.io.IOReadableWritable;

import java.io.IOException;
Expand Down Expand Up @@ -71,4 +72,9 @@ public void broadcastEmit(T record) throws IOException {
flushAll();
}
}

@VisibleForTesting
public ChannelSelector<T> getChannelSelector() {
return channelSelector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public class NonChainedOutput implements Serializable {
private final OutputTag<?> outputTag;

/** The corresponding data partitioner. */
private final StreamPartitioner<?> partitioner;
private StreamPartitioner<?> partitioner;

/** Target {@link ResultPartitionType}. */
private final ResultPartitionType partitionType;
Expand Down Expand Up @@ -119,6 +119,10 @@ public OutputTag<?> getOutputTag() {
return outputTag;
}

public void setPartitioner(StreamPartitioner<?> partitioner) {
this.partitioner = partitioner;
}

public StreamPartitioner<?> getPartitioner() {
return partitioner;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@
import org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil;
import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler;
import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer;
Expand Down Expand Up @@ -1603,6 +1605,7 @@ List<RecordWriter<SerializationDelegate<StreamRecord<OUT>>>> createRecordWriters

int index = 0;
for (NonChainedOutput streamOutput : outputsInOrder) {
replaceForwardPartitionerIfConsumerParallelismDoesNotMatch(environment, streamOutput);
recordWriters.add(
createRecordWriter(
streamOutput,
Expand All @@ -1614,6 +1617,15 @@ List<RecordWriter<SerializationDelegate<StreamRecord<OUT>>>> createRecordWriters
return recordWriters;
}

private static void replaceForwardPartitionerIfConsumerParallelismDoesNotMatch(
Environment environment, NonChainedOutput streamOutput) {
if (streamOutput.getPartitioner() instanceof ForwardPartitioner
&& streamOutput.getConsumerParallelism()
!= environment.getTaskInfo().getNumberOfParallelSubtasks()) {
streamOutput.setPartitioner(new RescalePartitioner<>());
}
}

@SuppressWarnings("unchecked")
private static <OUT> RecordWriter<SerializationDelegate<StreamRecord<OUT>>> createRecordWriter(
NonChainedOutput streamOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;

import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -203,6 +204,11 @@ public OWNER finish() {
}

public <OUT> OWNER finishForSingletonOperatorChain(TypeSerializer<OUT> outputSerializer) {
return finishForSingletonOperatorChain(outputSerializer, new BroadcastPartitioner<>());
}

public <OUT> OWNER finishForSingletonOperatorChain(
TypeSerializer<OUT> outputSerializer, StreamPartitioner<?> partitioner) {

checkState(chainIndex == 0, "Use finishForSingletonOperatorChain");
checkState(headConfig == tailConfig);
Expand Down Expand Up @@ -231,7 +237,7 @@ public <OUT> OWNER finishForSingletonOperatorChain(TypeSerializer<OUT> outputSer
false,
new IntermediateDataSetID(),
null,
new BroadcastPartitioner<>(),
partitioner,
ResultPartitionType.PIPELINED_BOUNDED));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.CompletingCheckpointResponder;
import org.apache.flink.util.FlinkRuntimeException;
Expand Down Expand Up @@ -125,7 +126,8 @@ public void testNotWaitingForAllRecordsProcessedIfCheckpointNotEnabled() throws
.addInput(STRING_TYPE_INFO)
.addAdditionalOutput(partitionWriters)
.setupOperatorChain(new EmptyOperator())
.finishForSingletonOperatorChain(StringSerializer.INSTANCE)
.finishForSingletonOperatorChain(
StringSerializer.INSTANCE, new BroadcastPartitioner<>())
.build()) {
testHarness.endInput();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
import org.apache.flink.util.function.FunctionWithException;

Expand Down Expand Up @@ -108,6 +109,8 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> {
private Function<SingleInputGateBuilder, SingleInputGateBuilder> modifyGateBuilder =
Function.identity();

private StreamPartitioner<?> partitioner = new BroadcastPartitioner<>();

public StreamTaskMailboxTestHarnessBuilder(
FunctionWithException<Environment, ? extends StreamTask<OUT, ?>, Exception> taskFactory,
TypeInformation<OUT> outputType) {
Expand Down Expand Up @@ -324,11 +327,7 @@ private void initializeNetworkInput(
0, null, null, (StreamOperator<?>) null, null, SourceStreamTask.class);
StreamEdge streamEdge =
new StreamEdge(
sourceVertexDummy,
targetVertexDummy,
gateIndex + 1,
new BroadcastPartitioner<>(),
null);
sourceVertexDummy, targetVertexDummy, gateIndex + 1, partitioner, null);

inPhysicalEdges.add(streamEdge);
streamMockEnvironment.addInputGate(inputGates[gateIndex].getInputGate());
Expand Down Expand Up @@ -415,7 +414,7 @@ public StreamTaskMailboxTestHarnessBuilder<OUT> setupOutputForSingletonOperatorC
StreamOperatorFactory<?> factory, OperatorID operatorID) {
checkState(!setupCalled, "This harness was already setup.");
return setupOperatorChain(operatorID, factory)
.finishForSingletonOperatorChain(outputSerializer);
.finishForSingletonOperatorChain(outputSerializer, partitioner);
}

public StreamConfigChainer<StreamTaskMailboxTestHarnessBuilder<OUT>> setupOperatorChain(
Expand Down Expand Up @@ -462,6 +461,12 @@ public StreamTaskMailboxTestHarnessBuilder<OUT> setTaskStateSnapshot(
return this;
}

public StreamTaskMailboxTestHarnessBuilder<OUT> setOutputPartitioner(
StreamPartitioner partitioner) {
this.partitioner = partitioner;
return this;
}

/**
* A place holder representation of a {@link SourceInputConfig}. When building the test harness
* it is replaced with {@link SourceInputConfig}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.IntegerTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
Expand Down Expand Up @@ -53,10 +54,15 @@
import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
import org.apache.flink.runtime.io.network.api.StopMode;
import org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
import org.apache.flink.runtime.io.network.api.writer.ChannelSelectorRecordWriter;
import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.api.writer.SingleRecordWriter;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable;
import org.apache.flink.runtime.metrics.TimerGauge;
Expand All @@ -65,6 +71,7 @@
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.AbstractStateBackend;
Expand Down Expand Up @@ -109,6 +116,7 @@
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.graph.NonChainedOutput;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
Expand All @@ -122,6 +130,8 @@
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.runtime.io.DataInputStatus;
import org.apache.flink.streaming.runtime.io.StreamInputProcessor;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction;
import org.apache.flink.streaming.util.MockStreamConfig;
Expand All @@ -132,6 +142,7 @@
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FatalExitExceptionHandler;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.clock.SystemClock;
import org.apache.flink.util.concurrent.FutureUtils;
Expand Down Expand Up @@ -1799,6 +1810,64 @@ public void testMailboxMetricsMeasurement() throws Exception {
}
}

@Test
public void testForwardPartitionerIsConvertedToRebalanceOnParallelismChanges()
throws Exception {
StreamTaskMailboxTestHarnessBuilder<Integer> builder =
new StreamTaskMailboxTestHarnessBuilder<>(
OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO)
.addInput(BasicTypeInfo.INT_TYPE_INFO)
.setOutputPartitioner(new ForwardPartitioner<>())
.setupOutputForSingletonOperatorChain(
new TestBoundedOneInputStreamOperator());

try (StreamTaskMailboxTestHarness<Integer> harness = builder.build()) {

RecordWriterDelegate<SerializationDelegate<StreamRecord<Object>>> recordWriterDelegate =
harness.streamTask.createRecordWriterDelegate(
harness.streamTask.configuration, harness.streamMockEnvironment);
// Prerequisite: We are using the ForwardPartitioner
assertTrue(
((ChannelSelectorRecordWriter)
((SingleRecordWriter) recordWriterDelegate)
.getRecordWriter(0))
.getChannelSelector()
instanceof ForwardPartitioner);

// Change consumer parallelism
harness.streamTask.configuration.setVertexNonChainedOutputs(
List.of(
new NonChainedOutput(
false,
0,
// Set a different consumer parallelism to force trigger
// replacing the ForwardPartitioner
42,
100,
1000,
false,
new IntermediateDataSetID(),
new OutputTag<>("output", IntegerTypeInfo.INT_TYPE_INFO),
// Use forward partitioner
new ForwardPartitioner<>(),
ResultPartitionType.PIPELINED)));
harness.streamTask.configuration.serializeAllConfigs();

// Re-create outputs
recordWriterDelegate =
harness.streamTask.createRecordWriterDelegate(
harness.streamTask.configuration, harness.streamMockEnvironment);
// We should now have a RescalePartitioner to distribute the load
// for the non-matching downstream parallelism
assertTrue(
((ChannelSelectorRecordWriter)
((SingleRecordWriter) recordWriterDelegate)
.getRecordWriter(0))
.getChannelSelector()
instanceof RescalePartitioner);
}
}

private int getCurrentBufferSize(InputGate inputGate) {
return getTestChannel(inputGate, 0).getCurrentBufferSize();
}
Expand Down

0 comments on commit 61d6e78

Please sign in to comment.