diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java index c733af51334b1..cfa36bad1df3d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java @@ -1342,11 +1342,12 @@ private void applyParallelismOverrides(JobGraph jobGraph) { for (JobVertex vertex : jobGraph.getVertices()) { String override = overrides.get(vertex.getID().toHexString()); if (override != null) { + int currentParallelism = vertex.getParallelism(); int overrideParallelism = Integer.parseInt(override); log.info( "Changing job vertex {} parallelism from {} to {}", vertex.getID(), - vertex.getParallelism(), + currentParallelism, overrideParallelism); vertex.setParallelism(overrideParallelism); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java index 07181bd01f1c1..5b756b693b367 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.api.writer; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.io.IOReadableWritable; import java.io.IOException; @@ -71,4 +72,9 @@ public void broadcastEmit(T record) throws IOException { flushAll(); } } + + @VisibleForTesting + public ChannelSelector getChannelSelector() { + return channelSelector; + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java index 20d357c14e4b4..1042b3962769d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java @@ -59,7 +59,7 @@ public class NonChainedOutput implements Serializable { private final OutputTag outputTag; /** The corresponding data partitioner. */ - private final StreamPartitioner partitioner; + private StreamPartitioner partitioner; /** Target {@link ResultPartitionType}. */ private final ResultPartitionType partitionType; @@ -119,6 +119,10 @@ public OutputTag getOutputTag() { return outputTag; } + public void setPartitioner(StreamPartitioner partitioner) { + this.partitioner = partitioner; + } + public StreamPartitioner getPartitioner() { return partitioner; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index dc6a9dff6e4b8..16666b71804cd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -91,6 +91,8 @@ import org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil; import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler; import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer; @@ -1603,6 +1605,7 @@ List>>> createRecordWriters int index = 0; for (NonChainedOutput streamOutput : outputsInOrder) { + replaceForwardPartitionerIfConsumerParallelismDoesNotMatch(environment, streamOutput); recordWriters.add( createRecordWriter( streamOutput, @@ -1614,6 +1617,15 @@ List>>> createRecordWriters return recordWriters; } + private static void replaceForwardPartitionerIfConsumerParallelismDoesNotMatch( + Environment environment, NonChainedOutput streamOutput) { + if (streamOutput.getPartitioner() instanceof ForwardPartitioner + && streamOutput.getConsumerParallelism() + != environment.getTaskInfo().getNumberOfParallelSubtasks()) { + streamOutput.setPartitioner(new RescalePartitioner<>()); + } + } + @SuppressWarnings("unchecked") private static RecordWriter>> createRecordWriter( NonChainedOutput streamOutput, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java index 37520e0515c29..100340a6ac53c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java @@ -35,6 +35,7 @@ import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import java.util.Collections; import java.util.HashMap; @@ -203,6 +204,11 @@ public OWNER finish() { } public OWNER finishForSingletonOperatorChain(TypeSerializer outputSerializer) { + return finishForSingletonOperatorChain(outputSerializer, new BroadcastPartitioner<>()); + } + + public OWNER finishForSingletonOperatorChain( + TypeSerializer outputSerializer, StreamPartitioner partitioner) { checkState(chainIndex == 0, "Use finishForSingletonOperatorChain"); checkState(headConfig == tailConfig); @@ -231,7 +237,7 @@ public OWNER finishForSingletonOperatorChain(TypeSerializer outputSer false, new IntermediateDataSetID(), null, - new BroadcastPartitioner<>(), + partitioner, ResultPartitionType.PIPELINED_BOUNDED)); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java index 08ab5d99e6025..637784d132fe1 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java @@ -53,6 +53,7 @@ import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.CompletingCheckpointResponder; import org.apache.flink.util.FlinkRuntimeException; @@ -125,7 +126,8 @@ public void testNotWaitingForAllRecordsProcessedIfCheckpointNotEnabled() throws .addInput(STRING_TYPE_INFO) .addAdditionalOutput(partitionWriters) .setupOperatorChain(new EmptyOperator()) - .finishForSingletonOperatorChain(StringSerializer.INSTANCE) + .finishForSingletonOperatorChain( + StringSerializer.INSTANCE, new BroadcastPartitioner<>()) .build()) { testHarness.endInput(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java index 7fb875ada3566..dcf9ffb3b8c97 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java @@ -58,6 +58,7 @@ import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; import org.apache.flink.util.function.FunctionWithException; @@ -108,6 +109,8 @@ public class StreamTaskMailboxTestHarnessBuilder { private Function modifyGateBuilder = Function.identity(); + private StreamPartitioner partitioner = new BroadcastPartitioner<>(); + public StreamTaskMailboxTestHarnessBuilder( FunctionWithException, Exception> taskFactory, TypeInformation outputType) { @@ -324,11 +327,7 @@ private void initializeNetworkInput( 0, null, null, (StreamOperator) null, null, SourceStreamTask.class); StreamEdge streamEdge = new StreamEdge( - sourceVertexDummy, - targetVertexDummy, - gateIndex + 1, - new BroadcastPartitioner<>(), - null); + sourceVertexDummy, targetVertexDummy, gateIndex + 1, partitioner, null); inPhysicalEdges.add(streamEdge); streamMockEnvironment.addInputGate(inputGates[gateIndex].getInputGate()); @@ -415,7 +414,7 @@ public StreamTaskMailboxTestHarnessBuilder setupOutputForSingletonOperatorC StreamOperatorFactory factory, OperatorID operatorID) { checkState(!setupCalled, "This harness was already setup."); return setupOperatorChain(operatorID, factory) - .finishForSingletonOperatorChain(outputSerializer); + .finishForSingletonOperatorChain(outputSerializer, partitioner); } public StreamConfigChainer> setupOperatorChain( @@ -462,6 +461,12 @@ public StreamTaskMailboxTestHarnessBuilder setTaskStateSnapshot( return this; } + public StreamTaskMailboxTestHarnessBuilder setOutputPartitioner( + StreamPartitioner partitioner) { + this.partitioner = partitioner; + return this; + } + /** * A place holder representation of a {@link SourceInputConfig}. When building the test harness * it is replaced with {@link SourceInputConfig}. diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index 9f5bbf227e90f..30063b0a09e2c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; import org.apache.flink.api.common.state.CheckpointListener; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.IntegerTypeInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.ReadableConfig; @@ -53,10 +54,15 @@ import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder; import org.apache.flink.runtime.io.network.api.StopMode; import org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter; +import org.apache.flink.runtime.io.network.api.writer.ChannelSelectorRecordWriter; +import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.api.writer.SingleRecordWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable; import org.apache.flink.runtime.metrics.TimerGauge; @@ -65,6 +71,7 @@ import org.apache.flink.runtime.operators.testutils.ExpectedTestException; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; +import org.apache.flink.runtime.plugable.SerializationDelegate; import org.apache.flink.runtime.shuffle.ShuffleEnvironment; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; @@ -109,6 +116,7 @@ import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.graph.NonChainedOutput; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; @@ -122,6 +130,8 @@ import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; import org.apache.flink.streaming.runtime.io.DataInputStatus; import org.apache.flink.streaming.runtime.io.StreamInputProcessor; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction; import org.apache.flink.streaming.util.MockStreamConfig; @@ -132,6 +142,7 @@ import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FatalExitExceptionHandler; import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.OutputTag; import org.apache.flink.util.TestLogger; import org.apache.flink.util.clock.SystemClock; import org.apache.flink.util.concurrent.FutureUtils; @@ -1799,6 +1810,64 @@ public void testMailboxMetricsMeasurement() throws Exception { } } + @Test + public void testForwardPartitionerIsConvertedToRebalanceOnParallelismChanges() + throws Exception { + StreamTaskMailboxTestHarnessBuilder builder = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO) + .addInput(BasicTypeInfo.INT_TYPE_INFO) + .setOutputPartitioner(new ForwardPartitioner<>()) + .setupOutputForSingletonOperatorChain( + new TestBoundedOneInputStreamOperator()); + + try (StreamTaskMailboxTestHarness harness = builder.build()) { + + RecordWriterDelegate>> recordWriterDelegate = + harness.streamTask.createRecordWriterDelegate( + harness.streamTask.configuration, harness.streamMockEnvironment); + // Prerequisite: We are using the ForwardPartitioner + assertTrue( + ((ChannelSelectorRecordWriter) + ((SingleRecordWriter) recordWriterDelegate) + .getRecordWriter(0)) + .getChannelSelector() + instanceof ForwardPartitioner); + + // Change consumer parallelism + harness.streamTask.configuration.setVertexNonChainedOutputs( + List.of( + new NonChainedOutput( + false, + 0, + // Set a different consumer parallelism to force trigger + // replacing the ForwardPartitioner + 42, + 100, + 1000, + false, + new IntermediateDataSetID(), + new OutputTag<>("output", IntegerTypeInfo.INT_TYPE_INFO), + // Use forward partitioner + new ForwardPartitioner<>(), + ResultPartitionType.PIPELINED))); + harness.streamTask.configuration.serializeAllConfigs(); + + // Re-create outputs + recordWriterDelegate = + harness.streamTask.createRecordWriterDelegate( + harness.streamTask.configuration, harness.streamMockEnvironment); + // We should now have a RescalePartitioner to distribute the load + // for the non-matching downstream parallelism + assertTrue( + ((ChannelSelectorRecordWriter) + ((SingleRecordWriter) recordWriterDelegate) + .getRecordWriter(0)) + .getChannelSelector() + instanceof RescalePartitioner); + } + } + private int getCurrentBufferSize(InputGate inputGate) { return getTestChannel(inputGate, 0).getCurrentBufferSize(); }