diff --git a/flink-core/src/test/java/org/apache/flink/api/connector/source/mocks/MockSplitEnumerator.java b/flink-core/src/test/java/org/apache/flink/api/connector/source/mocks/MockSplitEnumerator.java index 2b9dbd35edf10..7a3763bc5f978 100644 --- a/flink-core/src/test/java/org/apache/flink/api/connector/source/mocks/MockSplitEnumerator.java +++ b/flink-core/src/test/java/org/apache/flink/api/connector/source/mocks/MockSplitEnumerator.java @@ -22,6 +22,7 @@ import org.apache.flink.api.connector.source.SplitEnumerator; import org.apache.flink.api.connector.source.SplitEnumeratorContext; import org.apache.flink.api.connector.source.SplitsAssignment; +import org.apache.flink.api.connector.source.SupportsBatchSnapshot; import javax.annotation.Nullable; @@ -38,7 +39,8 @@ import java.util.TreeSet; /** A mock {@link SplitEnumerator} for unit tests. */ -public class MockSplitEnumerator implements SplitEnumerator> { +public class MockSplitEnumerator + implements SplitEnumerator>, SupportsBatchSnapshot { private final SortedSet unassignedSplits; private final SplitEnumeratorContext enumContext; private final List handledSourceEvent; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java index 0fb840a25a785..c53e227ec1ee0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java @@ -1480,7 +1480,8 @@ private void releasePartitionGroups( } } - ResultPartitionID createResultPartitionId( + @VisibleForTesting + public ResultPartitionID createResultPartitionId( final IntermediateResultPartitionID resultPartitionId) { final SchedulingResultPartition schedulingResultPartition = getSchedulingTopology().getResultPartition(resultPartitionId); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index 25c6187bfbfca..46769a45f8906 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -127,7 +127,7 @@ public class Execution private final ExecutionVertex vertex; /** The unique ID marking the specific execution instant of the task. */ - private final ExecutionAttemptID attemptId; + private ExecutionAttemptID attemptId; /** * The timestamps when state transitions occurred, indexed by {@link ExecutionState#ordinal()}. @@ -453,6 +453,40 @@ public CompletableFuture registerProducedPartitions(TaskManagerLocation lo }); } + private void recoverAttempt(ExecutionAttemptID newId) { + if (!this.attemptId.equals(newId)) { + getVertex().getExecutionGraphAccessor().deregisterExecution(this); + this.attemptId = newId; + getVertex().getExecutionGraphAccessor().registerExecution(this); + } + } + + /** Recover the execution attempt status after JM failover. */ + public void recoverExecution( + ExecutionAttemptID attemptId, + TaskManagerLocation location, + Map> userAccumulators, + IOMetrics metrics) { + recoverAttempt(attemptId); + taskManagerLocationFuture.complete(location); + + try { + transitionState(this.state, FINISHED); + finishPartitionsAndUpdateConsumers(); + updateAccumulatorsAndMetrics(userAccumulators, metrics); + releaseAssignedResource(null); + vertex.getExecutionGraphAccessor().deregisterExecution(this); + } finally { + vertex.executionFinished(this); + } + } + + public void recoverProducedPartitions( + Map + producedPartitions) { + this.producedPartitions = checkNotNull(producedPartitions); + } + private static CompletableFuture< Map> registerProducedPartitions( @@ -469,7 +503,6 @@ public CompletableFuture registerProducedPartitions(TaskManagerLocation lo for (IntermediateResultPartition partition : partitions) { PartitionDescriptor partitionDescriptor = PartitionDescriptor.from(partition); - int maxParallelism = getPartitionMaxParallelism(partition); CompletableFuture shuffleDescriptorFuture = vertex.getExecutionGraphAccessor() .getShuffleMaster() @@ -479,10 +512,8 @@ public CompletableFuture registerProducedPartitions(TaskManagerLocation lo CompletableFuture partitionRegistration = shuffleDescriptorFuture.thenApply( shuffleDescriptor -> - new ResultPartitionDeploymentDescriptor( - partitionDescriptor, - shuffleDescriptor, - maxParallelism)); + createResultPartitionDeploymentDescriptor( + partitionDescriptor, partition, shuffleDescriptor)); partitionRegistrations.add(partitionRegistration); } @@ -503,6 +534,21 @@ private static int getPartitionMaxParallelism(IntermediateResultPartition partit return partition.getIntermediateResult().getConsumersMaxParallelism(); } + public static ResultPartitionDeploymentDescriptor createResultPartitionDeploymentDescriptor( + IntermediateResultPartition partition, ShuffleDescriptor shuffleDescriptor) { + PartitionDescriptor partitionDescriptor = PartitionDescriptor.from(partition); + return createResultPartitionDeploymentDescriptor( + partitionDescriptor, partition, shuffleDescriptor); + } + + private static ResultPartitionDeploymentDescriptor createResultPartitionDeploymentDescriptor( + PartitionDescriptor partitionDescriptor, + IntermediateResultPartition partition, + ShuffleDescriptor shuffleDescriptor) { + return new ResultPartitionDeploymentDescriptor( + partitionDescriptor, shuffleDescriptor, getPartitionMaxParallelism(partition)); + } + /** * Deploys the execution to the previously assigned resource. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java index 2fa43818ce2e6..440493c45ddd2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java @@ -92,13 +92,13 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio private final ScheduledExecutor delayExecutor; - private final SchedulingStrategy schedulingStrategy; + protected final SchedulingStrategy schedulingStrategy; private final ExecutionOperations executionOperations; private final Set verticesWaitingForRestart; - private final ShuffleMaster shuffleMaster; + protected final ShuffleMaster shuffleMaster; private final Map reservedAllocationRefCounters; @@ -109,6 +109,8 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio protected final ExecutionDeployer executionDeployer; + protected final FailoverStrategy failoverStrategy; + protected DefaultScheduler( final Logger log, final JobGraph jobGraph, @@ -162,7 +164,7 @@ protected DefaultScheduler( this.reservedAllocationRefCounters = new HashMap<>(); this.reservedAllocationByExecutionVertex = new HashMap<>(); - final FailoverStrategy failoverStrategy = + this.failoverStrategy = failoverStrategyFactory.create( getSchedulingTopology(), getResultPartitionAvailabilityChecker()); log.info( @@ -301,7 +303,7 @@ private Throwable maybeTranslateToClusterDatasetException( cause, Collections.singletonList(failedPartitionId.getIntermediateDataSetID())); } - private void notifyCoordinatorsAboutTaskFailure( + protected void notifyCoordinatorsAboutTaskFailure( final Execution execution, @Nullable final Throwable error) { final ExecutionJobVertex jobVertex = execution.getVertex().getJobVertex(); final int subtaskIndex = execution.getParallelSubtaskIndex(); @@ -323,7 +325,7 @@ public void handleGlobalFailure(final Throwable error) { maybeRestartTasks(failureHandlingResult); } - private void maybeRestartTasks(final FailureHandlingResult failureHandlingResult) { + protected void maybeRestartTasks(final FailureHandlingResult failureHandlingResult) { if (failureHandlingResult.canRestart()) { restartTasksWithDelay(failureHandlingResult); } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java index 6c71c7c06fcae..e9629716412ab 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SchedulerBase.java @@ -756,8 +756,7 @@ protected final void archiveFromFailureHandlingResult( } @Override - public final boolean updateTaskExecutionState( - final TaskExecutionStateTransition taskExecutionState) { + public boolean updateTaskExecutionState(final TaskExecutionStateTransition taskExecutionState) { final ExecutionAttemptID attemptId = taskExecutionState.getID(); final Execution execution = executionGraph.getRegisteredExecutions().get(attemptId); @@ -1142,7 +1141,7 @@ public void notifyEndOfData(ExecutionAttemptID executionAttemptID) { // ------------------------------------------------------------------------ @VisibleForTesting - JobID getJobId() { + protected JobID getJobId() { return jobGraph.getJobID(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java index 7154a1b5ff7e0..d9578e7be1093 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.execution.SuppressRestartsException; import org.apache.flink.runtime.executiongraph.Execution; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo; @@ -42,12 +43,15 @@ import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.executiongraph.JobStatusListener; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; import org.apache.flink.runtime.executiongraph.MarkPartitionFinishedStrategy; import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos; import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils; import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy; import org.apache.flink.runtime.executiongraph.failover.FailureHandlingResult; import org.apache.flink.runtime.executiongraph.failover.RestartBackoffTimeStrategy; +import org.apache.flink.runtime.executiongraph.failover.RestartPipelinedRegionFailoverStrategy; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; @@ -73,6 +77,7 @@ import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategyFactory; import org.apache.flink.runtime.shuffle.ShuffleMaster; import org.apache.flink.runtime.source.coordinator.SourceCoordinator; +import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.util.concurrent.FutureUtils; import org.apache.flink.util.concurrent.ScheduledExecutor; @@ -81,11 +86,14 @@ import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.BiConsumer; @@ -120,6 +128,16 @@ public class AdaptiveBatchScheduler extends DefaultScheduler { private final SpeculativeExecutionHandler speculativeExecutionHandler; + /** + * A set of JobVertex Ids associated with JobVertices whose operatorCoordinators did not + * successfully recover. And if any execution within these job vertices needs to be restarted in + * the future, all other executions within the same vertex must also be restarted to ensure the + * consistency and correctness of the state. + */ + private Set jobVerticesWithUnRecoveredCoordinators = new HashSet<>(); + + private final BatchJobRecoveryHandler jobRecoveryHandler; + public AdaptiveBatchScheduler( final Logger log, final JobGraph jobGraph, @@ -148,7 +166,8 @@ public AdaptiveBatchScheduler( final int defaultMaxParallelism, final BlocklistOperations blocklistOperations, final HybridPartitionDataConsumeConstraint hybridPartitionDataConsumeConstraint, - final Map forwardGroupsByJobVertexId) + final Map forwardGroupsByJobVertexId, + final BatchJobRecoveryHandler jobRecoveryHandler) throws Exception { super( @@ -195,6 +214,8 @@ public AdaptiveBatchScheduler( speculativeExecutionHandler = createSpeculativeExecutionHandler( log, jobMasterConfiguration, executionVertexVersioner, blocklistOperations); + + this.jobRecoveryHandler = jobRecoveryHandler; } private SpeculativeExecutionHandler createSpeculativeExecutionHandler( @@ -224,18 +245,104 @@ private SpeculativeExecutionHandler createSpeculativeExecutionHandler( protected void startSchedulingInternal() { speculativeExecutionHandler.init( getExecutionGraph(), getMainThreadExecutor(), jobManagerJobMetricGroup); + jobRecoveryHandler.initialize(new DefaultBatchJobRecoveryContext()); - tryComputeSourceParallelismThenRunAsync( - (Void value, Throwable throwable) -> { - if (getExecutionGraph().getState() == JobStatus.CREATED) { - initializeVerticesIfPossible(); - super.startSchedulingInternal(); - } - }); + if (jobRecoveryHandler.needRecover()) { + jobRecoveryHandler.startRecovering(); + } else { + tryComputeSourceParallelismThenRunAsync( + (Void value, Throwable throwable) -> { + if (getExecutionGraph().getState() == JobStatus.CREATED) { + initializeVerticesIfPossible(); + super.startSchedulingInternal(); + } + }); + } + } + + /** + * Modifies the vertices which need to be restarted. If any task needing restarting belongs to + * job vertices with unrecovered operator coordinators, all tasks within those job vertices need + * to be restarted once. + */ + @Override + protected void maybeRestartTasks(final FailureHandlingResult failureHandlingResult) { + FailureHandlingResult wrappedResult = failureHandlingResult; + if (failureHandlingResult.canRestart()) { + Set originalNeedToRestartVertices = + failureHandlingResult.getVerticesToRestart(); + + Set extraNeedToRestartJobVertices = + originalNeedToRestartVertices.stream() + .map(ExecutionVertexID::getJobVertexId) + .filter(jobVerticesWithUnRecoveredCoordinators::contains) + .collect(Collectors.toSet()); + + jobVerticesWithUnRecoveredCoordinators.removeAll(extraNeedToRestartJobVertices); + + Set needToRestartVertices = + extraNeedToRestartJobVertices.stream() + .flatMap( + jobVertexId -> { + ExecutionJobVertex jobVertex = + getExecutionJobVertex(jobVertexId); + return Arrays.stream(jobVertex.getTaskVertices()) + .map(ExecutionVertex::getID); + }) + .collect(Collectors.toSet()); + needToRestartVertices.addAll(originalNeedToRestartVertices); + + wrappedResult = + FailureHandlingResult.restartable( + failureHandlingResult.getFailedExecution().orElse(null), + failureHandlingResult.getError(), + failureHandlingResult.getTimestamp(), + failureHandlingResult.getFailureLabels(), + needToRestartVertices, + failureHandlingResult.getRestartDelayMS(), + failureHandlingResult.isGlobalFailure(), + failureHandlingResult.isRootCause()); + } + + super.maybeRestartTasks(wrappedResult); + } + + @VisibleForTesting + boolean isRecovering() { + return jobRecoveryHandler.isRecovering(); + } + + @Override + protected void resetForNewExecutions(Collection vertices) { + super.resetForNewExecutions(vertices); + if (!isRecovering()) { + jobRecoveryHandler.onExecutionVertexReset(vertices); + } + } + + private void initializeJobVertex( + ExecutionJobVertex jobVertex, + int parallelism, + Map jobVertexInputInfos, + long createTimestamp) + throws JobException { + if (!jobVertex.isParallelismDecided()) { + changeJobVertexParallelism(jobVertex, parallelism); + } else { + checkState(parallelism == jobVertex.getParallelism()); + } + checkState(canInitialize(jobVertex)); + getExecutionGraph().initializeJobVertex(jobVertex, createTimestamp, jobVertexInputInfos); + if (!isRecovering()) { + jobRecoveryHandler.onExecutionJobVertexInitialization( + jobVertex.getJobVertex().getID(), parallelism, jobVertexInputInfos); + } } @Override public CompletableFuture closeAsync() { + // stop job event manager. + jobRecoveryHandler.stop(requestJobStatus().isGloballyTerminalState()); speculativeExecutionHandler.stopSlowTaskDetector(); return super.closeAsync(); } @@ -243,6 +350,9 @@ public CompletableFuture closeAsync() { @Override protected void onTaskFinished(final Execution execution, final IOMetrics ioMetrics) { speculativeExecutionHandler.notifyTaskFinished(execution, this::cancelPendingExecutions); + if (!isRecovering()) { + jobRecoveryHandler.onExecutionFinished(execution.getVertex().getID()); + } checkNotNull(ioMetrics); updateResultPartitionBytesMetrics(ioMetrics.getResultPartitionBytes()); @@ -475,11 +585,17 @@ public void initializeVerticesIfPossible() { // Note that in current implementation, the decider will not load balance // (evenly distribute data) for job vertices whose parallelism has already been - // decided, so we can call the - // ExecutionGraph#initializeJobVertex(ExecutionJobVertex, long) to initialize. + // decided, so we can call the initializeJobVertex method, specifying the + // user-defined parallelism as its argument. // TODO: In the future, if we want to load balance for job vertices whose // parallelism has already been decided, we need to refactor the logic here. - getExecutionGraph().initializeJobVertex(jobVertex, createTimestamp); + initializeJobVertex( + jobVertex, + jobVertex.getParallelism(), + VertexInputInfoComputationUtils.computeVertexInputInfos( + jobVertex, + getExecutionGraph().getAllIntermediateResults()::get), + createTimestamp); newlyInitializedJobVertices.add(jobVertex); } else { Optional> consumedResultsInfo = @@ -488,14 +604,11 @@ public void initializeVerticesIfPossible() { ParallelismAndInputInfos parallelismAndInputInfos = tryDecideParallelismAndInputInfos( jobVertex, consumedResultsInfo.get()); - changeJobVertexParallelism( - jobVertex, parallelismAndInputInfos.getParallelism()); - checkState(canInitialize(jobVertex)); - getExecutionGraph() - .initializeJobVertex( - jobVertex, - createTimestamp, - parallelismAndInputInfos.getJobVertexInputInfos()); + initializeJobVertex( + jobVertex, + parallelismAndInputInfos.getParallelism(), + parallelismAndInputInfos.getJobVertexInputInfos(), + createTimestamp); newlyInitializedJobVertices.add(jobVertex); } } @@ -766,4 +879,98 @@ BlockingResultInfo getBlockingResultInfo(IntermediateDataSetID resultId) { SpeculativeExecutionHandler getSpeculativeExecutionHandler() { return speculativeExecutionHandler; } + + private class DefaultBatchJobRecoveryContext implements BatchJobRecoveryContext { + + private final FailoverStrategy restartStrategyOnResultConsumable = + new RestartPipelinedRegionFailoverStrategy.Factory() + .create(getSchedulingTopology(), getResultPartitionAvailabilityChecker()); + + private final FailoverStrategy restartStrategyNotOnResultConsumable = + new RestartPipelinedRegionFailoverStrategy.Factory() + .create(getSchedulingTopology(), ignored -> true); + + @Override + public ExecutionGraph getExecutionGraph() { + return AdaptiveBatchScheduler.this.getExecutionGraph(); + } + + @Override + public ShuffleMaster getShuffleMaster() { + return shuffleMaster; + } + + @Override + public Set getTasksNeedingRestart( + ExecutionVertexID vertexId, boolean considerResultConsumable) { + if (considerResultConsumable) { + return restartStrategyOnResultConsumable.getTasksNeedingRestart(vertexId, null); + } else { + return restartStrategyNotOnResultConsumable.getTasksNeedingRestart(vertexId, null); + } + } + + @Override + public ComponentMainThreadExecutor getMainThreadExecutor() { + return AdaptiveBatchScheduler.this.getMainThreadExecutor(); + } + + @Override + public void resetVerticesInRecovering(Set verticesToReset) + throws Exception { + for (ExecutionVertexID executionVertexID : verticesToReset) { + notifyCoordinatorsAboutTaskFailure( + getExecutionVertex(executionVertexID).getCurrentExecutionAttempt(), null); + } + resetForNewExecutions(verticesToReset); + restoreState(verticesToReset, false); + } + + @Override + public void updateResultPartitionBytesMetrics( + Map resultPartitionBytes) { + AdaptiveBatchScheduler.this.updateResultPartitionBytesMetrics(resultPartitionBytes); + } + + @Override + public void initializeJobVertex( + ExecutionJobVertex jobVertex, + int parallelism, + Map jobVertexInputInfos, + long createTimestamp) + throws JobException { + AdaptiveBatchScheduler.this.initializeJobVertex( + jobVertex, parallelism, jobVertexInputInfos, createTimestamp); + } + + @Override + public void updateTopology(final List newlyInitializedJobVertices) { + AdaptiveBatchScheduler.this.updateTopology(newlyInitializedJobVertices); + } + + @Override + public void onRecoveringFinished(Set jobVerticesWithUnRecoveredCoordinators) { + AdaptiveBatchScheduler.this.jobVerticesWithUnRecoveredCoordinators = + new HashSet<>(jobVerticesWithUnRecoveredCoordinators); + tryComputeSourceParallelismThenRunAsync( + (Void value, Throwable throwable) -> + schedulingStrategy.scheduleAllVerticesIfPossible()); + } + + @Override + public void onRecoveringFailed() { + // call #initializeVerticesIfPossible to avoid an empty execution graph + initializeVerticesIfPossible(); + handleGlobalFailure( + new FlinkRuntimeException("Recover failed from JM failover, fail global.")); + } + + @Override + public void failJob( + Throwable cause, + long timestamp, + CompletableFuture> failureLabels) { + AdaptiveBatchScheduler.this.failJob(cause, timestamp, failureLabels); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java index 947851daa415c..e87cd6b660b51 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java @@ -50,6 +50,8 @@ import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup; import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; import org.apache.flink.runtime.jobmaster.ExecutionDeploymentTracker; +import org.apache.flink.runtime.jobmaster.event.FileSystemJobEventStore; +import org.apache.flink.runtime.jobmaster.event.JobEventManager; import org.apache.flink.runtime.jobmaster.slotpool.PhysicalSlotProvider; import org.apache.flink.runtime.jobmaster.slotpool.PhysicalSlotProviderImpl; import org.apache.flink.runtime.jobmaster.slotpool.SlotPool; @@ -147,6 +149,21 @@ public SchedulerNG createInstance( jobGraph.getName(), jobGraph.getJobID()); + final boolean isJobRecoveryEnabled = + jobMasterConfiguration.getBoolean(BatchExecutionOptions.JOB_RECOVERY_ENABLED) + && shuffleMaster.supportsBatchSnapshot(); + + BatchJobRecoveryHandler jobRecoveryHandler; + if (isJobRecoveryEnabled) { + FileSystemJobEventStore jobEventStore = + new FileSystemJobEventStore(jobGraph.getJobID(), jobMasterConfiguration); + JobEventManager jobEventManager = new JobEventManager(jobEventStore); + jobRecoveryHandler = + new DefaultBatchJobRecoveryHandler(jobEventManager, jobMasterConfiguration); + } else { + jobRecoveryHandler = new DummyBatchJobRecoveryHandler(); + } + return createScheduler( log, jobGraph, @@ -173,7 +190,8 @@ public SchedulerNG createInstance( new ScheduledExecutorServiceAdapter(futureExecutor), DefaultVertexParallelismAndInputInfosDecider.from( getDefaultMaxParallelism(jobMasterConfiguration, executionConfig), - jobMasterConfiguration)); + jobMasterConfiguration), + jobRecoveryHandler); } @VisibleForTesting @@ -201,7 +219,8 @@ public static AdaptiveBatchScheduler createScheduler( ExecutionSlotAllocatorFactory allocatorFactory, RestartBackoffTimeStrategy restartBackoffTimeStrategy, ScheduledExecutor delayExecutor, - VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider) + VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider, + BatchJobRecoveryHandler jobRecoveryHandler) throws Exception { checkState( @@ -271,7 +290,8 @@ public static AdaptiveBatchScheduler createScheduler( defaultMaxParallelism, blocklistOperations, hybridPartitionDataConsumeConstraint, - forwardGroupsByJobVertexId); + forwardGroupsByJobVertexId, + jobRecoveryHandler); } public static InputConsumableDecider.Factory loadInputConsumableDeciderFactory( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryContext.java new file mode 100644 index 0000000000000..8235f3a3ca3b7 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryContext.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.JobException; +import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.shuffle.ShuffleMaster; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +/** Context for batch job recovery. */ +public interface BatchJobRecoveryContext { + + /** + * Provides the {@code ExecutionGraph} associated with the job. + * + * @return The execution graph. + */ + ExecutionGraph getExecutionGraph(); + + /** + * Provides the {@code ShuffleMaster} associated with the job. + * + * @return The shuffle master. + */ + ShuffleMaster getShuffleMaster(); + + /** + * Provides the main thread executor. + * + * @return The main thread executor. + */ + ComponentMainThreadExecutor getMainThreadExecutor(); + + /** + * Retrieves a set of vertices that need to be restarted. If result consumption is considered + * (`basedOnResultConsumable` is true), the set will include all downstream vertices that have + * finished and upstream vertices that have missed partitions. Otherwise, only include + * downstream finished vertices. + * + * @param vertexId The ID of the vertex from which to compute the restart set. + * @param considerResultConsumable Indicates whether to consider result partition consumption + * while computing the vertices needing restart. + * @return A set of vertex IDs that need to be restarted. + */ + Set getTasksNeedingRestart( + ExecutionVertexID vertexId, boolean considerResultConsumable); + + /** + * Resets vertices specified by their IDs during recovery process. + * + * @param verticesToReset The set of vertices that require resetting. + */ + void resetVerticesInRecovering(Set verticesToReset) throws Exception; + + /** + * Updates the metrics related to the result partition sizes. + * + * @param resultPartitionBytes Mapping of partition IDs to their respective result partition + * bytes. + */ + void updateResultPartitionBytesMetrics( + Map resultPartitionBytes); + + /** + * Initializes a given job vertex with the specified parallelism and input information. + * + * @param jobVertex The job vertex to initialize. + * @param parallelism The parallelism to set for the job vertex. + * @param jobVertexInputInfos The input information for the job vertex. + * @param createTimestamp The timestamp marking the creation of the job vertex. + */ + void initializeJobVertex( + ExecutionJobVertex jobVertex, + int parallelism, + Map jobVertexInputInfos, + long createTimestamp) + throws JobException; + + /** + * Updates the job topology with new job vertices that were initialized. + * + * @param newlyInitializedJobVertices List of job vertices that have been initialized. + */ + void updateTopology(List newlyInitializedJobVertices); + + /** + * Notifies the recovery finished. + * + * @param jobVerticesWithUnRecoveredCoordinators A set of job vertex Ids is associated with job + * vertices whose operatorCoordinators did not successfully recover their state. If any + * execution within these job vertices needs to be restarted in the future, all other + * executions within the same job vertex must also be restarted to ensure the consistency + * and correctness of the state. + */ + void onRecoveringFinished(Set jobVerticesWithUnRecoveredCoordinators); + + /** Notifies the recovery failed. */ + void onRecoveringFailed(); + + /** Trigger job failure. */ + void failJob( + Throwable cause, long timestamp, CompletableFuture> failureLabels); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryHandler.java new file mode 100644 index 0000000000000..f0499f4788f06 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryHandler.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; + +import java.util.Collection; +import java.util.Map; + +/** Interface for handling batch job recovery. */ +public interface BatchJobRecoveryHandler { + + /** Initializes the recovery handler with the batch job recovery context. */ + void initialize(BatchJobRecoveryContext batchJobRecoveryContext); + + /** Starts the recovery process. */ + void startRecovering(); + + /** + * Stops the job recovery handler and optionally clean up. + * + * @param cleanUp whether to clean up. + */ + void stop(boolean cleanUp); + + /** Determines whether the job needs to undergo recovery. */ + boolean needRecover(); + + /** Determines whether the job is recovering. */ + boolean isRecovering(); + + /** + * Handles the reset event for a collection of execution vertices and records the event for use + * during future batch job recovery. + * + * @param vertices a collection of execution vertex IDs that have been reset. + */ + void onExecutionVertexReset(Collection vertices); + + /** + * Records the job vertex initialization event for use during future batch job recovery. + * + * @param jobVertexId the id of the job vertex being initialized. + * @param parallelism the parallelism of the job vertex. + * @param jobVertexInputInfos a map of intermediate dataset IDs to job vertex input info. + */ + void onExecutionJobVertexInitialization( + JobVertexID jobVertexId, + int parallelism, + Map jobVertexInputInfos); + + /** + * Records the execution vertex finished event for use during future batch job recovery. + * + * @param executionVertexId the id of the execution vertex is finished. + */ + void onExecutionFinished(ExecutionVertexID executionVertexId); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultBatchJobRecoveryHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultBatchJobRecoveryHandler.java new file mode 100644 index 0000000000000..d2fd19cbcda88 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultBatchJobRecoveryHandler.java @@ -0,0 +1,846 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.api.common.JobStatus; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.BatchExecutionOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.executiongraph.Execution; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; +import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.failure.FailureEnricherUtils; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexInitializedEvent; +import org.apache.flink.runtime.jobmaster.event.ExecutionVertexFinishedEvent; +import org.apache.flink.runtime.jobmaster.event.ExecutionVertexResetEvent; +import org.apache.flink.runtime.jobmaster.event.JobEvent; +import org.apache.flink.runtime.jobmaster.event.JobEventManager; +import org.apache.flink.runtime.jobmaster.event.JobEventReplayHandler; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinatorHolder; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.shuffle.DefaultShuffleMasterSnapshotContext; +import org.apache.flink.runtime.shuffle.PartitionWithMetrics; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleMasterSnapshot; +import org.apache.flink.util.clock.Clock; +import org.apache.flink.util.clock.SystemClock; + +import org.apache.flink.shaded.guava31.com.google.common.collect.Sets; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.stream.Collectors; + +import static org.apache.flink.runtime.operators.coordination.OperatorCoordinator.NO_CHECKPOINT; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation of {@link BatchJobRecoveryHandler} and {@link JobEventReplayHandler}. */ +public class DefaultBatchJobRecoveryHandler + implements BatchJobRecoveryHandler, JobEventReplayHandler { + + private final Logger log = LoggerFactory.getLogger(getClass()); + + private final JobEventManager jobEventManager; + + private BatchJobRecoveryContext context; + + /** The timestamp (via {@link Clock#relativeTimeMillis()}) of the last snapshot. */ + private long lastSnapshotRelativeTime; + + private final Set needToSnapshotJobVertices = new HashSet<>(); + + private static final ResourceID UNKNOWN_PRODUCER = ResourceID.generate(); + + private final long snapshotMinPauseMills; + + private Clock clock; + + private final Map + executionVertexFinishedEventMap = new LinkedHashMap<>(); + + private final List jobVertexInitializedEvents = + new ArrayList<>(); + + /** + * A set of JobVertex Ids associated with JobVertices whose operatorCoordinators did not + * successfully recover. And if any execution within these job vertices needs to be restarted in + * the future, all other executions within the same vertex must also be restarted to ensure the + * consistency and correctness of the state. + */ + private final Set jobVerticesWithUnRecoveredCoordinators = new HashSet<>(); + + private final Duration previousWorkerRecoveryTimeout; + + public DefaultBatchJobRecoveryHandler( + JobEventManager jobEventManager, Configuration jobMasterConfiguration) { + this.jobEventManager = jobEventManager; + + this.previousWorkerRecoveryTimeout = + jobMasterConfiguration.get( + BatchExecutionOptions.JOB_RECOVERY_PREVIOUS_WORKER_RECOVERY_TIMEOUT); + this.snapshotMinPauseMills = + jobMasterConfiguration + .get(BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE) + .toMillis(); + } + + @Override + public void initialize(BatchJobRecoveryContext context) { + this.context = checkNotNull(context); + this.clock = SystemClock.getInstance(); + + try { + jobEventManager.start(); + } catch (Throwable throwable) { + context.failJob( + throwable, + System.currentTimeMillis(), + FailureEnricherUtils.EMPTY_FAILURE_LABELS); + } + } + + @Override + public void stop(boolean cleanUp) { + jobEventManager.stop(cleanUp); + } + + @Override + public void startRecovering() { + context.getMainThreadExecutor().assertRunningInMainThread(); + + startRecoveringInternal(); + + // notify the shuffle master the recovery process has started and try to fetch partitions + context.getShuffleMaster() + .notifyPartitionRecoveryStarted(context.getExecutionGraph().getJobID()); + + if (!jobEventManager.replay(this)) { + log.warn( + "Fail to replay log for {}, will start the job as a new one.", + context.getExecutionGraph().getJobID()); + recoverFailed(); + return; + } + log.info("Replay all job events successfully."); + + recoverPartitions() + .whenComplete( + (ignored, throwable) -> { + if (throwable != null) { + recoverFailed(); + } + try { + recoverFinished(); + } catch (Exception exception) { + recoverFailed(); + } + }); + } + + @Override + public boolean needRecover() { + try { + return jobEventManager.hasJobEvents(); + } catch (Throwable throwable) { + context.failJob( + throwable, + System.currentTimeMillis(), + FailureEnricherUtils.EMPTY_FAILURE_LABELS); + return false; + } + } + + @Override + public boolean isRecovering() { + return context.getExecutionGraph().getState() == JobStatus.RECONCILING; + } + + private void restoreShuffleMaster(List snapshots) { + checkState(context.getShuffleMaster().supportsBatchSnapshot()); + context.getShuffleMaster().restoreState(snapshots); + } + + private void startRecoveringInternal() { + log.info("Try to recover status from previously failed job master."); + context.getExecutionGraph().transitionState(JobStatus.CREATED, JobStatus.RECONCILING); + } + + private void restoreOperatorCoordinators( + Map snapshots, Map operatorToJobVertex) + throws Exception { + for (Map.Entry entry : snapshots.entrySet()) { + OperatorID operatorId = entry.getKey(); + JobVertexID jobVertexId = checkNotNull(operatorToJobVertex.get(operatorId)); + ExecutionJobVertex jobVertex = getExecutionJobVertex(jobVertexId); + log.info("Restore operator coordinators of {} from job event.", jobVertex.getName()); + + for (OperatorCoordinatorHolder holder : jobVertex.getOperatorCoordinators()) { + if (holder.coordinator().supportsBatchSnapshot()) { + byte[] snapshot = snapshots.get(holder.operatorId()); + holder.resetToCheckpoint(NO_CHECKPOINT, snapshot); + } + } + } + + determineVerticesForResetAfterRestoreOpCoordinator(); + } + + @Override + public void startReplay() { + // do nothing. + } + + @Override + public void replayOneEvent(JobEvent jobEvent) { + if (jobEvent instanceof ExecutionVertexFinishedEvent) { + ExecutionVertexFinishedEvent event = (ExecutionVertexFinishedEvent) jobEvent; + executionVertexFinishedEventMap.put(event.getExecutionVertexId(), event); + } else if (jobEvent instanceof ExecutionVertexResetEvent) { + ExecutionVertexResetEvent event = (ExecutionVertexResetEvent) jobEvent; + for (ExecutionVertexID executionVertexId : event.getExecutionVertexIds()) { + executionVertexFinishedEventMap.remove(executionVertexId); + } + } else if (jobEvent instanceof ExecutionJobVertexInitializedEvent) { + jobVertexInitializedEvents.add((ExecutionJobVertexInitializedEvent) jobEvent); + } else { + throw new IllegalStateException("Unsupported job event " + jobEvent); + } + } + + @Override + public void finalizeReplay() throws Exception { + // recover job vertex initialization info and update topology + long currentTimeMillis = System.currentTimeMillis(); + final List initializedJobVertices = new ArrayList<>(); + for (ExecutionJobVertexInitializedEvent event : jobVertexInitializedEvents) { + final ExecutionJobVertex jobVertex = getExecutionJobVertex(event.getJobVertexId()); + context.initializeJobVertex( + jobVertex, + event.getParallelism(), + event.getJobVertexInputInfos(), + currentTimeMillis); + initializedJobVertices.add(jobVertex); + } + context.updateTopology(initializedJobVertices); + + // Because we will take operator coordinator and shuffle master snapshots and persisted + // externally periodically. As a result, any events in the final batch that do not have an + // associated snapshot are redundant and can be disregarded. + LinkedList finishedEvents = + new LinkedList<>(executionVertexFinishedEventMap.values()); + while (!finishedEvents.isEmpty() + && !finishedEvents.getLast().hasOperatorCoordinatorAndShuffleMasterSnapshots()) { + finishedEvents.removeLast(); + } + + if (finishedEvents.isEmpty()) { + return; + } + + // find the last operator coordinator state for each operator coordinator + Map operatorCoordinatorSnapshots = new HashMap<>(); + + List shuffleMasterSnapshots = new ArrayList<>(); + + // transition states of all vertices + for (ExecutionVertexFinishedEvent event : finishedEvents) { + JobVertexID jobVertexId = event.getExecutionVertexId().getJobVertexId(); + ExecutionJobVertex jobVertex = context.getExecutionGraph().getJobVertex(jobVertexId); + checkState(jobVertex.isInitialized()); + + int subTaskIndex = event.getExecutionVertexId().getSubtaskIndex(); + Execution execution = + jobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt(); + // recover execution info. + execution.recoverExecution( + event.getExecutionAttemptId(), + event.getTaskManagerLocation(), + event.getUserAccumulators(), + event.getIOMetrics()); + + // recover operator coordinator + for (Map.Entry> entry : + event.getOperatorCoordinatorSnapshotFutures().entrySet()) { + checkState(entry.getValue().isDone()); + operatorCoordinatorSnapshots.put(entry.getKey(), entry.getValue().get()); + } + + // recover shuffle master + if (event.getShuffleMasterSnapshotFuture() != null) { + checkState(event.getShuffleMasterSnapshotFuture().isDone()); + + ShuffleMasterSnapshot shuffleMasterSnapshot = + event.getShuffleMasterSnapshotFuture().get(); + if (shuffleMasterSnapshot.isIncremental()) { + shuffleMasterSnapshots.add(shuffleMasterSnapshot); + } else { + shuffleMasterSnapshots = Arrays.asList(shuffleMasterSnapshot); + } + } + } + + // restore operator coordinator state if needed. + final Map operatorToJobVertex = new HashMap<>(); + for (ExecutionJobVertex jobVertex : context.getExecutionGraph().getAllVertices().values()) { + if (!jobVertex.isInitialized()) { + continue; + } + + for (OperatorCoordinatorHolder holder : jobVertex.getOperatorCoordinators()) { + operatorToJobVertex.put(holder.operatorId(), jobVertex.getJobVertexId()); + } + } + + try { + restoreOperatorCoordinators(operatorCoordinatorSnapshots, operatorToJobVertex); + } catch (Exception exception) { + log.warn("Restore coordinator operator failed.", exception); + throw exception; + } + + // restore shuffle master + restoreShuffleMaster(shuffleMasterSnapshots); + } + + @Override + public void onExecutionVertexReset(Collection vertices) { + // write execute vertex reset event. + checkState(!isRecovering()); + jobEventManager.writeEvent(new ExecutionVertexResetEvent(new ArrayList<>(vertices)), false); + } + + @Override + public void onExecutionJobVertexInitialization( + JobVertexID jobVertexId, + int parallelism, + Map jobVertexInputInfos) { + // write execution job vertex initialized event. + checkState(!isRecovering()); + jobEventManager.writeEvent( + new ExecutionJobVertexInitializedEvent( + jobVertexId, parallelism, jobVertexInputInfos), + false); + } + + @Override + public void onExecutionFinished(ExecutionVertexID executionVertexId) { + checkState(!isRecovering()); + + Execution execution = getExecutionVertex(executionVertexId).getCurrentExecutionAttempt(); + + // check whether the job vertex is finished. + ExecutionJobVertex jobVertex = execution.getVertex().getJobVertex(); + boolean jobVertexFinished = jobVertex.getAggregateState() == ExecutionState.FINISHED; + + // snapshot operator coordinators and shuffle master if needed. + needToSnapshotJobVertices.add(executionVertexId.getJobVertexId()); + final Map> operatorCoordinatorSnapshotFutures = + new HashMap<>(); + CompletableFuture shuffleMasterSnapshotFuture = null; + long currentRelativeTime = clock.relativeTimeMillis(); + if (jobVertexFinished + || (currentRelativeTime - lastSnapshotRelativeTime >= snapshotMinPauseMills)) { + // operator coordinator + operatorCoordinatorSnapshotFutures.putAll(snapshotOperatorCoordinators()); + lastSnapshotRelativeTime = currentRelativeTime; + needToSnapshotJobVertices.clear(); + + // shuffle master + shuffleMasterSnapshotFuture = snapshotShuffleMaster(); + } + + // write job event. + jobEventManager.writeEvent( + new ExecutionVertexFinishedEvent( + execution.getAttemptId(), + execution.getAssignedResourceLocation(), + operatorCoordinatorSnapshotFutures, + shuffleMasterSnapshotFuture, + execution.getIOMetrics(), + execution.getUserAccumulators()), + jobVertexFinished); + } + + private Map> snapshotOperatorCoordinators() { + + final Map> snapshotFutures = new HashMap<>(); + + for (JobVertexID jobVertexId : needToSnapshotJobVertices) { + final ExecutionJobVertex jobVertex = checkNotNull(getExecutionJobVertex(jobVertexId)); + + log.info( + "Snapshot operator coordinators of {} to job event, checkpointId {}.", + jobVertex.getName(), + NO_CHECKPOINT); + + for (OperatorCoordinatorHolder holder : jobVertex.getOperatorCoordinators()) { + if (holder.coordinator().supportsBatchSnapshot()) { + final CompletableFuture checkpointFuture = new CompletableFuture<>(); + holder.checkpointCoordinator(NO_CHECKPOINT, checkpointFuture); + snapshotFutures.put(holder.operatorId(), checkpointFuture); + } + } + } + return snapshotFutures; + } + + private CompletableFuture snapshotShuffleMaster() { + + checkState(context.getShuffleMaster().supportsBatchSnapshot()); + CompletableFuture shuffleMasterSnapshotFuture = + new CompletableFuture<>(); + context.getShuffleMaster() + .snapshotState( + shuffleMasterSnapshotFuture, new DefaultShuffleMasterSnapshotContext()); + return shuffleMasterSnapshotFuture; + } + + private void determineVerticesForResetAfterRestoreOpCoordinator() throws Exception { + Set verticesToReset = new HashSet<>(); + + for (ExecutionJobVertex jobVertex : context.getExecutionGraph().getAllVertices().values()) { + if (!jobVertex.isInitialized() || jobVertex.getOperatorCoordinators().isEmpty()) { + continue; + } + + boolean allSupportsBatchSnapshot = + jobVertex.getOperatorCoordinators().stream() + .allMatch(holder -> holder.coordinator().supportsBatchSnapshot()); + + Set unfinishedTasks = + Arrays.stream(jobVertex.getTaskVertices()) + .filter(vertex -> vertex.getExecutionState() != ExecutionState.FINISHED) + .map( + executionVertex -> { + // transition to terminal state to allow reset it + executionVertex + .getCurrentExecutionAttempt() + .transitionState(ExecutionState.CANCELED); + return executionVertex.getID(); + }) + .collect(Collectors.toSet()); + + if (allSupportsBatchSnapshot) { + log.info( + "All operator coordinators of jobVertex {} support batch snapshot, " + + "add {} unfinished tasks to revise.", + jobVertex.getName(), + unfinishedTasks.size()); + verticesToReset.addAll(unfinishedTasks); + } else if (unfinishedTasks.isEmpty()) { + log.info( + "JobVertex {} is finished, but not all of its operator coordinators support " + + "batch snapshot. Therefore, if any single task within it requires " + + "a restart in the future, all tasks associated with this JobVertex " + + "need to be restarted as well.", + jobVertex.getName()); + jobVerticesWithUnRecoveredCoordinators.add(jobVertex.getJobVertexId()); + } else { + log.info( + "Restart all tasks of jobVertex {} because it has not been finished and not " + + "all of its operator coordinators support batch snapshot.", + jobVertex.getName()); + verticesToReset.addAll( + Arrays.stream(jobVertex.getTaskVertices()) + .map(ExecutionVertex::getID) + .collect(Collectors.toSet())); + } + } + + resetVerticesInRecovering(verticesToReset, false); + } + + private void resetVerticesInRecovering( + Set nextVertices, boolean baseOnResultPartitionConsumable) + throws Exception { + checkState(isRecovering()); + + Set verticesToRestart = new HashSet<>(); + while (!nextVertices.isEmpty()) { + for (ExecutionVertexID executionVertexId : nextVertices) { + if (!verticesToRestart.contains(executionVertexId)) { + verticesToRestart.addAll( + context.getTasksNeedingRestart( + executionVertexId, baseOnResultPartitionConsumable)); + } + } + + Set extraNeedToRestartJobVertices = + verticesToRestart.stream() + .map(ExecutionVertexID::getJobVertexId) + .filter(jobVerticesWithUnRecoveredCoordinators::contains) + .collect(Collectors.toSet()); + jobVerticesWithUnRecoveredCoordinators.removeAll(extraNeedToRestartJobVertices); + + nextVertices = + extraNeedToRestartJobVertices.stream() + .flatMap( + jobVertexId -> { + ExecutionJobVertex jobVertex = + getExecutionJobVertex(jobVertexId); + return Arrays.stream(jobVertex.getTaskVertices()) + .map(ExecutionVertex::getID); + }) + .collect(Collectors.toSet()); + } + + // we only reset tasks which are not CREATED. + Set verticesToReset = + verticesToRestart.stream() + .filter( + executionVertexID -> + getExecutionVertex(executionVertexID).getExecutionState() + != ExecutionState.CREATED) + .collect(Collectors.toSet()); + + context.resetVerticesInRecovering(verticesToReset); + } + + private void recoverFailed() { + String message = + String.format( + "Job %s recover failed from JM failover, fail global.", + context.getExecutionGraph().getJobID()); + log.warn(message); + context.getExecutionGraph().transitionState(JobStatus.RECONCILING, JobStatus.RUNNING); + + // clear job events and restart job event manager. + jobEventManager.stop(true); + try { + jobEventManager.start(); + } catch (Throwable throwable) { + context.failJob( + throwable, + System.currentTimeMillis(), + FailureEnricherUtils.EMPTY_FAILURE_LABELS); + return; + } + + context.onRecoveringFailed(); + } + + private void recoverFinished() { + log.info( + "Job {} successfully recovered from JM failover", + context.getExecutionGraph().getJobID()); + + context.getExecutionGraph().transitionState(JobStatus.RECONCILING, JobStatus.RUNNING); + checkExecutionGraphState(); + context.onRecoveringFinished(jobVerticesWithUnRecoveredCoordinators); + } + + private void checkExecutionGraphState() { + for (ExecutionVertex executionVertex : + context.getExecutionGraph().getAllExecutionVertices()) { + ExecutionState state = executionVertex.getExecutionState(); + checkState(state == ExecutionState.CREATED || state == ExecutionState.FINISHED); + } + } + + private CompletableFuture recoverPartitions() { + context.getMainThreadExecutor().assertRunningInMainThread(); + + CompletableFuture>> + reconcilePartitionsFuture = reconcilePartitions(); + + return reconcilePartitionsFuture.thenAccept( + tuple2 -> { + ReconcileResult reconcileResult = tuple2.f0; + Collection partitionWithMetrics = tuple2.f1; + + log.info( + "Partitions to be released: {}, missed partitions: {}, partitions to be reserved: {}.", + reconcileResult.partitionsToRelease, + reconcileResult.partitionsMissing, + reconcileResult.partitionsToReserve); + + // release partitions which is no more needed. + ((InternalExecutionGraphAccessor) context.getExecutionGraph()) + .getPartitionTracker() + .stopTrackingAndReleasePartitions(reconcileResult.partitionsToRelease); + + // start tracking all partitions should be reserved + Map + availablePartitionBytes = new HashMap<>(); + partitionWithMetrics.stream() + .filter( + partitionAndMetric -> + reconcileResult.partitionsToReserve.contains( + partitionAndMetric + .getPartition() + .getResultPartitionID())) + .forEach( + partitionAndMetric -> { + ShuffleDescriptor shuffleDescriptor = + partitionAndMetric.getPartition(); + + // we cannot get the producer id when using remote shuffle + ResourceID producerTaskExecutorId = UNKNOWN_PRODUCER; + if (shuffleDescriptor + .storesLocalResourcesOn() + .isPresent()) { + producerTaskExecutorId = + shuffleDescriptor + .storesLocalResourcesOn() + .get(); + } + IntermediateResultPartition partition = + context.getExecutionGraph() + .getResultPartitionOrThrow( + shuffleDescriptor + .getResultPartitionID() + .getPartitionId()); + ((InternalExecutionGraphAccessor) + context.getExecutionGraph()) + .getPartitionTracker() + .startTrackingPartition( + producerTaskExecutorId, + Execution + .createResultPartitionDeploymentDescriptor( + partition, + shuffleDescriptor)); + + availablePartitionBytes.put( + shuffleDescriptor + .getResultPartitionID() + .getPartitionId(), + partitionAndMetric + .getPartitionMetrics() + .getPartitionBytes()); + }); + + // recover the produced partitions for executions + Map< + ExecutionVertexID, + Map< + IntermediateResultPartitionID, + ResultPartitionDeploymentDescriptor>> + allDescriptors = new HashMap<>(); + ((InternalExecutionGraphAccessor) context.getExecutionGraph()) + .getPartitionTracker() + .getAllTrackedNonClusterPartitions() + .forEach( + descriptor -> { + ExecutionVertexID vertexId = + descriptor + .getShuffleDescriptor() + .getResultPartitionID() + .getProducerId() + .getExecutionVertexId(); + if (!allDescriptors.containsKey(vertexId)) { + allDescriptors.put(vertexId, new HashMap<>()); + } + + allDescriptors + .get(vertexId) + .put(descriptor.getPartitionId(), descriptor); + }); + + allDescriptors.forEach( + (vertexId, descriptors) -> + getExecutionVertex(vertexId) + .getCurrentExecutionAttempt() + .recoverProducedPartitions(descriptors)); + + // recover result partition bytes + context.updateResultPartitionBytesMetrics(availablePartitionBytes); + + // restart all producers of missing partitions + Set missingPartitionVertices = + reconcileResult.partitionsMissing.stream() + .map(ResultPartitionID::getPartitionId) + .map(this::getProducer) + .map(ExecutionVertex::getID) + .collect(Collectors.toSet()); + + try { + resetVerticesInRecovering(missingPartitionVertices, true); + } catch (Exception e) { + throw new CompletionException(e); + } + }); + } + + private CompletableFuture>> + reconcilePartitions() { + List partitions = + context.getExecutionGraph().getAllIntermediateResults().values().stream() + .flatMap(result -> Arrays.stream(result.getPartitions())) + .collect(Collectors.toList()); + + Set partitionsToReserve = new HashSet<>(); + Set partitionsToRelease = new HashSet<>(); + for (IntermediateResultPartition partition : partitions) { + PartitionReservationStatus reserveStatus = getPartitionReservationStatus(partition); + + if (reserveStatus.equals(PartitionReservationStatus.RESERVE)) { + partitionsToReserve.add(createResultPartitionId(partition.getPartitionId())); + } else if (reserveStatus.equals(PartitionReservationStatus.RELEASE)) { + partitionsToRelease.add(createResultPartitionId(partition.getPartitionId())); + } + } + + CompletableFuture> fetchPartitionsFuture = + context.getShuffleMaster() + .getPartitionWithMetrics( + context.getExecutionGraph().getJobID(), + previousWorkerRecoveryTimeout, + partitionsToReserve); + + return fetchPartitionsFuture.thenApplyAsync( + partitionWithMetrics -> { + Set actualPartitions = + partitionWithMetrics.stream() + .map(PartitionWithMetrics::getPartition) + .map(ShuffleDescriptor::getResultPartitionID) + .collect(Collectors.toSet()); + + Set actualpartitionsToRelease = + Sets.intersection(partitionsToRelease, actualPartitions); + Set actualpartitionsMissing = + Sets.difference(partitionsToReserve, actualPartitions); + Set actualpartitionsToReserve = + Sets.intersection(partitionsToReserve, actualPartitions); + + return Tuple2.of( + new ReconcileResult( + actualpartitionsToRelease, + actualpartitionsMissing, + actualpartitionsToReserve), + partitionWithMetrics); + }, + context.getMainThreadExecutor()); + } + + private ResultPartitionID createResultPartitionId(IntermediateResultPartitionID partitionId) { + final Execution producer = getProducer(partitionId).getPartitionProducer(); + return new ResultPartitionID(partitionId, producer.getAttemptId()); + } + + private ExecutionVertex getProducer(IntermediateResultPartitionID partitionId) { + return context.getExecutionGraph().getResultPartitionOrThrow(partitionId).getProducer(); + } + + private PartitionReservationStatus getPartitionReservationStatus( + IntermediateResultPartition partition) { + // 1. Check if the producer of this partition is finished. + ExecutionVertex producer = getProducer(partition.getPartitionId()); + boolean isProducerFinished = producer.getExecutionState() == ExecutionState.FINISHED; + if (!isProducerFinished) { + return PartitionReservationStatus.RELEASE; + } + + // 2. Check if not all the consumer vertices for this partition are initialized. + boolean allConsumersInitialized = + partition.getIntermediateResult().getConsumerVertices().stream() + .allMatch( + jobVertexId -> getExecutionJobVertex(jobVertexId).isInitialized()); + + if (!allConsumersInitialized) { + return PartitionReservationStatus.RESERVE; + } + + // 3. If all downstream vertices are finished, we need reserve the partitions. Otherwise, we + // could reserve them if fetched from shuffle master. + return getConsumers(partition.getPartitionId()).stream() + .anyMatch(vertex -> vertex.getExecutionState() != ExecutionState.FINISHED) + ? PartitionReservationStatus.RESERVE + : PartitionReservationStatus.OPTIONAL; + } + + /** Enum that specifies the reservation status of a partition. */ + private enum PartitionReservationStatus { + // Indicates the partition should be released. + RELEASE, + + // Indicates the partition should be reserved. + RESERVE, + + // Indicates the partition's reservation is preferred but not mandatory. + OPTIONAL + } + + private List getConsumers(IntermediateResultPartitionID partitionId) { + List consumerVertexGroups = + context.getExecutionGraph() + .getResultPartitionOrThrow(partitionId) + .getConsumerVertexGroups(); + List executionVertices = new ArrayList<>(); + for (ConsumerVertexGroup group : consumerVertexGroups) { + for (ExecutionVertexID executionVertexID : group) { + executionVertices.add(getExecutionVertex(executionVertexID)); + } + } + return executionVertices; + } + + private ExecutionVertex getExecutionVertex(final ExecutionVertexID executionVertexId) { + return context.getExecutionGraph() + .getAllVertices() + .get(executionVertexId.getJobVertexId()) + .getTaskVertices()[executionVertexId.getSubtaskIndex()]; + } + + private ExecutionJobVertex getExecutionJobVertex(final JobVertexID jobVertexId) { + return context.getExecutionGraph().getAllVertices().get(jobVertexId); + } + + private static class ReconcileResult { + private final Set partitionsToRelease; + private final Set partitionsMissing; + private final Set partitionsToReserve; + + ReconcileResult( + Set partitionsToRelease, + Set partitionsMissing, + Set partitionsToReserve) { + this.partitionsToRelease = checkNotNull(partitionsToRelease); + this.partitionsMissing = checkNotNull(partitionsMissing); + this.partitionsToReserve = checkNotNull(partitionsToReserve); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DummyBatchJobRecoveryHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DummyBatchJobRecoveryHandler.java new file mode 100644 index 0000000000000..d66f01169badb --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DummyBatchJobRecoveryHandler.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; + +import java.util.Collection; +import java.util.Map; + +/** A dummy implementation of the {@link BatchJobRecoveryHandler}. */ +public class DummyBatchJobRecoveryHandler implements BatchJobRecoveryHandler { + @Override + public void initialize(BatchJobRecoveryContext batchJobRecoveryContext) {} + + @Override + public void startRecovering() {} + + @Override + public void stop(boolean cleanUp) {} + + @Override + public boolean needRecover() { + return false; + } + + @Override + public boolean isRecovering() { + return false; + } + + @Override + public void onExecutionVertexReset(Collection vertices) {} + + @Override + public void onExecutionJobVertexInitialization( + JobVertexID jobVertexId, + int parallelism, + Map jobVertexInputInfos) {} + + @Override + public void onExecutionFinished(ExecutionVertexID executionVertexId) {} +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingStrategy.java index 86ebce5df844b..5efdfd7367164 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingStrategy.java @@ -55,4 +55,12 @@ public interface SchedulingStrategy { * @param resultPartitionId The id of the result partition */ void onPartitionConsumable(IntermediateResultPartitionID resultPartitionId); + + /** + * Schedules all vertices and excludes any vertices that are already finished or whose inputs + * are not yet ready. + */ + default void scheduleAllVerticesIfPossible() { + throw new UnsupportedOperationException(); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java index b2d25b1f82f0c..7be6922b9359d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java @@ -146,6 +146,18 @@ private void maybeScheduleVertices(final Set vertices) { scheduledVertices.addAll(verticesToSchedule); } + @Override + public void scheduleAllVerticesIfPossible() { + newVertices.clear(); + Set verticesToSchedule = + IterableUtils.toStream(schedulingTopology.getVertices()) + .filter(vertex -> !vertex.getState().equals(ExecutionState.FINISHED)) + .map(SchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + + maybeScheduleVertices(verticesToSchedule); + } + private Set addToScheduleAndGetVertices( Set currentVertices, Set verticesToSchedule) { Set nextVertices = new HashSet<>(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java index 21c692e09a2ff..e64eb0424ca9a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java @@ -70,6 +70,7 @@ Licensed to the Apache Software Foundation (ASF) under one import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import static org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.readAndVerifyCoordinatorSerdeVersion; @@ -546,7 +547,7 @@ private void runInEventLoop( // --------------------------------------------------- @VisibleForTesting - SplitEnumerator getEnumerator() { + public SplitEnumerator getEnumerator() { return enumerator; } @@ -555,6 +556,11 @@ SourceCoordinatorContext getContext() { return context; } + @VisibleForTesting + public ExecutorService getCoordinatorExecutor() { + return context.getCoordinatorExecutor(); + } + // --------------------- Serde ----------------------- /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java index 934c222af9288..fa49577caad1d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java @@ -229,6 +229,11 @@ void sendEventToSourceOperator(int subtaskId, OperatorEvent event) { String.format("Failed to send event %s to subtask %d", event, subtaskId)); } + @VisibleForTesting + ScheduledExecutorService getCoordinatorExecutor() { + return coordinatorExecutor; + } + void sendEventToSourceOperatorIfTaskReady(int subtaskId, OperatorEvent event) { checkAndLazyInitialize(); checkSubtaskIndex(subtaskId); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java index 1183fe327be66..60d7eeaa307b8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java @@ -48,7 +48,9 @@ import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler; import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchSchedulerFactory; +import org.apache.flink.runtime.scheduler.adaptivebatch.BatchJobRecoveryHandler; import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo; +import org.apache.flink.runtime.scheduler.adaptivebatch.DummyBatchJobRecoveryHandler; import org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismAndInputInfosDecider; import org.apache.flink.runtime.scheduler.strategy.AllFinishedInputConsumableDecider; import org.apache.flink.runtime.scheduler.strategy.InputConsumableDecider; @@ -118,6 +120,7 @@ public class DefaultSchedulerBuilder { HybridPartitionDataConsumeConstraint.UNFINISHED_PRODUCERS; private InputConsumableDecider.Factory inputConsumableDeciderFactory = AllFinishedInputConsumableDecider.Factory.INSTANCE; + private BatchJobRecoveryHandler jobRecoveryHandler = new DummyBatchJobRecoveryHandler(); public DefaultSchedulerBuilder( JobGraph jobGraph, @@ -291,6 +294,12 @@ public DefaultSchedulerBuilder setInputConsumableDeciderFactory( return this; } + public DefaultSchedulerBuilder setJobRecoveryHandler( + BatchJobRecoveryHandler jobRecoveryHandler) { + this.jobRecoveryHandler = jobRecoveryHandler; + return this; + } + public DefaultScheduler build() throws Exception { return new DefaultScheduler( log, @@ -365,7 +374,8 @@ public AdaptiveBatchScheduler buildAdaptiveBatchJobScheduler(boolean enableSpecu executionSlotAllocatorFactory, restartBackoffTimeStrategy, delayExecutor, - vertexParallelismAndInputInfosDecider); + vertexParallelismAndInputInfosDecider, + jobRecoveryHandler); } private ExecutionGraphFactory createExecutionGraphFactory(boolean isDynamicGraph) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java index 6b69eb993f73a..e1a0a4714f9e4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java @@ -463,7 +463,7 @@ public static void transitionExecutionsState( taskExecutionState = createFailedTaskExecutionState(execution.getAttemptId(), throwable); } else { - throw new UnsupportedOperationException("Unsupported state " + state); + taskExecutionState = new TaskExecutionState(execution.getAttemptId(), state); } scheduler.updateTaskExecutionState(taskExecutionState); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryTest.java new file mode 100644 index 0000000000000..64e3e609bb74a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryTest.java @@ -0,0 +1,1144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.eventtime.WatermarkAlignmentParams; +import org.apache.flink.api.connector.source.Boundedness; +import org.apache.flink.api.connector.source.mocks.MockSource; +import org.apache.flink.api.connector.source.mocks.MockSourceSplit; +import org.apache.flink.api.connector.source.mocks.MockSplitEnumerator; +import org.apache.flink.configuration.BatchExecutionOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph; +import org.apache.flink.runtime.executiongraph.Execution; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; +import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.executiongraph.TestingComponentMainThreadExecutor; +import org.apache.flink.runtime.executiongraph.failover.FixedDelayRestartBackoffTimeStrategy; +import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTracker; +import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTrackerImpl; +import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.jobmaster.event.ExecutionVertexFinishedEvent; +import org.apache.flink.runtime.jobmaster.event.FileSystemJobEventStore; +import org.apache.flink.runtime.jobmaster.event.JobEvent; +import org.apache.flink.runtime.jobmaster.event.JobEventManager; +import org.apache.flink.runtime.jobmaster.event.JobEventStore; +import org.apache.flink.runtime.jobmaster.utils.TestingJobMasterGateway; +import org.apache.flink.runtime.jobmaster.utils.TestingJobMasterGatewayBuilder; +import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinatorHolder; +import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.TestingOperatorCoordinator; +import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder; +import org.apache.flink.runtime.scheduler.SchedulerBase; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; +import org.apache.flink.runtime.scheduler.strategy.SchedulingTopology; +import org.apache.flink.runtime.shuffle.DefaultShuffleMetrics; +import org.apache.flink.runtime.shuffle.JobShuffleContextImpl; +import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; +import org.apache.flink.runtime.shuffle.NettyShuffleMaster; +import org.apache.flink.runtime.shuffle.PartitionWithMetrics; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleMaster; +import org.apache.flink.runtime.shuffle.ShuffleMetrics; +import org.apache.flink.runtime.source.coordinator.SourceCoordinator; +import org.apache.flink.runtime.source.coordinator.SourceCoordinatorProvider; +import org.apache.flink.runtime.source.event.ReaderRegistrationEvent; +import org.apache.flink.runtime.testtasks.NoOpInvokable; +import org.apache.flink.runtime.testutils.CommonTestUtils; +import org.apache.flink.testutils.TestingUtils; +import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameter; +import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameters; +import org.apache.flink.testutils.junit.utils.TempDirUtils; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.SerializedValue; +import org.apache.flink.util.concurrent.ManuallyTriggeredScheduledExecutor; +import org.apache.flink.util.concurrent.ScheduledExecutor; +import org.apache.flink.util.concurrent.ScheduledExecutorServiceAdapter; +import org.apache.flink.util.function.ThrowingRunnable; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.io.TempDir; + +import javax.annotation.Nonnull; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.waitUntilExecutionVertexState; +import static org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder.createCustomParallelismDecider; +import static org.apache.flink.util.Preconditions.checkState; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +/** Test for batch job recovery. */ +@ExtendWith(ParameterizedTestExtension.class) +public class BatchJobRecoveryTest { + + private final Duration previousWorkerRecoveryTimeout = Duration.ofSeconds(1); + + @TempDir private java.nio.file.Path temporaryFolder; + + // ---- Mocks for the underlying Operator Coordinator Context --- + protected EventReceivingTasks receivingTasks; + + @RegisterExtension + static final TestExecutorExtension EXECUTOR_RESOURCE = + TestingUtils.defaultExecutorExtension(); + + @RegisterExtension + static final TestingComponentMainThreadExecutor.Extension MAIN_EXECUTOR_RESOURCE = + new TestingComponentMainThreadExecutor.Extension(); + + private final TestingComponentMainThreadExecutor mainThreadExecutor = + MAIN_EXECUTOR_RESOURCE.getComponentMainThreadTestExecutor(); + + private ScheduledExecutor delayedExecutor = + new ScheduledExecutorServiceAdapter(EXECUTOR_RESOURCE.getExecutor()); + + private static final OperatorID OPERATOR_ID = new OperatorID(1234L, 5678L); + private static final int NUM_SPLITS = 10; + private static final int SOURCE_PARALLELISM = 5; + private static final int MIDDLE_PARALLELISM = 5; + private static final int DECIDED_SINK_PARALLELISM = 2; + private static final JobVertexID SOURCE_ID = new JobVertexID(); + private static final JobVertexID MIDDLE_ID = new JobVertexID(); + private static final JobVertexID SINK_ID = new JobVertexID(); + private static final JobID JOB_ID = new JobID(); + + private SourceCoordinatorProvider provider; + private FileSystemJobEventStore jobEventStore; + private List persistedJobEventList; + + private byte[] serializedJobGraph; + + private final Collection allPartitionWithMetrics = new ArrayList<>(); + + @Parameter public boolean enableSpeculativeExecution; + + @Parameters(name = "enableSpeculativeExecution={0}") + public static Collection parameters() { + return Arrays.asList(false, true); + } + + @BeforeEach + void setUp() throws IOException { + final Path rootPath = new Path(TempDirUtils.newFolder(temporaryFolder).getAbsolutePath()); + delayedExecutor = new ScheduledExecutorServiceAdapter(EXECUTOR_RESOURCE.getExecutor()); + receivingTasks = EventReceivingTasks.createForRunningTasks(); + persistedJobEventList = new ArrayList<>(); + jobEventStore = + new TestingFileSystemJobEventStore( + rootPath, new Configuration(), persistedJobEventList); + + provider = + new SourceCoordinatorProvider<>( + "AdaptiveBatchSchedulerTest", + OPERATOR_ID, + new MockSource(Boundedness.BOUNDED, NUM_SPLITS), + 1, + WatermarkAlignmentParams.WATERMARK_ALIGNMENT_DISABLED, + null); + + this.serializedJobGraph = serializeJobGraph(createDefaultJobGraph()); + allPartitionWithMetrics.clear(); + } + + @AfterEach + void after() { + jobEventStore.stop(true); + } + + // This case will use job graph with the following topology: + // Source (p=5) -- POINTWISE --> Middle (p=5) -- ALLTOALL --> Sink (p=-1) + // + // This case will experience the following stages: + // 1. All source tasks are finished and all middle tasks are running + // 2. JM failover + // 3. After the failover, all source tasks are expected to be recovered to + // finished, and their produced partitions should also be restored. And middle vertex is + // redeployed. + @TestTemplate + void testRecoverFromJMFailover() throws Exception { + AdaptiveBatchScheduler scheduler = createScheduler(deserializeJobGraph(serializedJobGraph)); + + runInMainThread(scheduler::startScheduling); + + runInMainThread( + () -> { + // transition all sources to finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, SOURCE_ID); + }); + runInMainThread( + () -> { + // transition all middle tasks to RUNNING state + transitionExecutionsState(scheduler, ExecutionState.INITIALIZING, MIDDLE_ID); + transitionExecutionsState(scheduler, ExecutionState.RUNNING, MIDDLE_ID); + }); + List sourceExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(SOURCE_ID)); + List middleExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(MIDDLE_ID)); + + Map subpartitionNums = new HashMap<>(); + for (ExecutionVertex vertex : + getExecutionVertices(SOURCE_ID, scheduler.getExecutionGraph())) { + + IntermediateResultPartition partition = + vertex.getProducedPartitions().values().iterator().next(); + subpartitionNums.put(partition.getPartitionId(), partition.getNumberOfSubpartitions()); + } + + waitUntilWriteExecutionVertexFinishedEventPersisted(5); + runInMainThread(() -> jobEventStore.stop(false)); + + // register all produced partitions + registerPartitions(scheduler); + + // start a new scheduler and try to recover. + AdaptiveBatchScheduler newScheduler = + createScheduler(deserializeJobGraph(serializedJobGraph)); + startSchedulingAndWaitRecoverFinish(newScheduler); + + // check source vertices state were recovered. + for (ExecutionVertex vertex : + getExecutionVertices(SOURCE_ID, newScheduler.getExecutionGraph())) { + // check state. + assertThat(sourceExecutions) + .contains(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.FINISHED); + + // check partition tracker was rebuild. + JobMasterPartitionTracker partitionTracker = + ((InternalExecutionGraphAccessor) newScheduler.getExecutionGraph()) + .getPartitionTracker(); + List resultPartitionIds = + vertex.getProducedPartitions().keySet().stream() + .map( + ((DefaultExecutionGraph) newScheduler.getExecutionGraph()) + ::createResultPartitionId) + .collect(Collectors.toList()); + for (ResultPartitionID partitionID : resultPartitionIds) { + assertThat(partitionTracker.isPartitionTracked(partitionID)).isTrue(); + } + + // check partitions are recovered + IntermediateResultPartition partition = + vertex.getProducedPartitions().values().iterator().next(); + assertThat(partition.getNumberOfSubpartitions()) + .isEqualTo(subpartitionNums.get(partition.getPartitionId())); + } + + // check middle vertices state were not recovered. + for (ExecutionVertex vertex : + getExecutionVertices(MIDDLE_ID, newScheduler.getExecutionGraph())) { + assertThat(middleExecutions) + .doesNotContain(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.DEPLOYING); + } + } + + // This case will use job graph with the following topology: + // Source (p=5) -- POINTWISE --> Middle (p=5) -- ALLTOALL --> Sink (p=-1) + // + // This case will undergo the following stages: + // 1. All source tasks are finished, as well as middle task 0, while other middle tasks are + // still running. + // The middle vertex contains an operator coordinator that does not support batch snapshot. + // 2. The partition belonging to source task0 is released because middle task0 has finished. + // 3. JM failover. + // 4. After failover, middle task0 and source task0 are expected be reset. Other source tasks + // should be restored to finished and their produced partitions should also be restored. + @TestTemplate + void testJobVertexUnFinishedAndOperatorCoordinatorNotSupportBatchSnapshot() throws Exception { + JobGraph jobGraph = deserializeJobGraph(serializedJobGraph); + JobVertex jobVertex = jobGraph.findVertexByID(MIDDLE_ID); + jobVertex.addOperatorCoordinator( + new SerializedValue<>( + new TestingOperatorCoordinator.Provider( + jobVertex.getOperatorIDs().get(0).getGeneratedOperatorID()))); + AdaptiveBatchScheduler scheduler = + createScheduler( + jobGraph, + Duration.ZERO /* make sure every finished event can flush on time.*/); + + runInMainThread(scheduler::startScheduling); + + runInMainThread( + () -> { + // transition all sources to finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, SOURCE_ID); + }); + runInMainThread( + () -> { + // transition first middle task to finished. + ExecutionVertex firstMiddle = + getExecutionVertex(MIDDLE_ID, 0, scheduler.getExecutionGraph()); + AdaptiveBatchSchedulerTest.transitionExecutionsState( + scheduler, + ExecutionState.FINISHED, + Collections.singletonList(firstMiddle.getCurrentExecutionAttempt()), + null); + }); + List sourceExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(SOURCE_ID)); + List middleExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(MIDDLE_ID)); + + waitUntilWriteExecutionVertexFinishedEventPersisted(6); + runInMainThread(() -> jobEventStore.stop(false)); + + // register partitions, the partition of source task 0 is lost, and it will be restarted + // if middle task 0 need be restarted. + int subtaskIndex = 0; + registerPartitions( + scheduler, + Collections.emptySet(), + Collections.singleton( + scheduler + .getExecutionJobVertex(SOURCE_ID) + .getTaskVertices()[subtaskIndex] + .getID())); + + // start a new scheduler and try to recover. + AdaptiveBatchScheduler newScheduler = createScheduler(jobGraph); + startSchedulingAndWaitRecoverFinish(newScheduler); + + for (ExecutionVertex vertex : + getExecutionVertices(SOURCE_ID, newScheduler.getExecutionGraph())) { + // check source task0 was reset. + if (vertex.getParallelSubtaskIndex() == subtaskIndex) { + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.DEPLOYING); + continue; + } + + // check other source tasks state were recovered. + assertThat(sourceExecutions) + .contains(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.FINISHED); + + // check partition tracker was rebuild. + JobMasterPartitionTracker partitionTracker = + ((InternalExecutionGraphAccessor) newScheduler.getExecutionGraph()) + .getPartitionTracker(); + List resultPartitionIds = + vertex.getProducedPartitions().keySet().stream() + .map( + ((DefaultExecutionGraph) newScheduler.getExecutionGraph()) + ::createResultPartitionId) + .collect(Collectors.toList()); + for (ResultPartitionID partitionID : resultPartitionIds) { + assertThat(partitionTracker.isPartitionTracked(partitionID)).isTrue(); + } + } + + for (ExecutionVertex vertex : + getExecutionVertices(MIDDLE_ID, newScheduler.getExecutionGraph())) { + assertThat(middleExecutions) + .doesNotContain(vertex.getCurrentExecutionAttempt().getAttemptId()); + + // check middle task0 is CREATED because it's waiting source task0 finished. + if (vertex.getParallelSubtaskIndex() == subtaskIndex) { + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.CREATED); + continue; + } + + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.DEPLOYING); + } + } + + // This case will use job graph with the following topology: + // Source (p=5) -- POINTWISE --> Middle (p=5) -- ALLTOALL --> Sink (p=-1) + // + // This case will undergo the following stages: + // 1. All source tasks are finished. + // The source vertex contains an operator coordinator that does not support batch snapshot. + // 2. JM failover. + // 3. After the failover, all source tasks are expected to be recovered to finished, and their + // produced partitions should also be restored. + // 4. Transition all middle task to running + // 5. Mark the partition consumed by middle task0 as missing. + // 6. All source task should be restarted. + @TestTemplate + void testJobVertexFinishedAndOperatorCoordinatorNotSupportBatchSnapshotAndPartitionNotFound() + throws Exception { + JobGraph jobGraph = deserializeJobGraph(serializedJobGraph); + JobVertex jobVertex = jobGraph.findVertexByID(SOURCE_ID); + jobVertex.addOperatorCoordinator( + new SerializedValue<>( + new TestingOperatorCoordinator.Provider( + jobVertex.getOperatorIDs().get(0).getGeneratedOperatorID()))); + AdaptiveBatchScheduler scheduler = createScheduler(jobGraph); + + runInMainThread(scheduler::startScheduling); + + runInMainThread( + () -> { + // transition all sources to finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, SOURCE_ID); + }); + List sourceExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(SOURCE_ID)); + + waitUntilWriteExecutionVertexFinishedEventPersisted(5); + runInMainThread( + () -> { + jobEventStore.stop(false); + }); + + // register all produced partitions + registerPartitions(scheduler); + + // start a new scheduler and try to recover. + AdaptiveBatchScheduler newScheduler = createScheduler(jobGraph); + startSchedulingAndWaitRecoverFinish(newScheduler); + + // check source vertices state were recovered. + for (ExecutionVertex vertex : + getExecutionVertices(SOURCE_ID, newScheduler.getExecutionGraph())) { + // check state. + assertThat(sourceExecutions) + .contains(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.FINISHED); + + // check partition tracker was rebuild. + JobMasterPartitionTracker partitionTracker = + ((InternalExecutionGraphAccessor) newScheduler.getExecutionGraph()) + .getPartitionTracker(); + List resultPartitionIds = + vertex.getProducedPartitions().keySet().stream() + .map( + ((DefaultExecutionGraph) newScheduler.getExecutionGraph()) + ::createResultPartitionId) + .collect(Collectors.toList()); + for (ResultPartitionID partitionID : resultPartitionIds) { + assertThat(partitionTracker.isPartitionTracked(partitionID)).isTrue(); + } + } + + runInMainThread( + () -> { + // transition all middle tasks to running + transitionExecutionsState(scheduler, ExecutionState.RUNNING, MIDDLE_ID); + }); + + // trigger partition not found + ExecutionVertex firstMiddleTask = + getExecutionVertex(MIDDLE_ID, 0, newScheduler.getExecutionGraph()); + triggerFailedByDataConsumptionException(newScheduler, firstMiddleTask); + + waitUntilExecutionVertexState( + getExecutionVertex(SOURCE_ID, 0, newScheduler.getExecutionGraph()), + ExecutionState.DEPLOYING, + 15000L); + + // verify all source tasks were restarted + for (int i = 0; i < 5; i++) { + assertThat( + getExecutionVertex(SOURCE_ID, i, newScheduler.getExecutionGraph()) + .getExecutionState()) + .isNotEqualTo(ExecutionState.FINISHED); + } + } + + // This case will use job graph with the following topology: + // Source (p=5) -- POINTWISE --> Middle (p=5) -- ALLTOALL --> Sink (p=-1) + // + // This case will undergo the following stages: + // 1. All source tasks are finished. source task0 lose its partitions. + // 2. JM failover. + // 3. After the failover, source task0 is expected to be reset. Other source tasks are + // recovered to finished, and their produced partitions should also be restored. + @TestTemplate + void testRecoverFromJMFailoverAndPartitionsUnavailable() throws Exception { + AdaptiveBatchScheduler scheduler = createScheduler(deserializeJobGraph(serializedJobGraph)); + + runInMainThread(scheduler::startScheduling); + + runInMainThread( + () -> { + // transition all sources to finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, SOURCE_ID); + }); + List sourceExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(SOURCE_ID)); + + waitUntilWriteExecutionVertexFinishedEventPersisted(5); + runInMainThread(() -> jobEventStore.stop(false)); + + int losePartitionsTaskIndex = 0; + + // register partitions, the partition of source task 0 is lost, and it will be restarted + // if middle task 0 need be restarted. + registerPartitions( + scheduler, + Collections.emptySet(), + Collections.singleton( + getExecutionVertex( + SOURCE_ID, + losePartitionsTaskIndex, + scheduler.getExecutionGraph()) + .getID())); + + // start a new scheduler and try to recover. + AdaptiveBatchScheduler newScheduler = + createScheduler(deserializeJobGraph(serializedJobGraph)); + startSchedulingAndWaitRecoverFinish(newScheduler); + + // check source task0 is reset and other source task are finished + List sourceTasks = + getExecutionVertices(SOURCE_ID, newScheduler.getExecutionGraph()); + for (int i = 0; i < sourceTasks.size(); i++) { + ExecutionVertex vertex = sourceTasks.get(i); + if (i == losePartitionsTaskIndex) { + assertThat(sourceExecutions) + .doesNotContain(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.DEPLOYING); + } else { + assertThat(sourceExecutions) + .contains(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.FINISHED); + } + } + } + + // This case will use job graph with the following topology: + // Source (p=5) -- POINTWISE --> Middle (p=5) -- ALLTOALL --> Sink (p=2, decided at runtime) + @TestTemplate + void testRecoverDecidedParallelismFromTheSameJobGraphInstance() throws Exception { + JobGraph jobGraph = deserializeJobGraph(serializedJobGraph); + + AdaptiveBatchScheduler scheduler = createScheduler(jobGraph); + + runInMainThread(scheduler::startScheduling); + + runInMainThread( + () -> { + // transition all sources to finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, SOURCE_ID); + }); + runInMainThread( + () -> { // transition all middle tasks to finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, MIDDLE_ID); + }); + runInMainThread( + () -> { + // transition all sinks to finished. + transitionExecutionsState(scheduler, ExecutionState.FINISHED, SINK_ID); + }); + + List sourceExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(SOURCE_ID)); + List middleExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(MIDDLE_ID)); + List sinkExecutions = + getCurrentAttemptIds(scheduler.getExecutionJobVertex(SINK_ID)); + + waitUntilWriteExecutionVertexFinishedEventPersisted(12); + runInMainThread(() -> jobEventStore.stop(false)); + + // start a new scheduler and try to recover. + AdaptiveBatchScheduler newScheduler = createScheduler(jobGraph); + startSchedulingAndWaitRecoverFinish(newScheduler); + + // check source vertices' state were recovered. + for (ExecutionVertex vertex : + getExecutionVertices(SOURCE_ID, newScheduler.getExecutionGraph())) { + assertThat(sourceExecutions) + .contains(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.FINISHED); + } + // check middle vertices' state were recovered. + for (ExecutionVertex vertex : + getExecutionVertices(MIDDLE_ID, newScheduler.getExecutionGraph())) { + assertThat(middleExecutions) + .contains(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.FINISHED); + } + + // check sink's parallelism was recovered. + assertThat(newScheduler.getExecutionJobVertex(SINK_ID).getParallelism()) + .isEqualTo(DECIDED_SINK_PARALLELISM); + // check sink vertices' state were recovered. + for (ExecutionVertex vertex : + getExecutionVertices(SINK_ID, newScheduler.getExecutionGraph())) { + assertThat(sinkExecutions).contains(vertex.getCurrentExecutionAttempt().getAttemptId()); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.FINISHED); + } + } + + // This case will use job graph with the following topology: + // Source (p=5) -- POINTWISE --> Middle (p=5) -- ALLTOALL --> Sink (p=-1) + // + // This test case verifies that sourceCoordinator's split assignments are restored after a JM + // failover, unless sources are restarted (triggered by 'partition not found' exceptions), + // to prevent any loss of assigned splits. + @TestTemplate + void testPartitionNotFoundTwiceAfterJMFailover() throws Exception { + AdaptiveBatchScheduler scheduler = createScheduler(deserializeJobGraph(serializedJobGraph)); + + runInMainThread(scheduler::startScheduling); + + // assign all splits + runInMainThread( + () -> { + final SourceCoordinator sourceCoordinator = + getInternalSourceCoordinator(scheduler.getExecutionGraph(), SOURCE_ID); + assignSplitsForAllSubTask( + sourceCoordinator, + getCurrentAttemptIds(scheduler.getExecutionJobVertex(SOURCE_ID))); + // no unassigned split now. + checkUnassignedSplits(sourceCoordinator, 0); + }); + + // transition all sources to finished. + runInMainThread( + () -> transitionExecutionsState(scheduler, ExecutionState.FINISHED, SOURCE_ID)); + + waitUntilWriteExecutionVertexFinishedEventPersisted(5); + runInMainThread( + () -> { + jobEventStore.stop(false); + }); + + // register all produced partitions + registerPartitions(scheduler); + + // start a new scheduler and try to recover. + AdaptiveBatchScheduler newScheduler = + createScheduler(deserializeJobGraph(serializedJobGraph)); + startSchedulingAndWaitRecoverFinish(newScheduler); + + final SourceCoordinator sourceCoordinator = + getInternalSourceCoordinator(newScheduler.getExecutionGraph(), SOURCE_ID); + // no unassigned split now. + runInMainThread(() -> checkUnassignedSplits(sourceCoordinator, 0)); + + // ============================= + // FIRST TIME + // ============================= + // trigger subtask 0 of first middle failed by dataConsumptionException. + ExecutionVertex firstMiddle0 = + getExecutionVertex(MIDDLE_ID, 0, newScheduler.getExecutionGraph()); + triggerFailedByDataConsumptionException(newScheduler, firstMiddle0); + // wait until reset done. + waitUntilExecutionVertexState(firstMiddle0, ExecutionState.CREATED, 15000L); + // Check whether the splits have been returned. + runInMainThread(() -> checkUnassignedSplits(sourceCoordinator, 2)); + + // ============================= + // SECOND TIME + // ============================= + // assign splits to the restarted source vertex. + runInMainThread( + () -> { + assignSplits( + sourceCoordinator, + getExecutionVertex(SOURCE_ID, 0, newScheduler.getExecutionGraph()) + .getCurrentExecutionAttempt() + .getAttemptId()); + // no unassigned split now. + checkUnassignedSplits(sourceCoordinator, 0); + }); + + // transition all sources to finished. + runInMainThread( + () -> transitionExecutionsState(newScheduler, ExecutionState.FINISHED, SOURCE_ID)); + + // trigger subtask 1 of first middle failed by dataConsumptionException. + ExecutionVertex firstMiddle1 = + getExecutionVertex(MIDDLE_ID, 1, newScheduler.getExecutionGraph()); + triggerFailedByDataConsumptionException(newScheduler, firstMiddle1); + // wait until reset done. + waitUntilExecutionVertexState(firstMiddle1, ExecutionState.CREATED, 15000L); + + // Check whether the splits have been returned. + runInMainThread(() -> checkUnassignedSplits(sourceCoordinator, 2)); + } + + @TestTemplate + void testReplayEventFailed() throws Exception { + final JobEventStore failingJobEventStore = + new JobEventStore() { + @Override + public void start() {} + + @Override + public void stop(boolean clear) {} + + @Override + public void writeEvent(JobEvent event, boolean cutBlock) {} + + @Override + public JobEvent readEvent() throws Exception { + throw new Exception(); + } + + @Override + public boolean isEmpty() { + return false; + } + }; + + final ManuallyTriggeredScheduledExecutor taskRestartExecutor = + new ManuallyTriggeredScheduledExecutor(); + delayedExecutor = taskRestartExecutor; + + final AdaptiveBatchScheduler newScheduler = + createScheduler( + deserializeJobGraph(serializedJobGraph), + failingJobEventStore, + BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MAX_PARALLELISM + .defaultValue(), + BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE.defaultValue()); + runInMainThread(newScheduler::startScheduling); + + // trigger scheduled restarting and drain the main thread actions + taskRestartExecutor.triggerScheduledTasks(); + runInMainThread(() -> {}); + + assertThat( + ExceptionUtils.findThrowableWithMessage( + newScheduler.getExecutionGraph().getFailureCause(), + "Recover failed from JM failover")) + .isPresent(); + + // source should be scheduled. + for (ExecutionVertex vertex : + getExecutionVertices(SOURCE_ID, newScheduler.getExecutionGraph())) { + assertThat(vertex.getCurrentExecutionAttempt().getAttemptNumber()).isEqualTo(1); + assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.DEPLOYING); + } + } + + private void waitUntilWriteExecutionVertexFinishedEventPersisted(int count) throws Exception { + CommonTestUtils.waitUntilCondition( + () -> + new ArrayList<>(persistedJobEventList) + .stream() + .filter( + jobEvent -> + jobEvent + instanceof + ExecutionVertexFinishedEvent) + .count() + == count); + } + + private void triggerFailedByDataConsumptionException( + SchedulerBase scheduler, ExecutionVertex executionVertex) { + // trigger execution vertex failed by dataConsumptionException. + runInMainThread( + () -> { + // it's consumed IntermediateResultPartition. + IntermediateResultPartitionID partitionId = + getConsumedResultPartitions( + scheduler.getExecutionGraph().getSchedulingTopology(), + executionVertex.getID()) + .get(0); + // trigger failed. + AdaptiveBatchSchedulerTest.transitionExecutionsState( + scheduler, + ExecutionState.FAILED, + Collections.singletonList(executionVertex.getCurrentExecutionAttempt()), + new PartitionNotFoundException( + ((DefaultExecutionGraph) scheduler.getExecutionGraph()) + .createResultPartitionId(partitionId))); + }); + } + + private void assignSplits( + SourceCoordinator sourceCoordinator, ExecutionAttemptID attemptId) { + int subtask = attemptId.getSubtaskIndex(); + int attemptNumber = attemptId.getAttemptNumber(); + sourceCoordinator.executionAttemptReady( + subtask, + attemptNumber, + receivingTasks.createGatewayForSubtask(subtask, attemptNumber)); + + // Each source subtask assign 2 splits. + sourceCoordinator.handleEventFromOperator( + subtask, + attemptNumber, + new ReaderRegistrationEvent(subtask, "location_" + subtask)); + } + + private void assignSplitsForAllSubTask( + SourceCoordinator sourceCoordinator, List attemptIds) { + attemptIds.forEach(attemptId -> assignSplits(sourceCoordinator, attemptId)); + } + + private void checkUnassignedSplits(SourceCoordinator sourceCoordinator, int expected) { + final MockSplitEnumerator newSplitEnumerator = + (MockSplitEnumerator) sourceCoordinator.getEnumerator(); + + // check splits were returned. + runInCoordinatorThread( + sourceCoordinator, + () -> assertThat(newSplitEnumerator.getUnassignedSplits()).hasSize(expected)); + } + + private void runInCoordinatorThread( + SourceCoordinator sourceCoordinator, Runnable runnable) { + try { + sourceCoordinator.getCoordinatorExecutor().submit(runnable).get(); + } catch (Exception e) { + fail("Test failed due to " + e); + } + } + + private void runInMainThread(@Nonnull ThrowingRunnable throwingRunnable) { + mainThreadExecutor.execute(throwingRunnable); + } + + private void registerPartitions(AdaptiveBatchScheduler scheduler) { + registerPartitions(scheduler, Collections.emptySet(), Collections.emptySet()); + } + + private void registerPartitions( + AdaptiveBatchScheduler scheduler, + Set unavailablePartitionsJobVertices, + Set unavailablePartitionsExecutionVertices) { + // register partitions + ExecutionGraph executionGraph = scheduler.getExecutionGraph(); + + List list = + executionGraph.getAllIntermediateResults().values().stream() + .flatMap(result -> Arrays.stream(result.getPartitions())) + .filter( + partition -> { + ExecutionVertex producer = + executionGraph + .getResultPartitionOrThrow( + partition.getPartitionId()) + .getProducer(); + return !unavailablePartitionsJobVertices.contains( + producer.getJobvertexId()) + && !unavailablePartitionsExecutionVertices.contains( + producer.getID()) + && producer.getExecutionState() + == ExecutionState.FINISHED; + }) + .map( + partition -> { + BlockingResultInfo resultInfo = + scheduler.getBlockingResultInfo( + partition.getIntermediateResult().getId()); + IntermediateResultPartitionID partitionId = + partition.getPartitionId(); + final Execution producer = + executionGraph + .getResultPartitionOrThrow(partitionId) + .getProducer() + .getPartitionProducer(); + + ResultPartitionID resultPartitionID = + new ResultPartitionID( + partitionId, producer.getAttemptId()); + + DefaultShuffleMetrics metrics = + new DefaultShuffleMetrics( + resultInfo == null + ? new ResultPartitionBytes(new long[0]) + : new ResultPartitionBytes( + new long + [resultInfo + .getNumSubpartitions( + 0)])); + return new TestPartitionWithMetrics(resultPartitionID, metrics); + }) + .collect(Collectors.toList()); + + allPartitionWithMetrics.addAll(list); + } + + private void startSchedulingAndWaitRecoverFinish(AdaptiveBatchScheduler scheduler) + throws Exception { + runInMainThread(scheduler::startScheduling); + + // wait recover start + CommonTestUtils.waitUntilCondition(scheduler::isRecovering); + + // wait recover finish + CommonTestUtils.waitUntilCondition(() -> !scheduler.isRecovering()); + } + + private static SourceCoordinator getInternalSourceCoordinator( + final ExecutionGraph executionGraph, final JobVertexID sourceID) throws Exception { + ExecutionJobVertex sourceJobVertex = executionGraph.getJobVertex(sourceID); + OperatorCoordinatorHolder operatorCoordinatorHolder = + new ArrayList<>(sourceJobVertex.getOperatorCoordinators()).get(0); + final RecreateOnResetOperatorCoordinator coordinator = + (RecreateOnResetOperatorCoordinator) operatorCoordinatorHolder.coordinator(); + return (SourceCoordinator) coordinator.getInternalCoordinator(); + } + + private static List getConsumedResultPartitions( + final SchedulingTopology schedulingTopology, + final ExecutionVertexID executionVertexId) { + return StreamSupport.stream( + schedulingTopology + .getVertex(executionVertexId) + .getConsumedResults() + .spliterator(), + false) + .map(SchedulingResultPartition::getId) + .collect(Collectors.toList()); + } + + /** Transit the state of all executions in the Job Vertex. */ + public static void transitionExecutionsState( + final SchedulerBase scheduler, + final ExecutionState state, + final JobVertexID jobVertexID) { + AdaptiveBatchSchedulerTest.transitionExecutionsState( + scheduler, state, scheduler.getExecutionJobVertex(jobVertexID).getJobVertex()); + } + + /** + * Create job vertices and connect them as the following JobGraph: + * + *
+     *  	source -|-> middle -|-> sink
+     * 
+ * + *

Parallelism of source and middle is 5. + * + *

Edge (source --> middle) is BLOCKING and POINTWISE. Edge (middle --> sink) is BLOCKING and + * ALL_TO_ALL. + * + *

Source has an operator coordinator. + */ + private JobGraph createDefaultJobGraph() throws IOException { + List jobVertices = new ArrayList<>(); + + final JobVertex source = new JobVertex("source", SOURCE_ID); + source.setInvokableClass(NoOpInvokable.class); + source.addOperatorCoordinator(new SerializedValue<>(provider)); + source.setParallelism(SOURCE_PARALLELISM); + jobVertices.add(source); + + final JobVertex middle = new JobVertex("middle", MIDDLE_ID); + middle.setInvokableClass(NoOpInvokable.class); + middle.setParallelism(MIDDLE_PARALLELISM); + jobVertices.add(middle); + + final JobVertex sink = new JobVertex("sink", SINK_ID); + sink.setInvokableClass(NoOpInvokable.class); + jobVertices.add(sink); + + middle.connectNewDataSetAsInput( + source, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING); + sink.connectNewDataSetAsInput( + middle, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); + + return new JobGraph(JOB_ID, "TestJob", jobVertices.toArray(new JobVertex[0])); + } + + private static ExecutionVertex getExecutionVertex( + final JobVertexID jobVertexId, int subtask, final ExecutionGraph executionGraph) { + return getExecutionVertices(jobVertexId, executionGraph).get(subtask); + } + + private static List getExecutionVertices( + final JobVertexID jobVertexId, final ExecutionGraph executionGraph) { + checkState(executionGraph.getJobVertex(jobVertexId).isInitialized()); + return Arrays.asList(executionGraph.getJobVertex(jobVertexId).getTaskVertices()); + } + + private static List getCurrentAttemptIds( + final ExecutionJobVertex jobVertex) { + checkState(jobVertex.isInitialized()); + return Arrays.stream(jobVertex.getTaskVertices()) + .map(executionVertex -> executionVertex.getCurrentExecutionAttempt().getAttemptId()) + .collect(Collectors.toList()); + } + + private AdaptiveBatchScheduler createScheduler(final JobGraph jobGraph) throws Exception { + return createScheduler( + jobGraph, + jobEventStore, + BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MAX_PARALLELISM.defaultValue(), + BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE.defaultValue()); + } + + private AdaptiveBatchScheduler createScheduler( + final JobGraph jobGraph, final Duration jobRecoverySnapshotMinPause) throws Exception { + return createScheduler( + jobGraph, + jobEventStore, + BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MAX_PARALLELISM.defaultValue(), + jobRecoverySnapshotMinPause); + } + + private AdaptiveBatchScheduler createScheduler( + final JobGraph jobGraph, + final JobEventStore jobEventStore, + int defaultMaxParallelism, + Duration jobRecoverySnapshotMinPause) + throws Exception { + + final ShuffleMaster shuffleMaster = + new NettyShuffleMaster(new Configuration()); + TestingJobMasterGateway jobMasterGateway = + new TestingJobMasterGatewayBuilder() + .setGetPartitionWithMetricsFunction( + (timeout, set) -> + CompletableFuture.completedFuture(allPartitionWithMetrics)) + .build(); + shuffleMaster.registerJob(new JobShuffleContextImpl(jobGraph.getJobID(), jobMasterGateway)); + final JobMasterPartitionTracker partitionTracker = + new JobMasterPartitionTrackerImpl( + jobGraph.getJobID(), shuffleMaster, ignored -> Optional.empty()); + + Configuration jobMasterConfig = new Configuration(); + jobMasterConfig.set( + BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE, jobRecoverySnapshotMinPause); + jobMasterConfig.set(BatchExecutionOptions.JOB_RECOVERY_ENABLED, true); + jobMasterConfig.set( + BatchExecutionOptions.JOB_RECOVERY_PREVIOUS_WORKER_RECOVERY_TIMEOUT, + previousWorkerRecoveryTimeout); + + DefaultSchedulerBuilder schedulerBuilder = + new DefaultSchedulerBuilder( + jobGraph, + mainThreadExecutor.getMainThreadExecutor(), + EXECUTOR_RESOURCE.getExecutor()) + .setRestartBackoffTimeStrategy( + new FixedDelayRestartBackoffTimeStrategy + .FixedDelayRestartBackoffTimeStrategyFactory(10, 0) + .create()) + .setShuffleMaster(shuffleMaster) + .setJobMasterConfiguration(jobMasterConfig) + .setPartitionTracker(partitionTracker) + .setDelayExecutor(delayedExecutor) + .setJobRecoveryHandler( + new DefaultBatchJobRecoveryHandler( + new JobEventManager(jobEventStore), jobMasterConfig)) + .setVertexParallelismAndInputInfosDecider( + createCustomParallelismDecider(DECIDED_SINK_PARALLELISM)) + .setDefaultMaxParallelism(defaultMaxParallelism); + + return schedulerBuilder.buildAdaptiveBatchJobScheduler(enableSpeculativeExecution); + } + + private byte[] serializeJobGraph(final JobGraph jobGraph) throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + ObjectOutputStream oss = new ObjectOutputStream(byteArrayOutputStream); + oss.writeObject(jobGraph); + return byteArrayOutputStream.toByteArray(); + } + + private JobGraph deserializeJobGraph(final byte[] serializedJobGraph) throws Exception { + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(serializedJobGraph); + ObjectInputStream ois = new ObjectInputStream(byteArrayInputStream); + return (JobGraph) ois.readObject(); + } + + private static class TestingFileSystemJobEventStore extends FileSystemJobEventStore { + + private final List persistedJobEventList; + + public TestingFileSystemJobEventStore( + Path workingDir, Configuration configuration, List persistedJobEventList) + throws IOException { + super(workingDir, configuration); + this.persistedJobEventList = persistedJobEventList; + } + + @Override + protected void writeEventRunnable(JobEvent event, boolean cutBlock) { + super.writeEventRunnable(event, cutBlock); + persistedJobEventList.add(event); + } + } + + private static class TestPartitionWithMetrics implements PartitionWithMetrics { + + private final ResultPartitionID resultPartitionID; + private final ShuffleMetrics metrics; + + public TestPartitionWithMetrics( + ResultPartitionID resultPartitionID, ShuffleMetrics metrics) { + this.resultPartitionID = resultPartitionID; + this.metrics = metrics; + } + + @Override + public ShuffleMetrics getPartitionMetrics() { + return metrics; + } + + @Override + public ShuffleDescriptor getPartition() { + return new ShuffleDescriptor() { + @Override + public ResultPartitionID getResultPartitionID() { + return resultPartitionID; + } + + @Override + public Optional storesLocalResourcesOn() { + return Optional.empty(); + } + }; + } + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/scheduling/JMFailoverITCase.java b/flink-tests/src/test/java/org/apache/flink/test/scheduling/JMFailoverITCase.java new file mode 100644 index 0000000000000..d38cd8a4db980 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/scheduling/JMFailoverITCase.java @@ -0,0 +1,811 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.scheduling; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.configuration.BatchExecutionOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.CoreOptions; +import org.apache.flink.configuration.ExecutionOptions; +import org.apache.flink.configuration.HighAvailabilityOptions; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.RestartStrategyOptions; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.connector.file.src.reader.TextLineInputFormat; +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.executiongraph.AccessExecutionGraph; +import org.apache.flink.runtime.highavailability.HighAvailabilityServices; +import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServicesWithLeadershipControl; +import org.apache.flink.runtime.io.network.partition.PartitionedFile; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobType; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.jobmaster.JobResult; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.TestingMiniCluster; +import org.apache.flink.runtime.minicluster.TestingMiniClusterConfiguration; +import org.apache.flink.runtime.testutils.CommonTestUtils; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.graph.GlobalStreamExchangeMode; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.testutils.TestingUtils; +import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.NetUtils; +import org.apache.flink.util.function.SupplierWithException; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.io.TempDir; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.attribute.BasicFileAttributes; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.util.Preconditions.checkState; +import static org.assertj.core.api.Assertions.assertThat; + +/** ITCase for JM failover. */ +class JMFailoverITCase { + + // to speed up recovery + private final Duration previousWorkerRecoveryTimeout = Duration.ofSeconds(3); + + @RegisterExtension + static final TestExecutorExtension EXECUTOR_EXTENSION = + TestingUtils.defaultExecutorExtension(); + + private static final int DEFAULT_MAX_PARALLELISM = 4; + private static final int SOURCE_PARALLELISM = 8; + + private static final int NUMBER_KEYS = 10000; + private static final int NUMBER_OF_EACH_KEY = 4; + + private EmbeddedHaServicesWithLeadershipControl highAvailabilityServices; + + private String methodName; + + @TempDir java.nio.file.Path temporaryFolder; + + protected int numTaskManagers = 4; + + protected int numSlotsPerTaskManager = 4; + + protected Configuration flinkConfiguration = new Configuration(); + + protected MiniCluster flinkCluster; + + protected Supplier highAvailabilityServicesSupplier = null; + + @BeforeEach + void before(TestInfo testInfo) throws Exception { + flinkConfiguration = new Configuration(); + SourceTail.clear(); + StubMapFunction.clear(); + StubRecordSink.clear(); + testInfo.getTestMethod().ifPresent(method -> methodName = method.getName()); + } + + @AfterEach + void after() { + Throwable exception = null; + + try { + if (flinkCluster != null) { + flinkCluster.close(); + } + } catch (Throwable throwable) { + exception = throwable; + } + + if (exception != null) { + ExceptionUtils.rethrow(exception); + } + } + + public void setup() throws Exception { + SourceTail.clear(); + StubMapFunction.clear(); + StubRecordSink.clear(); + } + + @Test + void testRecoverFromJMFailover() throws Exception { + JobGraph jobGraph = prepareEnvAndGetJobGraph(); + + // blocking all sink + StubRecordSink.blockSubTasks(0, 1, 2, 3); + + JobID jobId = flinkCluster.submitJob(jobGraph).get().getJobID(); + + // wait until sink is running. + tryWaitUntilCondition(() -> StubRecordSink.attemptIds.size() > 0); + + triggerJMFailover(jobId); + + // unblock all sink. + StubRecordSink.unblockSubTasks(0, 1, 2, 3); + + JobResult jobResult = flinkCluster.requestJobResult(jobId).get(); + assertThat(jobResult.getSerializedThrowable()).isEmpty(); + + checkCountResults(); + } + + @Test + void testSourceNotAllFinished() throws Exception { + JobGraph jobGraph = prepareEnvAndGetJobGraph(); + + // blocking source 0 + SourceTail.blockSubTasks(0); + + JobID jobId = flinkCluster.submitJob(jobGraph).get().getJobID(); + + // wait until source is running. + tryWaitUntilCondition(() -> SourceTail.attemptIds.size() == SOURCE_PARALLELISM); + + JobVertex source = jobGraph.getVerticesSortedTopologicallyFromSources().get(0); + while (true) { + AccessExecutionGraph executionGraph = flinkCluster.getExecutionGraph(jobId).get(); + long finishedTasks = + Arrays.stream(executionGraph.getJobVertex(source.getID()).getTaskVertices()) + .filter(task -> task.getExecutionState() == ExecutionState.FINISHED) + .count(); + if (finishedTasks == SOURCE_PARALLELISM - 1) { + break; + } + + Thread.sleep(100L); + } + + triggerJMFailover(jobId); + + // unblock source 0. + SourceTail.unblockSubTasks(0); + + JobResult jobResult = flinkCluster.requestJobResult(jobId).get(); + assertThat(jobResult.getSerializedThrowable()).isEmpty(); + + checkCountResults(); + } + + @Test + void testTaskExecutorNotRegisterOnTime() throws Exception { + Configuration configuration = new Configuration(); + configuration.set( + BatchExecutionOptions.JOB_RECOVERY_PREVIOUS_WORKER_RECOVERY_TIMEOUT, Duration.ZERO); + JobGraph jobGraph = prepareEnvAndGetJobGraph(configuration); + + // blocking all sink + StubRecordSink.blockSubTasks(0, 1, 2, 3); + + JobID jobId = flinkCluster.submitJob(jobGraph).get().getJobID(); + + // wait until sink is running. + tryWaitUntilCondition(() -> StubRecordSink.attemptIds.size() > 0); + + triggerJMFailover(jobId); + + // unblock all sink. + StubRecordSink.unblockSubTasks(0, 1, 2, 3); + + JobResult jobResult = flinkCluster.requestJobResult(jobId).get(); + assertThat(jobResult.getSerializedThrowable()).isEmpty(); + + checkCountResults(); + } + + @Test + void testPartitionNotFoundTwice() throws Exception { + JobGraph jobGraph = prepareEnvAndGetJobGraph(); + + // blocking map 0 and map 1. + StubMapFunction.blockSubTasks(0, 1); + + JobID jobId = flinkCluster.submitJob(jobGraph).get().getJobID(); + + // wait until map deploying, which indicates all source finished. + tryWaitUntilCondition(() -> StubMapFunction.attemptIds.size() > 0); + + triggerJMFailover(jobId); + + // trigger partition not found. + releaseResultPartitionOfSource(); + + // map 0 unblock. + StubMapFunction.unblockSubTasks(0); + + // wait until map 0 restart, which indicates all source finished again. + tryWaitUntilCondition(() -> StubMapFunction.attemptIds.get(0) == 1); + + // trigger partition not found. + releaseResultPartitionOfSource(); + + // map 1 unblock. + StubMapFunction.unblockSubTasks(1); + + JobResult jobResult = flinkCluster.requestJobResult(jobId).get(); + assertThat(jobResult.getSerializedThrowable()).isEmpty(); + + checkCountResults(); + } + + @Test + void testPartitionNotFoundAndOperatorCoordinatorNotSupportBatchSnapshot() throws Exception { + JobGraph jobGraph = prepareEnvAndGetJobGraph(false); + + // blocking all map task + StubMapFunction2.blockSubTasks(0, 1, 2, 3); + + JobID jobId = flinkCluster.submitJob(jobGraph).get().getJobID(); + + // wait until map deploying, which indicates all source finished. + tryWaitUntilCondition(() -> StubMapFunction2.attemptIds.size() > 0); + + triggerJMFailover(jobId); + + // trigger partition not found. + releaseResultPartitionOfSource(); + + // map tasks unblock. + StubMapFunction2.unblockSubTasks(0, 1, 2, 3); + + JobResult jobResult = flinkCluster.requestJobResult(jobId).get(); + assertThat(jobResult.getSerializedThrowable()).isEmpty(); + + checkCountResults(); + } + + @Test + void testPartitionNotFoundAndOperatorCoordinatorSupportBatchSnapshot() throws Exception { + JobGraph jobGraph = prepareEnvAndGetJobGraph(); + + // blocking map 0. + StubMapFunction.blockSubTasks(0); + + JobID jobId = flinkCluster.submitJob(jobGraph).get().getJobID(); + + // wait until map deploying, which indicates all source finished. + tryWaitUntilCondition(() -> StubMapFunction.attemptIds.size() > 0); + + triggerJMFailover(jobId); + + // trigger partition not found. + releaseResultPartitionOfSource(); + + // map 0 unblock. + StubMapFunction.unblockSubTasks(0); + + JobResult jobResult = flinkCluster.requestJobResult(jobId).get(); + assertThat(jobResult.getSerializedThrowable()).isEmpty(); + + checkCountResults(); + } + + private JobGraph prepareEnvAndGetJobGraph() throws Exception { + Configuration configuration = new Configuration(); + configuration.set( + BatchExecutionOptions.JOB_RECOVERY_PREVIOUS_WORKER_RECOVERY_TIMEOUT, + previousWorkerRecoveryTimeout); + return prepareEnvAndGetJobGraph(configuration, true); + } + + private JobGraph prepareEnvAndGetJobGraph(Configuration config) throws Exception { + return prepareEnvAndGetJobGraph(config, true); + } + + private JobGraph prepareEnvAndGetJobGraph(boolean operatorCoordinatorsSupportsBatchSnapshot) + throws Exception { + Configuration configuration = new Configuration(); + configuration.set( + BatchExecutionOptions.JOB_RECOVERY_PREVIOUS_WORKER_RECOVERY_TIMEOUT, + previousWorkerRecoveryTimeout); + return prepareEnvAndGetJobGraph(configuration, operatorCoordinatorsSupportsBatchSnapshot); + } + + private JobGraph prepareEnvAndGetJobGraph( + Configuration config, boolean operatorCoordinatorsSupportsBatchSnapshot) + throws Exception { + flinkCluster = + TestingMiniCluster.newBuilder(getMiniClusterConfiguration(config)) + .setHighAvailabilityServicesSupplier(highAvailabilityServicesSupplier) + .build(); + flinkCluster.start(); + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(-1); + env.setRuntimeMode(RuntimeExecutionMode.BATCH); + + return operatorCoordinatorsSupportsBatchSnapshot + ? createJobGraph(env, methodName) + : createJobGraphWithUnsupportedBatchSnapshotOperatorCoordinator(env, methodName); + } + + private TestingMiniClusterConfiguration getMiniClusterConfiguration(Configuration config) + throws IOException { + // flink basic configuration. + NetUtils.Port jobManagerRpcPort = NetUtils.getAvailablePort(); + flinkConfiguration.set(ExecutionOptions.RUNTIME_MODE, RuntimeExecutionMode.BATCH); + flinkConfiguration.set(JobManagerOptions.PORT, jobManagerRpcPort.getPort()); + flinkConfiguration.set(JobManagerOptions.SLOT_REQUEST_TIMEOUT, 5000L); + flinkConfiguration.set(RestOptions.BIND_PORT, "0"); + flinkConfiguration.set(TaskManagerOptions.TOTAL_PROCESS_MEMORY, MemorySize.parse("1g")); + flinkConfiguration.set(TaskManagerOptions.NETWORK_MEMORY_FRACTION, 0.4F); + + // adaptive batch job scheduler config. + flinkConfiguration.set( + JobManagerOptions.SCHEDULER, JobManagerOptions.SchedulerType.AdaptiveBatch); + flinkConfiguration.set( + BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MAX_PARALLELISM, + DEFAULT_MAX_PARALLELISM); + flinkConfiguration.set( + BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK, + MemorySize.parse("256K")); + + // enable jm failover. + flinkConfiguration.set(BatchExecutionOptions.JOB_RECOVERY_ENABLED, true); + flinkConfiguration.set( + BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE, Duration.ZERO); + + // region failover config. + flinkConfiguration.set(JobManagerOptions.EXECUTION_FAILOVER_STRATEGY, "region"); + flinkConfiguration.set(RestartStrategyOptions.RESTART_STRATEGY, "fixed-delay"); + flinkConfiguration.set(RestartStrategyOptions.RESTART_STRATEGY_FIXED_DELAY_ATTEMPTS, 10); + + // ha config, which helps to trigger jm failover. + flinkConfiguration.set(HighAvailabilityOptions.HA_STORAGE_PATH, temporaryFolder.toString()); + highAvailabilityServices = + new EmbeddedHaServicesWithLeadershipControl(EXECUTOR_EXTENSION.getExecutor()); + highAvailabilityServicesSupplier = () -> highAvailabilityServices; + + // shuffle dir, to help trigger partitionNotFoundException + flinkConfiguration.set(CoreOptions.TMP_DIRS, temporaryFolder.toString()); + + // add user defined config + flinkConfiguration.addAll(config); + + return TestingMiniClusterConfiguration.newBuilder() + .setConfiguration(flinkConfiguration) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build(); + } + + private void triggerJMFailover(JobID jobId) throws Exception { + highAvailabilityServices.revokeJobMasterLeadership(jobId).get(); + highAvailabilityServices.grantJobMasterLeadership(jobId); + } + + private static void checkCountResults() { + Map countResults = StubRecordSink.countResults; + assertThat(countResults.size()).isEqualTo(NUMBER_KEYS); + + Map expectedResult = + IntStream.range(0, NUMBER_KEYS) + .boxed() + .collect(Collectors.toMap(Function.identity(), i -> NUMBER_OF_EACH_KEY)); + assertThat(countResults).isEqualTo(expectedResult); + } + + private void releaseResultPartitionOfSource() { + deleteOldestFileInShuffleNettyDirectory( + new File(flinkConfiguration.get(CoreOptions.TMP_DIRS))); + } + + private JobGraph createJobGraph(StreamExecutionEnvironment env, String jobName) { + TupleTypeInfo> typeInfo = + new TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); + + env.fromSequence(0, NUMBER_KEYS * NUMBER_OF_EACH_KEY - 1) + .setParallelism(SOURCE_PARALLELISM) + .slotSharingGroup("group1") + .transform("SourceTail", TypeInformation.of(Long.class), new SourceTail()) + .setParallelism(SOURCE_PARALLELISM) + .slotSharingGroup("group1") + .transform("Map", typeInfo, new StubMapFunction()) + .slotSharingGroup("group2") + .keyBy(tuple2 -> tuple2.f0) + .sum(1) + .slotSharingGroup("group3") + .transform("Sink", TypeInformation.of(Void.class), new StubRecordSink()) + .slotSharingGroup("group4"); + + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING); + streamGraph.setJobType(JobType.BATCH); + streamGraph.setJobName(jobName); + return StreamingJobGraphGenerator.createJobGraph(streamGraph); + } + + private JobGraph createJobGraphWithUnsupportedBatchSnapshotOperatorCoordinator( + StreamExecutionEnvironment env, String jobName) throws Exception { + + TupleTypeInfo> typeInfo = + new TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); + + File file = new File(temporaryFolder.getParent().toFile(), "data.tmp-" + UUID.randomUUID()); + prepareTestData(file); + + FileSource source = + FileSource.forRecordStreamFormat( + new TextLineInputFormat(), new Path(file.getPath())) + .build(); + + env.fromSource(source, WatermarkStrategy.noWatermarks(), "source") + .setParallelism(SOURCE_PARALLELISM) + .slotSharingGroup("group1") + .transform("Map", typeInfo, new StubMapFunction2()) + .slotSharingGroup("group2") + .keyBy(tuple2 -> tuple2.f0) + .sum(1) + .slotSharingGroup("group3") + .transform("Sink", TypeInformation.of(Void.class), new StubRecordSink()) + .slotSharingGroup("group4"); + + StreamGraph streamGraph = env.getStreamGraph(); + streamGraph.setGlobalStreamExchangeMode(GlobalStreamExchangeMode.ALL_EDGES_BLOCKING); + streamGraph.setJobType(JobType.BATCH); + streamGraph.setJobName(jobName); + return StreamingJobGraphGenerator.createJobGraph(streamGraph); + } + + private static void setSubtaskBlocked( + List indices, boolean block, Map subtaskBlocked) { + indices.forEach(index -> subtaskBlocked.put(index, block)); + } + + /** + * A stub which helps to: + * + *

1. Get source tasks' information. (Such as {@link ResultPartitionID}). + * + *

2. Manually control the execution of source task. Helps to block and unblock execution of + * source task. + * + *

This operator should be chained with source operator. + */ + private static class SourceTail extends AbstractStreamOperator + implements OneInputStreamOperator { + + public static Map subtaskBlocked = new ConcurrentHashMap<>(); + public static Map resultPartitions = new ConcurrentHashMap<>(); + public static Map attemptIds = new ConcurrentHashMap<>(); + + public SourceTail() { + super(); + // chain with source. + setChainingStrategy(ChainingStrategy.ALWAYS); + } + + @Override + public void setup( + StreamTask containingTask, + StreamConfig config, + Output> output) { + super.setup(containingTask, config, output); + + int subIdx = getRuntimeContext().getIndexOfThisSubtask(); + + // attempt id ++ + attemptIds.compute( + subIdx, + (ignored, value) -> { + if (value == null) { + value = 0; + } else { + value += 1; + } + return value; + }); + + // record result partition id. + Environment environment = getContainingTask().getEnvironment(); + checkState(environment.getAllWriters().length == 1); + resultPartitions.put(subIdx, environment.getAllWriters()[0].getPartitionId()); + + // wait until unblocked. + if (subtaskBlocked.containsKey(subIdx) && subtaskBlocked.get(subIdx)) { + tryWaitUntilCondition(() -> !subtaskBlocked.get(subIdx)); + } + } + + @Override + public void processElement(StreamRecord streamRecord) throws Exception { + output.collect(streamRecord); + } + + public static void clear() { + subtaskBlocked.clear(); + attemptIds.clear(); + resultPartitions.clear(); + } + + public static void blockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), true, subtaskBlocked); + } + + public static void unblockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), false, subtaskBlocked); + } + } + + /** + * A special map function which can get tasks' information (Such as {@link ResultPartitionID}) + * and manually control the task's execution. + */ + private static class StubMapFunction extends AbstractStreamOperator> + implements OneInputStreamOperator> { + + public static Map subtaskBlocked = new ConcurrentHashMap<>(); + public static Map attemptIds = new ConcurrentHashMap<>(); + + @Override + public void setup( + StreamTask containingTask, + StreamConfig config, + Output>> output) { + super.setup(containingTask, config, output); + + int subIdx = getRuntimeContext().getIndexOfThisSubtask(); + + // attempt id ++ + attemptIds.compute( + subIdx, + (ignored, value) -> { + if (value == null) { + value = 0; + } else { + value += 1; + } + return value; + }); + + // wait until unblocked. + if (subtaskBlocked.containsKey(subIdx) && subtaskBlocked.get(subIdx)) { + tryWaitUntilCondition(() -> !subtaskBlocked.get(subIdx)); + } + } + + @Override + public void processElement(StreamRecord streamRecord) throws Exception { + int number = streamRecord.getValue().intValue(); + output.collect(new StreamRecord<>(new Tuple2<>(number % NUMBER_KEYS, 1))); + } + + public static void clear() { + subtaskBlocked.clear(); + attemptIds.clear(); + } + + public static void blockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), true, subtaskBlocked); + } + + public static void unblockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), false, subtaskBlocked); + } + } + + private static class StubMapFunction2 extends AbstractStreamOperator> + implements OneInputStreamOperator> { + + public static Map subtaskBlocked = new ConcurrentHashMap<>(); + public static Map attemptIds = new ConcurrentHashMap<>(); + + @Override + public void setup( + StreamTask containingTask, + StreamConfig config, + Output>> output) { + super.setup(containingTask, config, output); + + int subIdx = getRuntimeContext().getIndexOfThisSubtask(); + + // attempt id ++ + attemptIds.compute( + subIdx, + (ignored, value) -> { + if (value == null) { + value = 0; + } else { + value += 1; + } + return value; + }); + + // wait until unblocked. + if (subtaskBlocked.containsKey(subIdx) && subtaskBlocked.get(subIdx)) { + tryWaitUntilCondition(() -> !subtaskBlocked.get(subIdx)); + } + } + + @Override + public void processElement(StreamRecord streamRecord) throws Exception { + int number = Integer.parseInt(streamRecord.getValue()); + + output.collect(new StreamRecord<>(new Tuple2<>(number % NUMBER_KEYS, 1))); + } + + public static void clear() { + subtaskBlocked.clear(); + attemptIds.clear(); + } + + public static void blockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), true, subtaskBlocked); + } + + public static void unblockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), false, subtaskBlocked); + } + } + + /** A special sink function which can control the task's execution. */ + private static class StubRecordSink extends AbstractStreamOperator + implements OneInputStreamOperator, Void> { + + public static Map subtaskBlocked = new ConcurrentHashMap<>(); + public static Map attemptIds = new ConcurrentHashMap<>(); + public static Map countResults = new ConcurrentHashMap<>(); + + @Override + public void setup( + StreamTask containingTask, + StreamConfig config, + Output> output) { + super.setup(containingTask, config, output); + + int subIdx = getRuntimeContext().getIndexOfThisSubtask(); + + // attempt id ++ + attemptIds.compute( + subIdx, + (ignored, value) -> { + if (value == null) { + value = 0; + } else { + value += 1; + } + return value; + }); + + // wait until unblocked. + if (subtaskBlocked.containsKey(subIdx) && subtaskBlocked.get(subIdx)) { + tryWaitUntilCondition(() -> !subtaskBlocked.get(subIdx)); + } + } + + @Override + public void processElement(StreamRecord> streamRecord) + throws Exception { + Tuple2 value = streamRecord.getValue(); + countResults.put(value.f0, value.f1); + } + + public static void clear() { + subtaskBlocked.clear(); + attemptIds.clear(); + countResults.clear(); + } + + public static void blockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), true, subtaskBlocked); + } + + public static void unblockSubTasks(Integer... subIndices) { + setSubtaskBlocked(Arrays.asList(subIndices), false, subtaskBlocked); + } + } + + private static void tryWaitUntilCondition(SupplierWithException condition) { + try { + CommonTestUtils.waitUntilCondition(condition); + } catch (Exception exception) { + } + } + + private File prepareTestData(File datafile) throws IOException { + try (FileWriter writer = new FileWriter(datafile)) { + for (int i = 0; i < NUMBER_KEYS * NUMBER_OF_EACH_KEY; i++) { + writer.write(i + "\n"); + } + } + return datafile; + } + + private void deleteOldestFileInShuffleNettyDirectory(File directory) { + if (directory == null || !directory.exists() || !directory.isDirectory()) { + return; + } + + File[] matchingDirectories = + directory.listFiles( + file -> + file.isDirectory() + && file.getName().startsWith("flink-netty-shuffle")); + + if (matchingDirectories == null) { + return; + } + + List files = new ArrayList<>(); + for (File subdirectory : matchingDirectories) { + Arrays.stream(subdirectory.listFiles()) + .filter(file -> file.getName().endsWith(PartitionedFile.DATA_FILE_SUFFIX)) + .forEach(files::add); + } + + if (!files.isEmpty()) { + files.sort(Comparator.comparing(this::getFileCreationTime)); + files.get(0).delete(); + } + } + + private long getFileCreationTime(File file) { + try { + BasicFileAttributes attrs = + Files.readAttributes(file.toPath(), BasicFileAttributes.class); + return attrs.creationTime().toMillis(); + } catch (NoSuchFileException e) { + // TaskExecutor will delete unfinished partition file asynchronously when jom failover. + return Long.MAX_VALUE; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +}