From 4b39e4ad43874cb5caedc6ab16240281445134f2 Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Tue, 19 Oct 2021 17:10:34 +0200 Subject: [PATCH] [FLINK-25817] Let TaskLocalStateStoreImpl persist TaskStateSnapshots This commit lets the TaskLocalStateStoreImpl persist the TaskStateSnapshots into the directory of the local state checkpoint. This allows to recover the TaskStateSnapshots in case of a process crash. If the TaskStateSnapshot cannot be read then the whole local checkpointing directory will be deleted to avoid corrupted files. --- .../TaskExecutorLocalStateStoresManager.java | 6 +- .../state/TaskLocalStateStoreImpl.java | 115 +++++++++++++++--- .../state/TaskLocalStateStoreImplTest.java | 106 ++++++++++++---- 3 files changed, 191 insertions(+), 36 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java index 987f6004db38a..80367e00bd65a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java @@ -78,6 +78,11 @@ public TaskExecutorLocalStateStoresManager( @Nonnull Executor discardExecutor) throws IOException { + LOG.debug( + "Start {} with local state root directories {}.", + getClass().getSimpleName(), + localStateRootDirectories); + this.taskStateStoresByAllocationID = new HashMap<>(); this.localRecoveryEnabled = localRecoveryEnabled; this.localStateRootDirectories = localStateRootDirectories; @@ -193,7 +198,6 @@ public TaskLocalStateStore localStateStoreForSubtask( } public void releaseLocalStateForAllocationId(@Nonnull AllocationID allocationID) { - if (LOG.isDebugEnabled()) { LOG.debug("Releasing local state under allocation id {}.", allocationID); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java index af4695f105700..87f4e0c6724a6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java @@ -25,6 +25,8 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FlinkRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,7 +37,11 @@ import javax.annotation.concurrent.GuardedBy; import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.util.AbstractMap; import java.util.ArrayList; import java.util.Collection; @@ -60,6 +66,8 @@ public class TaskLocalStateStoreImpl implements OwnedTaskLocalStateStore { /** Dummy value to use instead of null to satisfy {@link ConcurrentHashMap}. */ @VisibleForTesting static final TaskStateSnapshot NULL_DUMMY = new TaskStateSnapshot(0, false); + public static final String TASK_STATE_SNAPSHOT_FILENAME = "_task_state_snapshot"; + /** JobID from the owning subtask. */ @Nonnull private final JobID jobID; @@ -165,6 +173,7 @@ public void storeLocalState( } else { TaskStateSnapshot previous = storedTaskStateByCheckpointID.put(checkpointId, localState); + persistLocalStateMetadata(checkpointId, localState); if (previous != null) { toDiscard = new AbstractMap.SimpleEntry<>(checkpointId, previous); @@ -177,6 +186,45 @@ public void storeLocalState( } } + /** + * Writes a task state snapshot file that contains the serialized content of the local state. + * + * @param checkpointId identifying the checkpoint + * @param localState task state snapshot that will be persisted + */ + private void persistLocalStateMetadata(long checkpointId, TaskStateSnapshot localState) { + final File taskStateSnapshotFile = getTaskStateSnapshotFile(checkpointId); + try (ObjectOutputStream oos = + new ObjectOutputStream(new FileOutputStream(taskStateSnapshotFile))) { + oos.writeObject(localState); + + LOG.debug( + "Successfully written local task state snapshot file {} for checkpoint {}.", + taskStateSnapshotFile, + checkpointId); + } catch (IOException e) { + ExceptionUtils.rethrow(e, "Could not write the local task state snapshot file."); + } + } + + @VisibleForTesting + File getTaskStateSnapshotFile(long checkpointId) { + final File checkpointDirectory = + localRecoveryConfig + .getLocalStateDirectoryProvider() + .orElseThrow( + () -> new IllegalStateException("Local recovery must be enabled.")) + .subtaskSpecificCheckpointDirectory(checkpointId); + + if (!checkpointDirectory.exists() && !checkpointDirectory.mkdirs()) { + throw new FlinkRuntimeException( + String.format( + "Could not create the checkpoint directory '%s'", checkpointDirectory)); + } + + return new File(checkpointDirectory, TASK_STATE_SNAPSHOT_FILENAME); + } + @Override @Nullable public TaskStateSnapshot retrieveLocalState(long checkpointID) { @@ -184,7 +232,7 @@ public TaskStateSnapshot retrieveLocalState(long checkpointID) { TaskStateSnapshot snapshot; synchronized (lock) { - snapshot = storedTaskStateByCheckpointID.get(checkpointID); + snapshot = loadTaskStateSnapshot(checkpointID); } if (snapshot != null) { @@ -216,6 +264,42 @@ public TaskStateSnapshot retrieveLocalState(long checkpointID) { return (snapshot != NULL_DUMMY) ? snapshot : null; } + @GuardedBy("lock") + @Nullable + private TaskStateSnapshot loadTaskStateSnapshot(long checkpointID) { + return storedTaskStateByCheckpointID.computeIfAbsent( + checkpointID, this::tryLoadTaskStateSnapshotFromDisk); + } + + @GuardedBy("lock") + @Nullable + private TaskStateSnapshot tryLoadTaskStateSnapshotFromDisk(long checkpointID) { + final File taskStateSnapshotFile = getTaskStateSnapshotFile(checkpointID); + + if (taskStateSnapshotFile.exists()) { + TaskStateSnapshot taskStateSnapshot = null; + try (ObjectInputStream ois = + new ObjectInputStream(new FileInputStream(taskStateSnapshotFile))) { + taskStateSnapshot = (TaskStateSnapshot) ois.readObject(); + + LOG.debug( + "Loaded task state snapshot for checkpoint {} successfully from disk.", + checkpointID); + } catch (IOException | ClassNotFoundException e) { + LOG.debug( + "Could not read task state snapshot file {} for checkpoint {}. Deleting the corresponding local state.", + taskStateSnapshotFile, + checkpointID); + + discardLocalStateForCheckpoint(checkpointID, Optional.empty()); + } + + return taskStateSnapshot; + } + + return null; + } + @Override @Nonnull public LocalRecoveryConfig getLocalRecoveryConfig() { @@ -307,14 +391,14 @@ private void asyncDiscardLocalStateForCollection( private void syncDiscardLocalStateForCollection( Collection> toDiscard) { for (Map.Entry entry : toDiscard) { - discardLocalStateForCheckpoint(entry.getKey(), entry.getValue()); + discardLocalStateForCheckpoint(entry.getKey(), Optional.of(entry.getValue())); } } /** * Helper method that discards state objects with an executor and reports exceptions to the log. */ - private void discardLocalStateForCheckpoint(long checkpointID, TaskStateSnapshot o) { + private void discardLocalStateForCheckpoint(long checkpointID, Optional o) { if (LOG.isTraceEnabled()) { LOG.trace( @@ -333,17 +417,20 @@ private void discardLocalStateForCheckpoint(long checkpointID, TaskStateSnapshot subtaskIndex); } - try { - o.discardState(); - } catch (Exception discardEx) { - LOG.warn( - "Exception while discarding local task state snapshot of checkpoint {} in subtask ({} - {} - {}).", - checkpointID, - jobID, - jobVertexID, - subtaskIndex, - discardEx); - } + o.ifPresent( + taskStateSnapshot -> { + try { + taskStateSnapshot.discardState(); + } catch (Exception discardEx) { + LOG.warn( + "Exception while discarding local task state snapshot of checkpoint {} in subtask ({} - {} - {}).", + checkpointID, + jobID, + jobVertexID, + subtaskIndex, + discardEx); + } + }); Optional directoryProviderOptional = localRecoveryConfig.getLocalStateDirectoryProvider(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java index b71d575252fc5..8c906b51fa275 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java @@ -33,53 +33,67 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; +import javax.annotation.Nonnull; + import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; -import java.util.SortedMap; -import java.util.TreeMap; +import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; /** Test for the {@link TaskLocalStateStoreImpl}. */ public class TaskLocalStateStoreImplTest extends TestLogger { - private SortedMap internalSnapshotMap; - private Object internalLock; private TemporaryFolder temporaryFolder; private File[] allocationBaseDirs; private TaskLocalStateStoreImpl taskLocalStateStore; + private JobID jobID; + private AllocationID allocationID; + private JobVertexID jobVertexID; + private int subtaskIdx; @Before public void before() throws Exception { - JobID jobID = new JobID(); - AllocationID allocationID = new AllocationID(); - JobVertexID jobVertexID = new JobVertexID(); - int subtaskIdx = 0; + jobID = new JobID(); + allocationID = new AllocationID(); + jobVertexID = new JobVertexID(); + subtaskIdx = 0; this.temporaryFolder = new TemporaryFolder(); this.temporaryFolder.create(); this.allocationBaseDirs = new File[] {temporaryFolder.newFolder(), temporaryFolder.newFolder()}; - this.internalSnapshotMap = new TreeMap<>(); - this.internalLock = new Object(); + this.taskLocalStateStore = + createTaskLocalStateStoreImpl( + allocationBaseDirs, jobID, allocationID, jobVertexID, subtaskIdx); + } + + @Nonnull + private TaskLocalStateStoreImpl createTaskLocalStateStoreImpl( + File[] allocationBaseDirs, + JobID jobID, + AllocationID allocationID, + JobVertexID jobVertexID, + int subtaskIdx) { LocalRecoveryDirectoryProviderImpl directoryProvider = new LocalRecoveryDirectoryProviderImpl( allocationBaseDirs, jobID, jobVertexID, subtaskIdx); LocalRecoveryConfig localRecoveryConfig = new LocalRecoveryConfig(directoryProvider); - - this.taskLocalStateStore = - new TaskLocalStateStoreImpl( - jobID, - allocationID, - jobVertexID, - subtaskIdx, - localRecoveryConfig, - Executors.directExecutor(), - internalSnapshotMap, - internalLock); + return new TaskLocalStateStoreImpl( + jobID, + allocationID, + jobVertexID, + subtaskIdx, + localRecoveryConfig, + Executors.directExecutor()); } @After @@ -180,6 +194,56 @@ public void dispose() throws Exception { checkPrunedAndDiscarded(taskStateSnapshots, 0, chkCount); } + @Test + public void retrieveNullIfNoPersistedLocalState() { + assertThat(taskLocalStateStore.retrieveLocalState(0)).isNull(); + } + + @Test + public void retrievePersistedLocalStateFromDisc() { + final TaskStateSnapshot taskStateSnapshot = createTaskStateSnapshot(); + final long checkpointId = 0L; + taskLocalStateStore.storeLocalState(checkpointId, taskStateSnapshot); + + final TaskLocalStateStoreImpl newTaskLocalStateStore = + createTaskLocalStateStoreImpl( + allocationBaseDirs, jobID, allocationID, jobVertexID, 0); + + final TaskStateSnapshot retrievedTaskStateSnapshot = + newTaskLocalStateStore.retrieveLocalState(checkpointId); + + assertThat(retrievedTaskStateSnapshot).isEqualTo(taskStateSnapshot); + } + + @Nonnull + private TaskStateSnapshot createTaskStateSnapshot() { + final Map operatorSubtaskStates = new HashMap<>(); + operatorSubtaskStates.put(new OperatorID(), OperatorSubtaskState.builder().build()); + operatorSubtaskStates.put(new OperatorID(), OperatorSubtaskState.builder().build()); + final TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(operatorSubtaskStates); + return taskStateSnapshot; + } + + @Test + public void deletesLocalStateIfRetrievalFails() throws IOException { + final TaskStateSnapshot taskStateSnapshot = createTaskStateSnapshot(); + final long checkpointId = 0L; + taskLocalStateStore.storeLocalState(checkpointId, taskStateSnapshot); + + final File taskStateSnapshotFile = + taskLocalStateStore.getTaskStateSnapshotFile(checkpointId); + + Files.write( + taskStateSnapshotFile.toPath(), new byte[] {1, 2, 3, 4}, StandardOpenOption.WRITE); + + final TaskLocalStateStoreImpl newTaskLocalStateStore = + createTaskLocalStateStoreImpl( + allocationBaseDirs, jobID, allocationID, jobVertexID, subtaskIdx); + + assertThat(newTaskLocalStateStore.retrieveLocalState(checkpointId)).isNull(); + assertThat(taskStateSnapshotFile.getParentFile()).doesNotExist(); + } + private void checkStoredAsExpected(List history, int start, int end) { for (int i = start; i < end; ++i) { TestingTaskStateSnapshot expected = history.get(i);