diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index 88a759d068eac..1f32a8999c9f2 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -105,6 +105,7 @@ import java.util.List; import java.util.Map; import java.util.PriorityQueue; +import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; import java.util.UUID; @@ -170,8 +171,8 @@ public class RocksDBKeyedStateBackend extends AbstractKeyedStateBackend { /** True if incremental checkpointing is enabled */ private final boolean enableIncrementalCheckpointing; - /** The sst files materialized in pending checkpoints */ - private final SortedMap> materializedSstFiles = new TreeMap<>(); + /** The state handle ids of all sst files materialized in snapshots for previous checkpoints */ + private final SortedMap> materializedSstFiles = new TreeMap<>(); /** The identifier of the last completed checkpoint */ private long lastCompletedCheckpointId = -1; @@ -720,7 +721,7 @@ private static final class RocksDBIncrementalSnapshotOperation { private final long checkpointTimestamp; /** All sst files that were part of the last previously completed checkpoint */ - private Map baseSstFiles; + private Set baseSstFiles; /** The state meta data */ private final List> stateMetaInfoSnapshots = new ArrayList<>(); @@ -732,10 +733,7 @@ private static final class RocksDBIncrementalSnapshotOperation { private final CloseableRegistry closeableRegistry = new CloseableRegistry(); // new sst files since the last completed checkpoint - private final Map newSstFiles = new HashMap<>(); - - // old sst files which have been materialized in previous completed checkpoints - private final Map oldSstFiles = new HashMap<>(); + private final Map sstFiles = new HashMap<>(); // handles to the misc files in the current snapshot private final Map miscFiles = new HashMap<>(); @@ -830,7 +828,6 @@ void takeSnapshot() throws Exception { // use the last completed checkpoint as the comparison base. baseSstFiles = stateBackend.materializedSstFiles.get(stateBackend.lastCompletedCheckpointId); - // save meta data for (Map.Entry>> stateMetaInfoEntry : stateBackend.kvStateInformation.entrySet()) { @@ -867,18 +864,17 @@ KeyedStateHandle materializeSnapshot() throws Exception { final StateHandleID stateHandleID = new StateHandleID(fileName); if (fileName.endsWith(SST_FILE_SUFFIX)) { - StreamStateHandle fileHandle = - baseSstFiles == null ? null : baseSstFiles.get(fileName); + final boolean existsAlready = + baseSstFiles == null ? false : baseSstFiles.contains(stateHandleID); - if (fileHandle == null) { - fileHandle = materializeStateData(filePath); - newSstFiles.put(stateHandleID, fileHandle); - } else { + if (existsAlready) { // we introduce a placeholder state handle, that is replaced with the // original from the shared state registry (created from a previous checkpoint) - oldSstFiles.put( + sstFiles.put( stateHandleID, - new PlaceholderStreamStateHandle(fileHandle.getStateSize())); + new PlaceholderStreamStateHandle()); + } else { + sstFiles.put(stateHandleID, materializeStateData(filePath)); } } else { StreamStateHandle fileHandle = materializeStateData(filePath); @@ -887,22 +883,17 @@ KeyedStateHandle materializeSnapshot() throws Exception { } } - Map sstFiles = - new HashMap<>(newSstFiles.size() + oldSstFiles.size()); - sstFiles.putAll(newSstFiles); - sstFiles.putAll(oldSstFiles); synchronized (stateBackend.asyncSnapshotLock) { - stateBackend.materializedSstFiles.put(checkpointId, sstFiles); + stateBackend.materializedSstFiles.put(checkpointId, sstFiles.keySet()); } return new IncrementalKeyedStateHandle( stateBackend.operatorIdentifier, stateBackend.keyGroupRange, checkpointId, - newSstFiles, - oldSstFiles, + sstFiles, miscFiles, metaStateHandle); } @@ -933,7 +924,7 @@ void releaseResources(boolean canceled) { statesToDiscard.add(metaStateHandle); statesToDiscard.addAll(miscFiles.values()); - statesToDiscard.addAll(newSstFiles.values()); + statesToDiscard.addAll(sstFiles.values()); try { StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard); @@ -1308,15 +1299,12 @@ private void restoreInstance( UUID.randomUUID().toString()); try { - final Map newSstFiles = - restoreStateHandle.getCreatedSharedState(); - final Map oldSstFiles = - restoreStateHandle.getReferencedSharedState(); + final Map sstFiles = + restoreStateHandle.getSharedState(); final Map miscFiles = restoreStateHandle.getPrivateState(); - readAllStateData(newSstFiles, restoreInstancePath); - readAllStateData(oldSstFiles, restoreInstancePath); + readAllStateData(sstFiles, restoreInstancePath); readAllStateData(miscFiles, restoreInstancePath); // read meta data @@ -1409,8 +1397,7 @@ private void restoreInstance( throw new IOException("Could not create RocksDB data directory."); } - createFileHardLinksInRestorePath(newSstFiles, restoreInstancePath); - createFileHardLinksInRestorePath(oldSstFiles, restoreInstancePath); + createFileHardLinksInRestorePath(sstFiles, restoreInstancePath); createFileHardLinksInRestorePath(miscFiles, restoreInstancePath); List columnFamilyHandles = new ArrayList<>(); @@ -1437,10 +1424,7 @@ private void restoreInstance( // use the restore sst files as the base for succeeding checkpoints - Map sstFiles = new HashMap<>(); - sstFiles.putAll(newSstFiles); - sstFiles.putAll(oldSstFiles); - stateBackend.materializedSstFiles.put(restoreStateHandle.getCheckpointId(), sstFiles); + stateBackend.materializedSstFiles.put(restoreStateHandle.getCheckpointId(), sstFiles.keySet()); stateBackend.lastCompletedCheckpointId = restoreStateHandle.getCheckpointId(); } diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java index 9340455a3e744..89eb1d55cc65b 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java @@ -26,18 +26,22 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.core.testutils.OneShotLatch; -import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.IncrementalKeyedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.StateBackendTestBase; +import org.apache.flink.runtime.state.StateHandleID; +import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.filesystem.FsStateBackend; +import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -58,7 +62,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; import java.util.List; +import java.util.Map; +import java.util.Queue; import java.util.concurrent.RunnableFuture; import static junit.framework.TestCase.assertNotNull; @@ -67,6 +75,7 @@ import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.internal.verification.VerificationModeFactory.times; import static org.powermock.api.mockito.PowerMockito.mock; @@ -351,6 +360,83 @@ public void testDisposeDeletesAllDirectories() throws Exception { assertEquals(1, allFilesInDbDir.size()); } + @Test + public void testSharedIncrementalStateDeRegistration() throws Exception { + if (enableIncrementalCheckpointing) { + AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); + ValueStateDescriptor kvId = + new ValueStateDescriptor<>("id", String.class, null); + + kvId.initializeSerializerUnlessSet(new ExecutionConfig()); + + ValueState state = + backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + + + Queue previousStateHandles = new LinkedList<>(); + SharedStateRegistry sharedStateRegistry = spy(new SharedStateRegistry()); + for (int checkpointId = 0; checkpointId < 3; ++checkpointId) { + + reset(sharedStateRegistry); + + backend.setCurrentKey(checkpointId); + state.update("Hello-" + checkpointId); + + RunnableFuture snapshot = backend.snapshot( + checkpointId, + checkpointId, + createStreamFactory(), + CheckpointOptions.forFullCheckpoint()); + + snapshot.run(); + + IncrementalKeyedStateHandle stateHandle = (IncrementalKeyedStateHandle) snapshot.get(); + Map sharedState = + new HashMap<>(stateHandle.getSharedState()); + + stateHandle.registerSharedStates(sharedStateRegistry); + + for (Map.Entry e : sharedState.entrySet()) { + verify(sharedStateRegistry).registerReference( + stateHandle.createSharedStateRegistryKeyFromFileName(e.getKey()), + e.getValue()); + } + + previousStateHandles.add(stateHandle); + backend.notifyCheckpointComplete(checkpointId); + + //----------------------------------------------------------------- + + if (previousStateHandles.size() > 1) { + checkRemove(previousStateHandles.remove(), sharedStateRegistry); + } + } + + while (!previousStateHandles.isEmpty()) { + + reset(sharedStateRegistry); + + checkRemove(previousStateHandles.remove(), sharedStateRegistry); + } + + backend.close(); + backend.dispose(); + } + } + + private void checkRemove(IncrementalKeyedStateHandle remove, SharedStateRegistry registry) throws Exception { + for (StateHandleID id : remove.getSharedState().keySet()) { + verify(registry, times(0)).unregisterReference( + remove.createSharedStateRegistryKeyFromFileName(id)); + } + + remove.discardState(); + + for (StateHandleID id : remove.getSharedState().keySet()) { + verify(registry).unregisterReference( + remove.createSharedStateRegistryKeyFromFileName(id)); + } + } private void runStateUpdates() throws Exception{ for (int i = 50; i < 150; ++i) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java index 1ab5b41f0cdc2..b38208026e51e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java @@ -25,8 +25,6 @@ import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.ExceptionUtils; -import org.apache.flink.util.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -177,13 +175,13 @@ public CheckpointProperties getProperties() { } public void discardOnFailedStoring() throws Exception { - new UnstoredDiscardStategy().discard(); + doDiscard(); } public boolean discardOnSubsume(SharedStateRegistry sharedStateRegistry) throws Exception { if (props.discardOnSubsumed()) { - new StoredDiscardStrategy(sharedStateRegistry).discard(); + doDiscard(); return true; } @@ -197,7 +195,7 @@ public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry shared jobStatus == JobStatus.FAILED && props.discardOnJobFailed() || jobStatus == JobStatus.SUSPENDED && props.discardOnJobSuspended()) { - new StoredDiscardStrategy(sharedStateRegistry).discard(); + doDiscard(); return true; } else { if (externalPointer != null) { @@ -209,6 +207,42 @@ public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry shared } } + private void doDiscard() throws Exception { + + try { + // collect exceptions and continue cleanup + Exception exception = null; + + // drop the metadata, if we have some + if (externalizedMetadata != null) { + try { + externalizedMetadata.discardState(); + } catch (Exception e) { + exception = e; + } + } + + // discard private state objects + try { + StateUtil.bestEffortDiscardAllStateObjects(operatorStates.values()); + } catch (Exception e) { + exception = ExceptionUtils.firstOrSuppressed(e, exception); + } + + if (exception != null) { + throw exception; + } + } finally { + operatorStates.clear(); + + // to be null-pointer safe, copy reference to stack + CompletedCheckpointStats.DiscardCallback discardCallback = this.discardCallback; + if (discardCallback != null) { + discardCallback.notifyDiscardedCheckpoint(); + } + } + } + public long getStateSize() { long result = 0L; @@ -252,7 +286,7 @@ void setDiscardCallback(@Nullable CompletedCheckpointStats.DiscardCallback disca /** * Register all shared states in the given registry. This is method is called - * when the completed checkpoint has been successfully added into the store. + * before the checkpoint is added into the store. * * @param sharedStateRegistry The registry where shared states are registered */ @@ -266,102 +300,4 @@ public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { public String toString() { return String.format("Checkpoint %d @ %d for %s", checkpointID, timestamp, job); } - - /** - * Base class for the discarding strategies of {@link CompletedCheckpoint}. - */ - private abstract class DiscardStrategy { - - protected Exception storedException; - - public DiscardStrategy() { - this.storedException = null; - } - - public void discard() throws Exception { - - try { - // collect exceptions and continue cleanup - storedException = null; - - doDiscardExternalizedMetaData(); - doDiscardSharedState(); - doDiscardPrivateState(); - doReportStoredExceptions(); - } finally { - clearTaskStatesAndNotifyDiscardCompleted(); - } - } - - protected void doDiscardExternalizedMetaData() { - // drop the metadata, if we have some - if (externalizedMetadata != null) { - try { - externalizedMetadata.discardState(); - } catch (Exception e) { - storedException = e; - } - } - } - - protected void doDiscardPrivateState() { - // discard private state objects - try { - StateUtil.bestEffortDiscardAllStateObjects(operatorStates.values()); - } catch (Exception e) { - storedException = ExceptionUtils.firstOrSuppressed(e, storedException); - } - } - - protected abstract void doDiscardSharedState(); - - protected void doReportStoredExceptions() throws Exception { - if (storedException != null) { - throw storedException; - } - } - - protected void clearTaskStatesAndNotifyDiscardCompleted() { - operatorStates.clear(); - // to be null-pointer safe, copy reference to stack - CompletedCheckpointStats.DiscardCallback discardCallback = - CompletedCheckpoint.this.discardCallback; - - if (discardCallback != null) { - discardCallback.notifyDiscardedCheckpoint(); - } - } - } - - /** - * Discard all shared states created in the checkpoint. This strategy is applied - * when the completed checkpoint fails to be added into the store. - */ - private class UnstoredDiscardStategy extends CompletedCheckpoint.DiscardStrategy { - - @Override - protected void doDiscardSharedState() { - // nothing to do because we did not register any shared state yet. unregistered, new - // shared state is then still considered private state and deleted as part of - // doDiscardPrivateState(). - } - } - - /** - * Unregister all shared states from the given registry. This is strategy is - * applied when the completed checkpoint is subsumed or the job terminates. - */ - private class StoredDiscardStrategy extends CompletedCheckpoint.DiscardStrategy { - - SharedStateRegistry sharedStateRegistry; - - public StoredDiscardStrategy(SharedStateRegistry sharedStateRegistry) { - this.sharedStateRegistry = Preconditions.checkNotNull(sharedStateRegistry); - } - - @Override - protected void doDiscardSharedState() { - sharedStateRegistry.unregisterAll(operatorStates.values()); - } - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java index aa676e71f7ce0..b15302835bb38 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java @@ -125,13 +125,6 @@ public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { } } - @Override - public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) { - for (OperatorSubtaskState operatorSubtaskState : operatorSubtaskStates.values()) { - operatorSubtaskState.unregisterSharedStates(sharedStateRegistry); - } - } - @Override public long getStateSize() { long result = 0L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java index 49ef863c37cc8..e2ae632a26b1b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java @@ -157,17 +157,6 @@ public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { } } - @Override - public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) { - if (managedKeyedState != null) { - managedKeyedState.unregisterSharedStates(sharedStateRegistry); - } - - if (rawKeyedState != null) { - rawKeyedState.unregisterSharedStates(sharedStateRegistry); - } - } - @Override public long getStateSize() { return stateSize; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java index f5e1db351fd8c..233cfc86a9045 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java @@ -63,10 +63,10 @@ public void recover() throws Exception { @Override public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { - checkpoints.addLast(checkpoint); - checkpoint.registerSharedStates(sharedStateRegistry); + checkpoints.addLast(checkpoint); + if (checkpoints.size() > maxNumberOfCheckpointsToRetain) { try { CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java index a77baf3bdb3b4..20d675b686b94 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java @@ -161,17 +161,6 @@ public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { } } - @Override - public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) { - if (managedKeyedState != null) { - managedKeyedState.unregisterSharedStates(sharedStateRegistry); - } - - if (rawKeyedState != null) { - rawKeyedState.unregisterSharedStates(sharedStateRegistry); - } - } - @Override public long getStateSize() { return stateSize; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java index aa5c516a6be29..ed847a43449d7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java @@ -140,13 +140,6 @@ public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { } } - @Override - public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) { - for (SubtaskState subtaskState : subtaskStates.values()) { - subtaskState.unregisterSharedStates(sharedStateRegistry); - } - } - @Override public long getStateSize() { long result = 0L; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java index 084d93e7ca24d..4c3c1fff78d77 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java @@ -35,7 +35,6 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.ConcurrentModificationException; -import java.util.Iterator; import java.util.List; import java.util.concurrent.Executor; @@ -79,7 +78,7 @@ public class ZooKeeperCompletedCheckpointStore extends AbstractCompletedCheckpoi private final int maxNumberOfCheckpointsToRetain; /** Local completed checkpoints. */ - private final ArrayDeque, String>> checkpointStateHandles; + private final ArrayDeque completedCheckpoints; /** * Creates a {@link ZooKeeperCompletedCheckpointStore} instance. @@ -122,7 +121,7 @@ public ZooKeeperCompletedCheckpointStore( this.checkpointsInZooKeeper = new ZooKeeperStateHandleStore<>(this.client, stateStorage, executor); - this.checkpointStateHandles = new ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1); + this.completedCheckpoints = new ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1); LOG.info("Initialized in '{}'.", checkpointsPath); } @@ -146,7 +145,7 @@ public void recover() throws Exception { // Clear local handles in order to prevent duplicates on // recovery. The local handles should reflect the state // of ZooKeeper. - checkpointStateHandles.clear(); + completedCheckpoints.clear(); // Get all there is first List, String>> initialCheckpoints; @@ -170,6 +169,11 @@ public void recover() throws Exception { try { completedCheckpoint = retrieveCompletedCheckpoint(checkpointStateHandle); + if (completedCheckpoint != null) { + // Re-register all shared states in the checkpoint. + completedCheckpoint.registerSharedStates(sharedStateRegistry); + completedCheckpoints.add(completedCheckpoint); + } } catch (Exception e) { LOG.warn("Could not retrieve checkpoint. Removing it from the completed " + "checkpoint store.", e); @@ -177,11 +181,6 @@ public void recover() throws Exception { // remove the checkpoint with broken state handle removeBrokenStateHandle(checkpointStateHandle.f1, checkpointStateHandle.f0); } - - if (completedCheckpoint != null) { - completedCheckpoint.registerSharedStates(sharedStateRegistry); - checkpointStateHandles.add(checkpointStateHandle); - } } } @@ -195,20 +194,19 @@ public void addCheckpoint(final CompletedCheckpoint checkpoint) throws Exception checkNotNull(checkpoint, "Checkpoint"); final String path = checkpointIdToPath(checkpoint.getCheckpointID()); - final RetrievableStateHandle stateHandle; - // First add the new one. If it fails, we don't want to loose existing data. - stateHandle = checkpointsInZooKeeper.addAndLock(path, checkpoint); + // First, register all shared states in the checkpoint to consolidates placeholder. + checkpoint.registerSharedStates(sharedStateRegistry); - checkpointStateHandles.addLast(new Tuple2<>(stateHandle, path)); + // Now add the new one. If it fails, we don't want to loose existing data. + checkpointsInZooKeeper.addAndLock(path, checkpoint); - // Register all shared states in the checkpoint - checkpoint.registerSharedStates(sharedStateRegistry); + completedCheckpoints.addLast(checkpoint); // Everything worked, let's remove a previous checkpoint if necessary. - while (checkpointStateHandles.size() > maxNumberOfCheckpointsToRetain) { + while (completedCheckpoints.size() > maxNumberOfCheckpointsToRetain) { try { - removeSubsumed(checkpointStateHandles.removeFirst().f1, sharedStateRegistry); + removeSubsumed(completedCheckpoints.removeFirst(), sharedStateRegistry); } catch (Exception e) { LOG.warn("Failed to subsume the old checkpoint", e); } @@ -219,60 +217,23 @@ public void addCheckpoint(final CompletedCheckpoint checkpoint) throws Exception @Override public CompletedCheckpoint getLatestCheckpoint() { - if (checkpointStateHandles.isEmpty()) { + if (completedCheckpoints.isEmpty()) { return null; } else { - while(!checkpointStateHandles.isEmpty()) { - Tuple2, String> checkpointStateHandle = checkpointStateHandles.peekLast(); - - try { - return retrieveCompletedCheckpoint(checkpointStateHandle); - } catch (Exception e) { - LOG.warn("Could not retrieve latest checkpoint. Removing it from " + - "the completed checkpoint store.", e); - - try { - // remove the checkpoint with broken state handle - Tuple2, String> checkpoint = checkpointStateHandles.pollLast(); - removeBrokenStateHandle(checkpoint.f1, checkpoint.f0); - } catch (Exception removeException) { - LOG.warn("Could not remove the latest checkpoint with a broken state handle.", removeException); - } - } - } - - return null; + return completedCheckpoints.peekLast(); } } @Override public List getAllCheckpoints() throws Exception { - List checkpoints = new ArrayList<>(checkpointStateHandles.size()); - - Iterator, String>> stateHandleIterator = checkpointStateHandles.iterator(); - - while (stateHandleIterator.hasNext()) { - Tuple2, String> stateHandlePath = stateHandleIterator.next(); - - try { - checkpoints.add(retrieveCompletedCheckpoint(stateHandlePath)); - } catch (Exception e) { - LOG.warn("Could not retrieve checkpoint. Removing it from the completed " + - "checkpoint store.", e); - - // remove the checkpoint with broken state handle - stateHandleIterator.remove(); - removeBrokenStateHandle(stateHandlePath.f1, stateHandlePath.f0); - } - } - + List checkpoints = new ArrayList<>(completedCheckpoints); return checkpoints; } @Override public int getNumberOfRetainedCheckpoints() { - return checkpointStateHandles.size(); + return completedCheckpoints.size(); } @Override @@ -285,15 +246,15 @@ public void shutdown(JobStatus jobStatus) throws Exception { if (jobStatus.isGloballyTerminalState()) { LOG.info("Shutting down"); - for (Tuple2, String> checkpoint : checkpointStateHandles) { + for (CompletedCheckpoint checkpoint : completedCheckpoints) { try { - removeShutdown(checkpoint.f1, jobStatus, sharedStateRegistry); + removeShutdown(checkpoint, jobStatus, sharedStateRegistry); } catch (Exception e) { LOG.error("Failed to discard checkpoint.", e); } } - checkpointStateHandles.clear(); + completedCheckpoints.clear(); String path = "/" + client.getNamespace(); @@ -303,7 +264,7 @@ public void shutdown(JobStatus jobStatus) throws Exception { LOG.info("Suspending"); // Clear the local handles, but don't remove any state - checkpointStateHandles.clear(); + completedCheckpoints.clear(); // Release the state handle locks in ZooKeeper such that they can be deleted checkpointsInZooKeeper.releaseAll(); @@ -313,21 +274,18 @@ public void shutdown(JobStatus jobStatus) throws Exception { // ------------------------------------------------------------------------ private void removeSubsumed( - final String pathInZooKeeper, + final CompletedCheckpoint completedCheckpoint, final SharedStateRegistry sharedStateRegistry) throws Exception { - - ZooKeeperStateHandleStore.RemoveCallback action = new ZooKeeperStateHandleStore.RemoveCallback() { - @Override - public void apply(@Nullable RetrievableStateHandle value) throws FlinkException { - if (value != null) { - final CompletedCheckpoint completedCheckpoint; - try { - completedCheckpoint = value.retrieveState(); - } catch (Exception e) { - throw new FlinkException("Could not retrieve the completed checkpoint from the given state handle.", e); - } - if (completedCheckpoint != null) { + if(completedCheckpoint == null) { + return; + } + + ZooKeeperStateHandleStore.RemoveCallback action = + new ZooKeeperStateHandleStore.RemoveCallback() { + @Override + public void apply(@Nullable RetrievableStateHandle value) throws FlinkException { + if (value != null) { try { completedCheckpoint.discardOnSubsume(sharedStateRegistry); } catch (Exception e) { @@ -335,46 +293,41 @@ public void apply(@Nullable RetrievableStateHandle value) t } } } - } - }; + }; - checkpointsInZooKeeper.releaseAndTryRemove(pathInZooKeeper, action); + checkpointsInZooKeeper.releaseAndTryRemove( + checkpointIdToPath(completedCheckpoint.getCheckpointID()), + action); } private void removeShutdown( - final String pathInZooKeeper, + final CompletedCheckpoint completedCheckpoint, final JobStatus jobStatus, final SharedStateRegistry sharedStateRegistry) throws Exception { + if(completedCheckpoint == null) { + return; + } + ZooKeeperStateHandleStore.RemoveCallback removeAction = new ZooKeeperStateHandleStore.RemoveCallback() { @Override public void apply(@Nullable RetrievableStateHandle value) throws FlinkException { - if (value != null) { - final CompletedCheckpoint completedCheckpoint; - - try { - completedCheckpoint = value.retrieveState(); - } catch (Exception e) { - throw new FlinkException("Could not retrieve the completed checkpoint from the given state handle.", e); - } - - if (completedCheckpoint != null) { - try { - completedCheckpoint.discardOnShutdown(jobStatus, sharedStateRegistry); - } catch (Exception e) { - throw new FlinkException("Could not discard the completed checkpoint on subsume.", e); - } - } + try { + completedCheckpoint.discardOnShutdown(jobStatus, sharedStateRegistry); + } catch (Exception e) { + throw new FlinkException("Could not discard the completed checkpoint on subsume.", e); } } }; - checkpointsInZooKeeper.releaseAndTryRemove(pathInZooKeeper, removeAction); + checkpointsInZooKeeper.releaseAndTryRemove( + checkpointIdToPath(completedCheckpoint.getCheckpointID()), + removeAction); } private void removeBrokenStateHandle( - final String pathInZooKeeper, - final RetrievableStateHandle retrievableStateHandle) throws Exception { + final String pathInZooKeeper, + final RetrievableStateHandle retrievableStateHandle) throws Exception { checkpointsInZooKeeper.releaseAndTryRemove(pathInZooKeeper, new ZooKeeperStateHandleStore.RemoveCallback() { @Override public void apply(@Nullable RetrievableStateHandle value) throws FlinkException { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java index b71418b41611b..da0022c926960 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java @@ -29,7 +29,6 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.PlaceholderStreamStateHandle; import org.apache.flink.runtime.state.StateHandleID; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; @@ -75,7 +74,6 @@ class SavepointV2Serializer implements SavepointSerializer { private static final byte KEY_GROUPS_HANDLE = 3; private static final byte PARTITIONABLE_OPERATOR_STATE_HANDLE = 4; private static final byte INCREMENTAL_KEY_GROUPS_HANDLE = 5; - private static final byte PLACEHOLDER_STREAM_STATE_HANDLE = 6; /** The singleton instance of the serializer */ public static final SavepointV2Serializer INSTANCE = new SavepointV2Serializer(); @@ -328,8 +326,7 @@ private static void serializeKeyedStateHandle( serializeStreamStateHandle(incrementalKeyedStateHandle.getMetaStateHandle(), dos); - serializeStreamStateHandleMap(incrementalKeyedStateHandle.getCreatedSharedState(), dos); - serializeStreamStateHandleMap(incrementalKeyedStateHandle.getReferencedSharedState(), dos); + serializeStreamStateHandleMap(incrementalKeyedStateHandle.getSharedState(), dos); serializeStreamStateHandleMap(incrementalKeyedStateHandle.getPrivateState(), dos); } else { throw new IllegalStateException("Unknown KeyedStateHandle type: " + stateHandle.getClass()); @@ -390,16 +387,14 @@ private static KeyedStateHandle deserializeKeyedStateHandle(DataInputStream dis) KeyGroupRange.of(startKeyGroup, startKeyGroup + numKeyGroups - 1); StreamStateHandle metaDataStateHandle = deserializeStreamStateHandle(dis); - Map createdStates = deserializeStreamStateHandleMap(dis); - Map referencedStates = deserializeStreamStateHandleMap(dis); + Map sharedStates = deserializeStreamStateHandleMap(dis); Map privateStates = deserializeStreamStateHandleMap(dis); return new IncrementalKeyedStateHandle( operatorId, keyGroupRange, checkpointId, - createdStates, - referencedStates, + sharedStates, privateStates, metaDataStateHandle); } else { @@ -485,10 +480,6 @@ private static void serializeStreamStateHandle( byte[] internalData = byteStreamStateHandle.getData(); dos.writeInt(internalData.length); dos.write(byteStreamStateHandle.getData()); - } else if (stateHandle instanceof PlaceholderStreamStateHandle) { - PlaceholderStreamStateHandle placeholder = (PlaceholderStreamStateHandle) stateHandle; - dos.writeByte(PLACEHOLDER_STREAM_STATE_HANDLE); - dos.writeLong(placeholder.getStateSize()); } else { throw new IOException("Unknown implementation of StreamStateHandle: " + stateHandle.getClass()); } @@ -510,8 +501,6 @@ private static StreamStateHandle deserializeStreamStateHandle(DataInputStream di byte[] data = new byte[numBytes]; dis.readFully(data); return new ByteStreamStateHandle(handleName, data); - } else if (PLACEHOLDER_STREAM_STATE_HANDLE == type) { - return new PlaceholderStreamStateHandle(dis.readLong()); } else { throw new IOException("Unknown implementation of StreamStateHandle, code: " + type); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java index 002b7c315b396..1bc6a0fae80be 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java @@ -34,7 +34,8 @@ * this handle and considered as private state until it is registered for the first time. Registration * transfers ownership to the {@link SharedStateRegistry}. * The composite state handle should only delete all private states in the - * {@link StateObject#discardState()} method. + * {@link StateObject#discardState()} method, the {@link SharedStateRegistry} is responsible for + * deleting shared states after they were registered. */ public interface CompositeStateHandle extends StateObject { @@ -45,18 +46,10 @@ public interface CompositeStateHandle extends StateObject { *

* After this is completed, newly created shared state is considered as published is no longer * owned by this handle. This means that it should no longer be deleted as part of calls to - * {@link #discardState()}. + * {@link #discardState()}. Instead, {@link #discardState()} will trigger an unregistration + * from the registry. * * @param stateRegistry The registry where shared states are registered. */ void registerSharedStates(SharedStateRegistry stateRegistry); - - /** - * Unregister both created and referenced shared states in the given - * {@link SharedStateRegistry}. This method is called when the checkpoint is - * subsumed or the job is shut down. - * - * @param stateRegistry The registry where shared states are registered. - */ - void unregisterSharedStates(SharedStateRegistry stateRegistry); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java index 706e219a028d0..770b5a91c11ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java @@ -28,18 +28,24 @@ /** * The handle to states of an incremental snapshot. *

- * The states contained in an incremental snapshot include + * The states contained in an incremental snapshot include: *

    - *
  • Created shared state which includes (the supposed to be) shared files produced since the last + *
  • Created shared state which includes shared files produced since the last * completed checkpoint. These files can be referenced by succeeding checkpoints if the * checkpoint succeeds to complete.
  • *
  • Referenced shared state which includes the shared files materialized in previous - * checkpoints.
  • + * checkpoints. Until we this is registered to a {@link SharedStateRegistry}, all referenced + * shared state handles are only placeholders, so that we do not send state handles twice + * from which we know that they already exist on the checkpoint coordinator. *
  • Private state which includes all other files, typically mutable, that cannot be shared by * other checkpoints.
  • *
  • Backend meta state which includes the information of existing states.
  • *
* + * When this should become a completed checkpoint on the checkpoint coordinator, it must first be + * registered with a {@link SharedStateRegistry}, so that all placeholder state handles to + * previously existing state are replaced with the originals. + * * IMPORTANT: This class currently overrides equals and hash code only for testing purposes. They * should not be called from production code. This means this class is also not suited to serve as * a key, e.g. in hash maps. @@ -66,14 +72,9 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle { private final long checkpointId; /** - * State that the incremental checkpoint created new - */ - private final Map createdSharedState; - - /** - * State that the incremental checkpoint references from previous checkpoints + * Shared state in the incremental checkpoint. This i */ - private final Map referencedSharedState; + private final Map sharedState; /** * Private state in the incremental checkpoint @@ -86,32 +87,30 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle { private final StreamStateHandle metaStateHandle; /** - * True if the state handle has already registered shared states. - *

- * Once the shared states are registered, it's the {@link SharedStateRegistry}'s - * responsibility to maintain the shared states. But in the cases where the - * state handle is discarded before performing the registration, the handle - * should delete all the shared states created by it. + * Once the shared states are registered, it is the {@link SharedStateRegistry}'s + * responsibility to cleanup those shared states. + * But in the cases where the state handle is discarded before performing the registration, + * the handle should delete all the shared states created by it. + * + * This variable is not null iff the handles was registered. */ - private boolean registered; + private transient SharedStateRegistry sharedStateRegistry; public IncrementalKeyedStateHandle( String operatorIdentifier, KeyGroupRange keyGroupRange, long checkpointId, - Map createdSharedState, - Map referencedSharedState, + Map sharedState, Map privateState, StreamStateHandle metaStateHandle) { this.operatorIdentifier = Preconditions.checkNotNull(operatorIdentifier); this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange); this.checkpointId = checkpointId; - this.createdSharedState = Preconditions.checkNotNull(createdSharedState); - this.referencedSharedState = Preconditions.checkNotNull(referencedSharedState); + this.sharedState = Preconditions.checkNotNull(sharedState); this.privateState = Preconditions.checkNotNull(privateState); this.metaStateHandle = Preconditions.checkNotNull(metaStateHandle); - this.registered = false; + this.sharedStateRegistry = null; } @Override @@ -123,12 +122,8 @@ public long getCheckpointId() { return checkpointId; } - public Map getCreatedSharedState() { - return createdSharedState; - } - - public Map getReferencedSharedState() { - return referencedSharedState; + public Map getSharedState() { + return sharedState; } public Map getPrivateState() { @@ -155,8 +150,6 @@ public KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange) { @Override public void discardState() throws Exception { - Preconditions.checkState(!registered, "Attempt to dispose a registered composite state with registered shared state. Must unregister first."); - try { metaStateHandle.discardState(); } catch (Exception e) { @@ -169,37 +162,35 @@ public void discardState() throws Exception { LOG.warn("Could not properly discard misc file states.", e); } - try { - StateUtil.bestEffortDiscardAllStateObjects(createdSharedState.values()); - } catch (Exception e) { - LOG.warn("Could not properly discard new sst file states.", e); + // If this was not registered, we can delete the shared state. We can simply apply this + // to all handles, because all handles that have not been created for the first time for this + // are only placeholders at this point (disposing them is a NOP). + if (sharedStateRegistry == null) { + try { + StateUtil.bestEffortDiscardAllStateObjects(sharedState.values()); + } catch (Exception e) { + LOG.warn("Could not properly discard new sst file states.", e); + } + } else { + // If this was registered, we only unregister all our referenced shared states + // from the registry. + for (StateHandleID stateHandleID : sharedState.keySet()) { + sharedStateRegistry.unregisterReference( + createSharedStateRegistryKeyFromFileName(stateHandleID)); + } } - } @Override public long getStateSize() { - long size = getPrivateStateSize(); - - for (StreamStateHandle oldSstFileHandle : referencedSharedState.values()) { - size += oldSstFileHandle.getStateSize(); - } - - return size; - } - - /** - * Returns the size of the state that is privately owned by this handle. - */ - public long getPrivateStateSize() { long size = StateUtil.getStateSize(metaStateHandle); - for (StreamStateHandle newSstFileHandle : createdSharedState.values()) { - size += newSstFileHandle.getStateSize(); + for (StreamStateHandle sharedStateHandle : sharedState.values()) { + size += sharedStateHandle.getStateSize(); } - for (StreamStateHandle miscFileHandle : privateState.values()) { - size += miscFileHandle.getStateSize(); + for (StreamStateHandle privateStateHandle : privateState.values()) { + size += privateStateHandle.getStateSize(); } return size; @@ -208,64 +199,38 @@ public long getPrivateStateSize() { @Override public void registerSharedStates(SharedStateRegistry stateRegistry) { - Preconditions.checkState(!registered, "The state handle has already registered its shared states."); + Preconditions.checkState(sharedStateRegistry == null, "The state handle has already registered its shared states."); + + sharedStateRegistry = Preconditions.checkNotNull(stateRegistry); - for (Map.Entry newSstFileEntry : createdSharedState.entrySet()) { + for (Map.Entry sharedStateHandle : sharedState.entrySet()) { SharedStateRegistryKey registryKey = - createSharedStateRegistryKeyFromFileName(newSstFileEntry.getKey()); + createSharedStateRegistryKeyFromFileName(sharedStateHandle.getKey()); SharedStateRegistry.Result result = - stateRegistry.registerNewReference(registryKey, newSstFileEntry.getValue()); - - // We update our reference with the result from the registry, to prevent the following - // problem: + stateRegistry.registerReference(registryKey, sharedStateHandle.getValue()); + + // This step consolidates our shared handles with the registry, which does two things: + // + // 1) Replace placeholder state handle with already registered, actual state handles. + // + // 2) Deduplicate re-uploads of incremental state due to missing confirmations about + // completed checkpoints. + // + // This prevents the following problem: // A previous checkpoint n has already registered the state. This can happen if a // following checkpoint (n + x) wants to reference the same state before the backend got // notified that checkpoint n completed. In this case, the shared registry did // deduplication and returns the previous reference. - newSstFileEntry.setValue(result.getReference()); - } - - for (Map.Entry oldSstFileName : referencedSharedState.entrySet()) { - SharedStateRegistryKey registryKey = - createSharedStateRegistryKeyFromFileName(oldSstFileName.getKey()); - - SharedStateRegistry.Result result = stateRegistry.obtainReference(registryKey); - - // Again we update our state handle with the result from the registry, thus replacing - // placeholder state handles with the originals. - oldSstFileName.setValue(result.getReference()); - } - - // Migrate state from unregistered to registered, so that it will not count as private state - // for #discardState() from now. - referencedSharedState.putAll(createdSharedState); - createdSharedState.clear(); - - registered = true; - } - - @Override - public void unregisterSharedStates(SharedStateRegistry stateRegistry) { - - Preconditions.checkState(registered, "The state handle has not registered its shared states yet."); - - for (Map.Entry newSstFileEntry : createdSharedState.entrySet()) { - SharedStateRegistryKey registryKey = - createSharedStateRegistryKeyFromFileName(newSstFileEntry.getKey()); - stateRegistry.releaseReference(registryKey); - } - - for (Map.Entry oldSstFileEntry : referencedSharedState.entrySet()) { - SharedStateRegistryKey registryKey = - createSharedStateRegistryKeyFromFileName(oldSstFileEntry.getKey()); - stateRegistry.releaseReference(registryKey); + sharedStateHandle.setValue(result.getReference()); } - - registered = false; } - private SharedStateRegistryKey createSharedStateRegistryKeyFromFileName(StateHandleID shId) { + /** + * Create a unique key to register one of our shared state handles. + */ + @VisibleForTesting + public SharedStateRegistryKey createSharedStateRegistryKeyFromFileName(StateHandleID shId) { return new SharedStateRegistryKey(operatorIdentifier + '-' + keyGroupRange, shId); } @@ -293,10 +258,7 @@ public boolean equals(Object o) { if (!getKeyGroupRange().equals(that.getKeyGroupRange())) { return false; } - if (!getCreatedSharedState().equals(that.getCreatedSharedState())) { - return false; - } - if (!getReferencedSharedState().equals(that.getReferencedSharedState())) { + if (!getSharedState().equals(that.getSharedState())) { return false; } if (!getPrivateState().equals(that.getPrivateState())) { @@ -314,8 +276,7 @@ public int hashCode() { int result = getOperatorIdentifier().hashCode(); result = 31 * result + getKeyGroupRange().hashCode(); result = 31 * result + (int) (getCheckpointId() ^ (getCheckpointId() >>> 32)); - result = 31 * result + getCreatedSharedState().hashCode(); - result = 31 * result + getReferencedSharedState().hashCode(); + result = 31 * result + getSharedState().hashCode(); result = 31 * result + getPrivateState().hashCode(); result = 31 * result + getMetaStateHandle().hashCode(); return result; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java index 82804601f042a..8e38ad4750d86 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java @@ -97,11 +97,6 @@ public void registerSharedStates(SharedStateRegistry stateRegistry) { // No shared states } - @Override - public void unregisterSharedStates(SharedStateRegistry stateRegistry) { - // No shared states - } - @Override public void discardState() throws Exception { stateHandle.discardState(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PlaceholderStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PlaceholderStreamStateHandle.java index 2136061a4e2fd..7c948a169cb3b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PlaceholderStreamStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PlaceholderStreamStateHandle.java @@ -18,29 +18,20 @@ package org.apache.flink.runtime.state; -import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; /** * A placeholder state handle for shared state that will replaced by an original that was - * created in a previous checkpoint. So we don't have to send the handle twice, e.g. in - * case of {@link ByteStreamStateHandle}. To be used in the referenced states of + * created in a previous checkpoint. So we don't have to send a state handle twice, e.g. in + * case of {@link ByteStreamStateHandle}. This class is used in the referenced states of * {@link IncrementalKeyedStateHandle}. - *

- * IMPORTANT: This class currently overrides equals and hash code only for testing purposes. They - * should not be called from production code. This means this class is also not suited to serve as - * a key, e.g. in hash maps. */ public class PlaceholderStreamStateHandle implements StreamStateHandle { private static final long serialVersionUID = 1L; - /** We remember the size of the original file for which this is a placeholder */ - private final long originalSize; - - public PlaceholderStreamStateHandle(long originalSize) { - this.originalSize = originalSize; + public PlaceholderStreamStateHandle() { } @Override @@ -56,33 +47,6 @@ public void discardState() throws Exception { @Override public long getStateSize() { - return originalSize; - } - - /** - * This method is should only be called in tests! This should never serve as key in a hash map. - */ - @VisibleForTesting - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - PlaceholderStreamStateHandle that = (PlaceholderStreamStateHandle) o; - - return originalSize == that.originalSize; - } - - /** - * This method is should only be called in tests! This should never serve as key in a hash map. - */ - @VisibleForTesting - @Override - public int hashCode() { - return (int) (originalSize ^ (originalSize >>> 32)); + return 0L; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java index f9161b0924960..a5e0f841b6cb5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java @@ -54,7 +54,7 @@ public SharedStateRegistry(Executor asyncDisposalExecutor) { } /** - * Register a reference to the given (supposedly new) shared state in the registry. + * Register a reference to the given shared state in the registry. * This does the following: We check if the state handle is actually new by the * registrationKey. If it is new, we register it with a reference count of 1. If there is * already a state handle registered under the given key, we dispose the given "new" state @@ -62,14 +62,14 @@ public SharedStateRegistry(Executor asyncDisposalExecutor) { * a replacement with the result. * *

IMPORTANT: caller should check the state handle returned by the result, because the - * registry is performing deduplication and could potentially return a handle that is supposed + * registry is performing de-duplication and could potentially return a handle that is supposed * to replace the one from the registration request. * * @param state the shared state for which we register a reference. * @return the result of this registration request, consisting of the state handle that is * registered under the key by the end of the oepration and its current reference count. */ - public Result registerNewReference(SharedStateRegistryKey registrationKey, StreamStateHandle state) { + public Result registerReference(SharedStateRegistryKey registrationKey, StreamStateHandle state) { Preconditions.checkNotNull(state); @@ -95,28 +95,6 @@ public Result registerNewReference(SharedStateRegistryKey registrationKey, Strea return new Result(entry); } - /** - * Obtains one reference to the given shared state in the registry. This increases the - * reference count by one. - * - * @param registrationKey the shared state for which we obtain a reference. - * @return the shared state for which we release a reference. - * @return the result of the request, consisting of the reference count after this operation - * and the state handle. - */ - public Result obtainReference(SharedStateRegistryKey registrationKey) { - - Preconditions.checkNotNull(registrationKey); - - synchronized (registeredStates) { - SharedStateRegistry.SharedStateEntry entry = - Preconditions.checkNotNull(registeredStates.get(registrationKey), - "Could not find a state for the given registration key!"); - entry.increaseReferenceCount(); - return new Result(entry); - } - } - /** * Releases one reference to the given shared state in the registry. This decreases the * reference count by one. Once the count reaches zero, the shared state is deleted. @@ -125,7 +103,7 @@ public Result obtainReference(SharedStateRegistryKey registrationKey) { * @return the result of the request, consisting of the reference count after this operation * and the state handle, or null if the state handle was deleted through this request. */ - public Result releaseReference(SharedStateRegistryKey registrationKey) { + public Result unregisterReference(SharedStateRegistryKey registrationKey) { Preconditions.checkNotNull(registrationKey); @@ -172,30 +150,18 @@ public void registerAll(Iterable stateHandles) { } } - /** - * Unregister all the shared states referenced by the given. - * - * @param stateHandles The shared states to unregister. - */ - public void unregisterAll(Iterable stateHandles) { - if (stateHandles == null) { - return; - } - - synchronized (registeredStates) { - for (CompositeStateHandle stateHandle : stateHandles) { - stateHandle.unregisterSharedStates(this); - } - } - } - private void scheduleAsyncDelete(StreamStateHandle streamStateHandle) { - if (streamStateHandle != null) { + // We do the small optimization to not issue discards for placeholders, which are NOPs. + if (streamStateHandle != null && !isPlaceholder(streamStateHandle)) { asyncDisposalExecutor.execute( new SharedStateRegistry.AsyncDisposalRunnable(streamStateHandle)); } } + private boolean isPlaceholder(StreamStateHandle stateHandle) { + return stateHandle instanceof PlaceholderStreamStateHandle; + } + /** * An entry in the registry, tracking the handle and the corresponding reference count. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java index 42703f8662722..9ba9d35ff8939 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java @@ -91,6 +91,13 @@ public int hashCode() { return 31 * handleName.hashCode(); } + @Override + public String toString() { + return "ByteStreamStateHandle{" + + "handleName='" + handleName + '\'' + + '}'; + } + /** * An input stream view on a byte array. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 9250634834f8c..3b44d9a72cfd7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -639,12 +639,6 @@ public void testTriggerAndConfirmSimpleCheckpoint() { assertEquals(checkpointIdNew, successNew.getCheckpointID()); assertTrue(successNew.getOperatorStates().isEmpty()); - // validate that the subtask states in old savepoint have unregister their shared states - { - verify(subtaskState1, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); - verify(subtaskState2, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); - } - // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); @@ -925,9 +919,6 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { verify(subtaskState1_2, times(1)).discardState(); // validate that all subtask states in the second checkpoint are not discarded - verify(subtaskState2_1, never()).unregisterSharedStates(any(SharedStateRegistry.class)); - verify(subtaskState2_2, never()).unregisterSharedStates(any(SharedStateRegistry.class)); - verify(subtaskState2_3, never()).unregisterSharedStates(any(SharedStateRegistry.class)); verify(subtaskState2_1, never()).discardState(); verify(subtaskState2_2, never()).discardState(); verify(subtaskState2_3, never()).discardState(); @@ -951,9 +942,6 @@ public void testSuccessfulCheckpointSubsumesUnsuccessful() { coord.shutdown(JobStatus.FINISHED); // validate that the states in the second checkpoint have been discarded - verify(subtaskState2_1, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); - verify(subtaskState2_2, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); - verify(subtaskState2_3, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); verify(subtaskState2_1, times(1)).discardState(); verify(subtaskState2_2, times(1)).discardState(); verify(subtaskState2_3, times(1)).discardState(); @@ -1562,10 +1550,6 @@ public void testTriggerAndConfirmSimpleSavepoint() throws Exception { verify(subtaskState1, never()).discardState(); verify(subtaskState2, never()).discardState(); - // Savepoints are not supposed to have any shared state. - verify(subtaskState1, never()).unregisterSharedStates(any(SharedStateRegistry.class)); - verify(subtaskState2, never()).unregisterSharedStates(any(SharedStateRegistry.class)); - // validate that the relevant tasks got a confirmation message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class)); @@ -2088,15 +2072,6 @@ public void testRestoreLatestCheckpointedState() throws Exception { // shutdown the store store.shutdown(JobStatus.SUSPENDED); - // All shared states should be unregistered once the store is shut down - for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { - for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) { - for (OperatorSubtaskState subtaskState : taskState.getStates()) { - verify(subtaskState, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); - } - } - } - // restore the store Map tasks = new HashMap<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 985c662874464..fb5d7c3c2ff3e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -23,8 +23,8 @@ import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.util.TestLogger; +import org.junit.Assert; import org.junit.Test; -import org.mockito.Mockito; import java.io.IOException; import java.util.Collection; @@ -37,11 +37,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; /** * Test for basic {@link CompletedCheckpointStore} contract. @@ -114,12 +109,6 @@ public void testAddCheckpointMoreThanMaxRetained() throws Exception { expected[i - 1].awaitDiscard(); assertTrue(expected[i - 1].isDiscarded()); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); - - for (OperatorState operatorState : taskStates) { - for (OperatorSubtaskState subtaskState : operatorState.getStates()) { - verify(subtaskState, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); - } - } } } @@ -209,7 +198,8 @@ protected TestCompletedCheckpoint createCheckpoint(int id) throws IOException { operatorGroupState.put(operatorID, operatorState); for (int i = 0; i < numberOfStates; i++) { - OperatorSubtaskState subtaskState = mock(OperatorSubtaskState.class); + OperatorSubtaskState subtaskState = + new TestOperatorSubtaskState(); operatorState.putState(i, subtaskState); } @@ -217,18 +207,10 @@ protected TestCompletedCheckpoint createCheckpoint(int id) throws IOException { return new TestCompletedCheckpoint(new JobID(), id, 0, operatorGroupState, props); } - protected void resetCheckpoint(Collection operatorStates) { - for (OperatorState operatorState : operatorStates) { - for (OperatorSubtaskState subtaskState : operatorState.getStates()) { - Mockito.reset(subtaskState); - } - } - } - protected void verifyCheckpointRegistered(Collection operatorStates, SharedStateRegistry registry) { for (OperatorState operatorState : operatorStates) { for (OperatorSubtaskState subtaskState : operatorState.getStates()) { - verify(subtaskState, times(1)).registerSharedStates(eq(registry)); + Assert.assertTrue(((TestOperatorSubtaskState)subtaskState).registered); } } } @@ -236,7 +218,7 @@ protected void verifyCheckpointRegistered(Collection operatorStat protected void verifyCheckpointDiscarded(Collection operatorStates) { for (OperatorState operatorState : operatorStates) { for (OperatorSubtaskState subtaskState : operatorState.getStates()) { - verify(subtaskState, times(1)).discardState(); + Assert.assertTrue(((TestOperatorSubtaskState)subtaskState).discarded); } } } @@ -333,4 +315,37 @@ public int hashCode() { } } + static class TestOperatorSubtaskState extends OperatorSubtaskState { + private static final long serialVersionUID = 522580433699164230L; + + boolean registered; + boolean discarded; + + public TestOperatorSubtaskState() { + super(null, null, null, null, null); + this.registered = false; + this.discarded = false; + } + + @Override + public void discardState() { + super.discardState(); + Assert.assertFalse(discarded); + discarded = true; + registered = false; + } + + @Override + public void registerSharedStates(SharedStateRegistry sharedStateRegistry) { + super.registerSharedStates(sharedStateRegistry); + Assert.assertFalse(discarded); + registered = true; + } + + public void reset() { + registered = false; + discarded = false; + } + } + } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index 589ff46863a35..0bbb961e037b1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -100,7 +100,6 @@ public void testCleanUpOnSubsume() throws Exception { checkpoint.discardOnSubsume(sharedStateRegistry); verify(state, times(1)).discardState(); - verify(state, times(1)).unregisterSharedStates(sharedStateRegistry); } /** @@ -138,7 +137,6 @@ public void testCleanUpOnShutdown() throws Exception { checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(0)).discardState(); assertEquals(true, file.exists()); - verify(state, times(0)).unregisterSharedStates(sharedStateRegistry); // Discard props = new CheckpointProperties(false, false, true, true, true, true, true); @@ -152,7 +150,6 @@ public void testCleanUpOnShutdown() throws Exception { checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(1)).discardState(); - verify(state, times(1)).unregisterSharedStates(sharedStateRegistry); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index 6df01a02b32ca..a96b5979a4159 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -197,7 +197,6 @@ public void testAbortDiscardsState() throws Exception { OperatorState state = mock(OperatorState.class); doNothing().when(state).registerSharedStates(any(SharedStateRegistry.class)); - doNothing().when(state).unregisterSharedStates(any(SharedStateRegistry.class)); String targetDir = tmpFolder.newFolder().getAbsolutePath(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java index 0d932892621e7..44c802bddad8a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java @@ -34,6 +34,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; @@ -100,11 +101,7 @@ public void testRecover() throws Exception { assertEquals(3, ZOOKEEPER.getClient().getChildren().forPath(CHECKPOINT_PATH).size()); assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); - resetCheckpoint(expected[0].getOperatorStates().values()); - resetCheckpoint(expected[1].getOperatorStates().values()); - resetCheckpoint(expected[2].getOperatorStates().values()); - - // Recover TODO!!! clear registry! + // Recover checkpoints.recover(); assertEquals(3, ZOOKEEPER.getClient().getChildren().forPath(CHECKPOINT_PATH).size()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java index b63782d4b41c8..f985573b18326 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java @@ -34,7 +34,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle.StateMetaInfo; -import org.apache.flink.runtime.state.PlaceholderStreamStateHandle; import org.apache.flink.runtime.state.StateHandleID; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; @@ -273,18 +272,17 @@ public static void assertMasterStateEquality(MasterState a, MasterState b) { private CheckpointTestUtils() {} - private static IncrementalKeyedStateHandle createDummyIncrementalKeyedStateHandle(Random rnd) { + public static IncrementalKeyedStateHandle createDummyIncrementalKeyedStateHandle(Random rnd) { return new IncrementalKeyedStateHandle( createRandomUUID(rnd).toString(), new KeyGroupRange(1, 1), 42L, - createRandomOwnedHandleMap(rnd), - createRandomReferencedHandleMap(rnd), - createRandomOwnedHandleMap(rnd), + createRandomStateHandleMap(rnd), + createRandomStateHandleMap(rnd), createDummyStreamStateHandle(rnd)); } - private static Map createRandomOwnedHandleMap(Random rnd) { + public static Map createRandomStateHandleMap(Random rnd) { final int size = rnd.nextInt(4); Map result = new HashMap<>(size); for (int i = 0; i < size; ++i) { @@ -296,24 +294,13 @@ private static Map createRandomOwnedHandleMap( return result; } - private static Map createRandomReferencedHandleMap(Random rnd) { - final int size = rnd.nextInt(4); - Map result = new HashMap<>(size); - for (int i = 0; i < size; ++i) { - StateHandleID randomId = new StateHandleID(createRandomUUID(rnd).toString()); - result.put(randomId, new PlaceholderStreamStateHandle(rnd.nextInt(1024))); - } - - return result; - } - - private static KeyGroupsStateHandle createDummyKeyGroupStateHandle(Random rnd) { + public static KeyGroupsStateHandle createDummyKeyGroupStateHandle(Random rnd) { return new KeyGroupsStateHandle( new KeyGroupRangeOffsets(1, 1, new long[]{rnd.nextInt(1024)}), createDummyStreamStateHandle(rnd)); } - private static StreamStateHandle createDummyStreamStateHandle(Random rnd) { + public static StreamStateHandle createDummyStreamStateHandle(Random rnd) { return new TestByteStreamStateHandleDeepCompare( String.valueOf(createRandomUUID(rnd)), String.valueOf(createRandomUUID(rnd)).getBytes(ConfigConstants.DEFAULT_CHARSET)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java new file mode 100644 index 0000000000000..2a6975aa3d307 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.runtime.checkpoint.savepoint.CheckpointTestUtils; +import org.junit.Test; + +import java.util.Map; +import java.util.Random; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.powermock.api.mockito.PowerMockito.spy; + +public class IncrementalKeyedStateHandleTest { + + /** + * This test checks, that for an unregistered {@link IncrementalKeyedStateHandle} all state + * (including shared) is discarded. + */ + @Test + public void testUnregisteredDiscarding() throws Exception { + IncrementalKeyedStateHandle stateHandle = create(new Random(42)); + + stateHandle.discardState(); + + for (StreamStateHandle handle : stateHandle.getPrivateState().values()) { + verify(handle).discardState(); + } + + for (StreamStateHandle handle : stateHandle.getSharedState().values()) { + verify(handle).discardState(); + } + + verify(stateHandle.getMetaStateHandle()).discardState(); + } + + /** + * This test checks, that for a registered {@link IncrementalKeyedStateHandle} discards respect + * all shared state and only discard it one all references are released. + */ + @Test + public void testSharedStateDeRegistration() throws Exception { + + Random rnd = new Random(42); + + SharedStateRegistry registry = spy(new SharedStateRegistry()); + + // Create two state handles with overlapping shared state + IncrementalKeyedStateHandle stateHandle1 = create(new Random(42)); + IncrementalKeyedStateHandle stateHandle2 = create(new Random(42)); + + // Both handles should not be registered and not discarded by now. + for (Map.Entry entry : + stateHandle1.getSharedState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(entry.getKey()); + + verify(registry, times(0)).unregisterReference(registryKey); + verify(entry.getValue(), times(0)).discardState(); + } + + for (Map.Entry entry : + stateHandle2.getSharedState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(entry.getKey()); + + verify(registry, times(0)).unregisterReference(registryKey); + verify(entry.getValue(), times(0)).discardState(); + } + + // Now we register both ... + stateHandle1.registerSharedStates(registry); + stateHandle2.registerSharedStates(registry); + + for (Map.Entry stateHandleEntry : + stateHandle1.getSharedState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(stateHandleEntry.getKey()); + + verify(registry).registerReference( + registryKey, + stateHandleEntry.getValue()); + } + + for (Map.Entry stateHandleEntry : + stateHandle2.getSharedState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(stateHandleEntry.getKey()); + + verify(registry).registerReference( + registryKey, + stateHandleEntry.getValue()); + } + + // We discard the first + stateHandle1.discardState(); + + // Should be unregistered, non-shared discarded, shared not discarded + for (Map.Entry entry : + stateHandle1.getSharedState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(entry.getKey()); + + verify(registry, times(1)).unregisterReference(registryKey); + verify(entry.getValue(), times(0)).discardState(); + } + + for (StreamStateHandle handle : + stateHandle2.getSharedState().values()) { + + verify(handle, times(0)).discardState(); + } + + for (Map.Entry handleEntry : + stateHandle1.getPrivateState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(handleEntry.getKey()); + + verify(registry, times(0)).unregisterReference(registryKey); + verify(handleEntry.getValue(), times(1)).discardState(); + } + + for (Map.Entry handleEntry : + stateHandle2.getPrivateState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(handleEntry.getKey()); + + verify(registry, times(0)).unregisterReference(registryKey); + verify(handleEntry.getValue(), times(0)).discardState(); + } + + verify(stateHandle1.getMetaStateHandle(), times(1)).discardState(); + verify(stateHandle2.getMetaStateHandle(), times(0)).discardState(); + + // We discard the second + stateHandle2.discardState(); + + + // Now everything should be unregistered and discarded + for (Map.Entry entry : + stateHandle1.getSharedState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(entry.getKey()); + + verify(registry, times(2)).unregisterReference(registryKey); + verify(entry.getValue()).discardState(); + } + + for (Map.Entry entry : + stateHandle2.getSharedState().entrySet()) { + + SharedStateRegistryKey registryKey = + stateHandle1.createSharedStateRegistryKeyFromFileName(entry.getKey()); + + verify(registry, times(2)).unregisterReference(registryKey); + verify(entry.getValue()).discardState(); + } + + verify(stateHandle1.getMetaStateHandle(), times(1)).discardState(); + verify(stateHandle2.getMetaStateHandle(), times(1)).discardState(); + } + + private static IncrementalKeyedStateHandle create(Random rnd) { + return new IncrementalKeyedStateHandle( + "test", + KeyGroupRange.of(0, 0), + 1L, + placeSpies(CheckpointTestUtils.createRandomStateHandleMap(rnd)), + placeSpies(CheckpointTestUtils.createRandomStateHandleMap(rnd)), + spy(CheckpointTestUtils.createDummyStreamStateHandle(rnd))); + } + + private static Map placeSpies( + Map map) { + + for (Map.Entry entry : map.entrySet()) { + entry.setValue(spy(entry.getValue())); + } + return map; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java index 03e2a1392be84..4104595910693 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java @@ -40,14 +40,14 @@ public void testRegistryNormal() { // register one state TestSharedState firstState = new TestSharedState("first"); - SharedStateRegistry.Result result = sharedStateRegistry.registerNewReference(firstState.getRegistrationKey(), firstState); + SharedStateRegistry.Result result = sharedStateRegistry.registerReference(firstState.getRegistrationKey(), firstState); assertEquals(1, result.getReferenceCount()); assertTrue(firstState == result.getReference()); assertFalse(firstState.isDiscarded()); // register another state TestSharedState secondState = new TestSharedState("second"); - result = sharedStateRegistry.registerNewReference(secondState.getRegistrationKey(), secondState); + result = sharedStateRegistry.registerReference(secondState.getRegistrationKey(), secondState); assertEquals(1, result.getReferenceCount()); assertTrue(secondState == result.getReference()); assertFalse(firstState.isDiscarded()); @@ -55,7 +55,7 @@ public void testRegistryNormal() { // attempt to register state under an existing key TestSharedState firstStatePrime = new TestSharedState(firstState.getRegistrationKey().getKeyString()); - result = sharedStateRegistry.registerNewReference(firstState.getRegistrationKey(), firstStatePrime); + result = sharedStateRegistry.registerReference(firstState.getRegistrationKey(), firstStatePrime); assertEquals(2, result.getReferenceCount()); assertFalse(firstStatePrime == result.getReference()); assertTrue(firstState == result.getReference()); @@ -63,19 +63,19 @@ public void testRegistryNormal() { assertFalse(firstState.isDiscarded()); // reference the first state again - result = sharedStateRegistry.obtainReference(firstState.getRegistrationKey()); + result = sharedStateRegistry.registerReference(firstState.getRegistrationKey(), firstState); assertEquals(3, result.getReferenceCount()); assertTrue(firstState == result.getReference()); assertFalse(firstState.isDiscarded()); // unregister the second state - result = sharedStateRegistry.releaseReference(secondState.getRegistrationKey()); + result = sharedStateRegistry.unregisterReference(secondState.getRegistrationKey()); assertEquals(0, result.getReferenceCount()); assertTrue(result.getReference() == null); assertTrue(secondState.isDiscarded()); // unregister the first state - result = sharedStateRegistry.releaseReference(firstState.getRegistrationKey()); + result = sharedStateRegistry.unregisterReference(firstState.getRegistrationKey()); assertEquals(2, result.getReferenceCount()); assertTrue(firstState == result.getReference()); assertFalse(firstState.isDiscarded()); @@ -87,7 +87,7 @@ public void testRegistryNormal() { @Test(expected = IllegalStateException.class) public void testUnregisterWithUnexistedKey() { SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); - sharedStateRegistry.releaseReference(new SharedStateRegistryKey("non-existent")); + sharedStateRegistry.unregisterReference(new SharedStateRegistryKey("non-existent")); } private static class TestSharedState implements StreamStateHandle { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index b1927f18ba393..8d4a38ee3a833 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -539,7 +539,6 @@ public void testKryoRegisteringRestoreResilienceWithDefaultSerializer() throws E snapshot2.registerSharedStates(sharedStateRegistry); - snapshot.unregisterSharedStates(sharedStateRegistry); snapshot.discardState(); backend.dispose(); @@ -631,7 +630,6 @@ public void testKryoRegisteringRestoreResilienceWithRegisteredSerializer() throw snapshot2.registerSharedStates(sharedStateRegistry); - snapshot.unregisterSharedStates(sharedStateRegistry); snapshot.discardState(); backend.dispose(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java index 11a03cceb4b32..2251e46af7c27 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java @@ -52,10 +52,12 @@ public void recover() throws Exception { @Override public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception { - checkpoints.addLast(checkpoint); checkpoint.registerSharedStates(sharedStateRegistry); + checkpoints.addLast(checkpoint); + + if (checkpoints.size() > 1) { CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst(); checkpointToSubsume.discardOnSubsume(sharedStateRegistry); @@ -76,7 +78,6 @@ public void shutdown(JobStatus jobStatus) throws Exception { suspended.clear(); for (CompletedCheckpoint checkpoint : checkpoints) { - sharedStateRegistry.unregisterAll(checkpoint.getOperatorStates().values()); suspended.add(checkpoint); } diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java index fea2b794fd518..6ad770887ab18 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java @@ -29,6 +29,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; +import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.CheckpointListener; @@ -147,8 +148,14 @@ public void initStateBackend() throws IOException { } case ROCKSDB_INCREMENTAL: { String rocksDb = tempFolder.newFolder().getAbsolutePath(); + String backups = tempFolder.newFolder().getAbsolutePath(); + // we use the fs backend with small threshold here to test the behaviour with file + // references, not self contained byte handles RocksDBStateBackend rdb = - new RocksDBStateBackend(new MemoryStateBackend(MAX_MEM_STATE_SIZE), true); + new RocksDBStateBackend( + new FsStateBackend( + new Path("file://" + backups).toUri(), 16), + true); rdb.setDbStoragePath(rocksDb); this.stateBackend = rdb; break; diff --git a/flink-tests/src/test/java/org/apache/flink/test/recovery/JobManagerHACheckpointRecoveryITCase.java b/flink-tests/src/test/java/org/apache/flink/test/recovery/JobManagerHACheckpointRecoveryITCase.java index 6c70b87e715ef..f9af603cfd7c3 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/recovery/JobManagerHACheckpointRecoveryITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/recovery/JobManagerHACheckpointRecoveryITCase.java @@ -20,9 +20,20 @@ import akka.actor.ActorRef; import akka.actor.ActorSystem; +import akka.actor.PoisonPill; import org.apache.commons.io.FileUtils; +import org.apache.curator.test.TestingServer; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.CoreOptions; +import org.apache.flink.configuration.HighAvailabilityOptions; +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.akka.ListeningBehaviour; import org.apache.flink.runtime.clusterframework.types.ResourceID; @@ -35,6 +46,12 @@ import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.leaderelection.TestingListener; import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService; +import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.runtime.taskmanager.TaskManager; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.testtasks.BlockingNoOpInvokable; @@ -43,18 +60,20 @@ import org.apache.flink.runtime.testutils.JobManagerProcess; import org.apache.flink.runtime.testutils.ZooKeeperTestUtils; import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment; -import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.checkpoint.ListCheckpointed; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.testutils.junit.RetryOnFailure; import org.apache.flink.testutils.junit.RetryRule; +import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Option; @@ -127,7 +146,7 @@ public void cleanUp() throws Exception { private static final int Parallelism = 8; - private static CountDownLatch CompletedCheckpointsLatch = new CountDownLatch(2); + private static CountDownLatch CompletedCheckpointsLatch = new CountDownLatch(4); private static AtomicLongArray RecoveredStates = new AtomicLongArray(Parallelism); @@ -137,182 +156,7 @@ public void cleanUp() throws Exception { private static long LastElement = -1; - /** - * Simple checkpointed streaming sum. - * - *

The sources (Parallelism) count until sequenceEnd. The sink (1) sums up all counts and - * returns it to the main thread via a static variable. We wait until some checkpoints are - * completed and sanity check that the sources recover with an updated state to make sure that - * this test actually tests something. - */ - @Test - @RetryOnFailure(times=1) - public void testCheckpointedStreamingSumProgram() throws Exception { - // Config - final int checkpointingInterval = 200; - final int sequenceEnd = 5000; - final long expectedSum = Parallelism * sequenceEnd * (sequenceEnd + 1) / 2; - - final StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); - env.setParallelism(Parallelism); - env.enableCheckpointing(checkpointingInterval); - - env - .addSource(new CheckpointedSequenceSource(sequenceEnd)) - .addSink(new CountingSink()) - .setParallelism(1); - - JobGraph jobGraph = env.getStreamGraph().getJobGraph(); - - Configuration config = ZooKeeperTestUtils.createZooKeeperHAConfig(ZooKeeper - .getConnectString(), FileStateBackendBasePath.getAbsoluteFile().toURI().toString()); - config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, Parallelism); - - ActorSystem testSystem = null; - final JobManagerProcess[] jobManagerProcess = new JobManagerProcess[2]; - LeaderRetrievalService leaderRetrievalService = null; - ActorSystem taskManagerSystem = null; - final HighAvailabilityServices highAvailabilityServices = HighAvailabilityServicesUtils.createHighAvailabilityServices( - config, - TestingUtils.defaultExecutor(), - HighAvailabilityServicesUtils.AddressResolution.NO_ADDRESS_RESOLUTION); - - try { - final Deadline deadline = TestTimeOut.fromNow(); - - // Test actor system - testSystem = AkkaUtils.createActorSystem(new Configuration(), - new Some<>(new Tuple2("localhost", 0))); - - // The job managers - jobManagerProcess[0] = new JobManagerProcess(0, config); - jobManagerProcess[1] = new JobManagerProcess(1, config); - - jobManagerProcess[0].startProcess(); - jobManagerProcess[1].startProcess(); - - // Leader listener - TestingListener leaderListener = new TestingListener(); - leaderRetrievalService = highAvailabilityServices.getJobManagerLeaderRetriever(HighAvailabilityServices.DEFAULT_JOB_ID); - leaderRetrievalService.start(leaderListener); - - // The task manager - taskManagerSystem = AkkaUtils.createActorSystem( - config, Option.apply(new Tuple2("localhost", 0))); - TaskManager.startTaskManagerComponentsAndActor( - config, - ResourceID.generate(), - taskManagerSystem, - highAvailabilityServices, - "localhost", - Option.empty(), - false, - TaskManager.class); - - { - // Initial submission - leaderListener.waitForNewLeader(deadline.timeLeft().toMillis()); - - String leaderAddress = leaderListener.getAddress(); - UUID leaderId = leaderListener.getLeaderSessionID(); - - // Get the leader ref - ActorRef leaderRef = AkkaUtils.getActorRef( - leaderAddress, testSystem, deadline.timeLeft()); - ActorGateway leader = new AkkaActorGateway(leaderRef, leaderId); - - // Submit the job in detached mode - leader.tell(new SubmitJob(jobGraph, ListeningBehaviour.DETACHED)); - - JobManagerActorTestUtils.waitForJobStatus( - jobGraph.getJobID(), JobStatus.RUNNING, leader, deadline.timeLeft()); - } - - // Who's the boss? - JobManagerProcess leadingJobManagerProcess; - if (jobManagerProcess[0].getJobManagerAkkaURL(deadline.timeLeft()).equals(leaderListener.getAddress())) { - leadingJobManagerProcess = jobManagerProcess[0]; - } - else { - leadingJobManagerProcess = jobManagerProcess[1]; - } - - CompletedCheckpointsLatch.await(); - - // Kill the leading job manager process - leadingJobManagerProcess.destroy(); - - { - // Recovery by the standby JobManager - leaderListener.waitForNewLeader(deadline.timeLeft().toMillis()); - - String leaderAddress = leaderListener.getAddress(); - UUID leaderId = leaderListener.getLeaderSessionID(); - - ActorRef leaderRef = AkkaUtils.getActorRef( - leaderAddress, testSystem, deadline.timeLeft()); - ActorGateway leader = new AkkaActorGateway(leaderRef, leaderId); - - JobManagerActorTestUtils.waitForJobStatus(jobGraph.getJobID(), JobStatus.RUNNING, - leader, deadline.timeLeft()); - } - - // Wait to finish - FinalCountLatch.await(); - - assertEquals(expectedSum, (long) FinalCount.get()); - - for (int i = 0; i < Parallelism; i++) { - assertNotEquals(0, RecoveredStates.get(i)); - } - } - catch (Throwable t) { - // Reset all static state for test retries - CompletedCheckpointsLatch = new CountDownLatch(2); - RecoveredStates = new AtomicLongArray(Parallelism); - FinalCountLatch = new CountDownLatch(1); - FinalCount = new AtomicReference<>(); - LastElement = -1; - - // Print early (in some situations the process logs get too big - // for Travis and the root problem is not shown) - t.printStackTrace(); - - // In case of an error, print the job manager process logs. - if (jobManagerProcess[0] != null) { - jobManagerProcess[0].printProcessLog(); - } - - if (jobManagerProcess[1] != null) { - jobManagerProcess[1].printProcessLog(); - } - - throw t; - } - finally { - if (jobManagerProcess[0] != null) { - jobManagerProcess[0].destroy(); - } - - if (jobManagerProcess[1] != null) { - jobManagerProcess[1].destroy(); - } - - if (leaderRetrievalService != null) { - leaderRetrievalService.stop(); - } - - if (taskManagerSystem != null) { - taskManagerSystem.shutdown(); - } - - if (testSystem != null) { - testSystem.shutdown(); - } - - highAvailabilityServices.closeAndCleanupAllData(); - } - } + private static final int retainedCheckpoints = 2; /** * Tests that the JobManager logs failures during recovery properly. @@ -480,13 +324,110 @@ public void testCheckpointRecoveryFailure() throws Exception { } } + @Test + public void testCheckpointedStreamingProgramIncrementalRocksDB() throws Exception { + testCheckpointedStreamingProgram( + new RocksDBStateBackend( + new FsStateBackend(FileStateBackendBasePath.getAbsoluteFile().toURI(), 16), + true)); + } + + private void testCheckpointedStreamingProgram(AbstractStateBackend stateBackend) throws Exception { + + // Config + final int checkpointingInterval = 100; + final int sequenceEnd = 5000; + final long expectedSum = Parallelism * sequenceEnd * (sequenceEnd + 1) / 2; + + final ActorSystem system = ActorSystem.create("Test", AkkaUtils.getDefaultAkkaConfig()); + final TestingServer testingServer = new TestingServer(); + final TemporaryFolder temporaryFolder = new TemporaryFolder(); + temporaryFolder.create(); + + LocalFlinkMiniCluster miniCluster = null; + + final int numJMs = 2; + final int numTMs = 4; + final int numSlots = 8; + + try { + Configuration config = new Configuration(); + config.setInteger(CoreOptions.MAX_RETAINED_CHECKPOINTS, retainedCheckpoints); + config.setInteger(ConfigConstants.LOCAL_NUMBER_JOB_MANAGER, numJMs); + config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, numTMs); + config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, numSlots); + + + String tmpFolderString = temporaryFolder.newFolder().toString(); + config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, tmpFolderString); + config.setString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, testingServer.getConnectString()); + config.setString(HighAvailabilityOptions.HA_MODE, "zookeeper"); + + miniCluster = new LocalFlinkMiniCluster(config, true); + + miniCluster.start(); + + ActorGateway jmGateway = miniCluster.getLeaderGateway(TestingUtils.TESTING_DURATION()); + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(Parallelism); + env.enableCheckpointing(checkpointingInterval); + + //TODO parameterize + env.setStateBackend(stateBackend); + env + .addSource(new CheckpointedSequenceSource(sequenceEnd, 1)) + .keyBy(new KeySelector() { + + private static final long serialVersionUID = -8572892067702489025L; + + @Override + public Object getKey(Long value) throws Exception { + return value; + } + }) + .flatMap(new StatefulFlatMap()).setParallelism(1) + .addSink(new CountingSink()) + .setParallelism(1); + + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + miniCluster.submitJobDetached(jobGraph); + + CompletedCheckpointsLatch.await(); + + jmGateway.tell(PoisonPill.getInstance()); + + // Wait to finish + FinalCountLatch.await(); + + assertEquals(expectedSum, (long) FinalCount.get()); + + for (int i = 0; i < Parallelism; i++) { + assertNotEquals(0, RecoveredStates.get(i)); + } + + } finally { + if (miniCluster != null) { + miniCluster.stop(); + miniCluster.awaitTermination(); + } + + system.shutdown(); + system.awaitTermination(); + + testingServer.stop(); + testingServer.close(); + + } + } + // --------------------------------------------------------------------------------------------- /** * A checkpointed source, which emits elements from 0 to a configured number. */ public static class CheckpointedSequenceSource extends RichParallelSourceFunction - implements ListCheckpointed { + implements ListCheckpointed> { private static final Logger LOG = LoggerFactory.getLogger(CheckpointedSequenceSource.class); @@ -496,13 +437,22 @@ public static class CheckpointedSequenceSource extends RichParallelSourceFunctio private final long end; - private long current = 0; + private int repeat; + + private long current; private volatile boolean isRunning = true; public CheckpointedSequenceSource(long end) { + this(end, 1); + + } + + public CheckpointedSequenceSource(long end, int repeat) { checkArgument(end >= 0, "Negative final count"); + this.current = 0; this.end = end; + this.repeat = repeat; } @Override @@ -511,8 +461,10 @@ public void run(SourceContext ctx) throws Exception { synchronized (ctx.getCheckpointLock()) { if (current <= end) { ctx.collect(current++); - } - else { + } else if(repeat > 0) { + --repeat; + current = 0; + } else { ctx.collect(LastElement); return; } @@ -520,32 +472,33 @@ public void run(SourceContext ctx) throws Exception { // Slow down until some checkpoints are completed if (sync.getCount() != 0) { - Thread.sleep(100); + Thread.sleep(50); } } } @Override - public List snapshotState(long checkpointId, long timestamp) throws Exception { + public List> snapshotState(long checkpointId, long timestamp) throws Exception { LOG.debug("Snapshotting state {} @ ID {}.", current, checkpointId); - return Collections.singletonList(this.current); + return Collections.singletonList(new Tuple2<>(this.current, this.repeat)); } @Override - public void restoreState(List state) throws Exception { - if (state.isEmpty() || state.size() > 1) { - throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size()); + public void restoreState(List> list) throws Exception { + if (list.isEmpty() || list.size() > 1) { + throw new RuntimeException("Test failed due to unexpected recovered state size " + list.size()); } - Long s = state.get(0); - LOG.debug("Restoring state {}", s); + Tuple2 state = list.get(0); + LOG.debug("Restoring state {}", state); // This is necessary to make sure that something is recovered at all. Otherwise it // might happen that the job is restarted from the beginning. - RecoveredStates.set(getRuntimeContext().getIndexOfThisSubtask(), s); + RecoveredStates.set(getRuntimeContext().getIndexOfThisSubtask(), 1); sync.countDown(); - current = s; + current = state._1; + repeat = state._2; } @Override @@ -571,6 +524,7 @@ public static class CountingSink implements SinkFunction, ListCheckpointed @Override public void invoke(Long value) throws Exception { + if (value == LastElement) { numberOfReceivedLastElements++; @@ -611,4 +565,41 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception { CompletedCheckpointsLatch.countDown(); } } + + public static class StatefulFlatMap extends RichFlatMapFunction implements CheckpointedFunction { + + private static final long serialVersionUID = 9031079547885320663L; + + private transient ValueState alreadySeen; + + @Override + public void flatMap(Long input, Collector out) throws Exception { + + Integer seen = this.alreadySeen.value(); + if (seen >= Parallelism || input == -1) { + out.collect(input); + } + this.alreadySeen.update(seen + 1); + } + + @Override + public void open(Configuration config) { + + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + ValueStateDescriptor descriptor = + new ValueStateDescriptor<>( + "seenCountState", + TypeInformation.of(new TypeHint() {}), + 0); + alreadySeen = context.getKeyedStateStore().getState(descriptor); + } + } }