Skip to content

Commit

Permalink
[FLINK-35858][Runtime/State] Create state with namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Zakelly committed Jul 25, 2024
1 parent b07b0b4 commit 055e11e
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.state.v2.StateDescriptor;
Expand All @@ -46,13 +47,17 @@ public interface AsyncKeyedStateBackend extends Disposable, Closeable {
/**
* Creates and returns a new state.
*
* @param stateDesc The {@code StateDescriptor} that contains the name of the state.
* @param <SV> The type of the stored state value.
* @param <N> the type of namespace for partitioning.
* @param <S> The type of the public API state.
* @param <SV> The type of the stored state value.
* @param namespaceSerializer the serializer for namespace.
* @param stateDesc The {@code StateDescriptor} that contains the name of the state.
* @throws Exception Exceptions may occur during initialization of the state.
*/
@Nonnull
<SV, S extends State> S createState(@Nonnull StateDescriptor<SV> stateDesc) throws Exception;
<N, S extends State, SV> S createState(
TypeSerializer<N> namespaceSerializer, @Nonnull StateDescriptor<SV> stateDesc)
throws Exception;

/**
* Creates a {@code StateExecutor} which supports to execute a batch of state requests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.api.common.state.v2.ReducingState;
import org.apache.flink.api.common.state.v2.ValueState;
import org.apache.flink.runtime.state.AsyncKeyedStateBackend;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.util.Preconditions;

import javax.annotation.Nonnull;
Expand All @@ -41,7 +42,8 @@ public DefaultKeyedStateStoreV2(@Nonnull AsyncKeyedStateBackend asyncKeyedStateB
public <T> ValueState<T> getValueState(@Nonnull ValueStateDescriptor<T> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(stateProperties);
return asyncKeyedStateBackend.createState(
VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
Expand All @@ -51,7 +53,8 @@ public <T> ValueState<T> getValueState(@Nonnull ValueStateDescriptor<T> statePro
public <T> ListState<T> getListState(@Nonnull ListStateDescriptor<T> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(stateProperties);
return asyncKeyedStateBackend.createState(
VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
Expand All @@ -62,7 +65,8 @@ public <UK, UV> MapState<UK, UV> getMapState(
@Nonnull MapStateDescriptor<UK, UV> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(stateProperties);
return asyncKeyedStateBackend.createState(
VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
Expand All @@ -73,7 +77,8 @@ public <T> ReducingState<T> getReducingState(
@Nonnull ReducingStateDescriptor<T> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(stateProperties);
return asyncKeyedStateBackend.createState(
VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
Expand All @@ -84,7 +89,8 @@ public <IN, ACC, OUT> AggregatingState<IN, OUT> getAggregatingState(
@Nonnull AggregatingStateDescriptor<IN, ACC, OUT> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(stateProperties);
return asyncKeyedStateBackend.createState(
VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.flink.runtime.state.AsyncKeyedStateBackend;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StateBackendTestUtils;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.v2.InternalValueState;
import org.apache.flink.runtime.state.v2.ValueStateDescriptor;
import org.apache.flink.util.FlinkRuntimeException;
Expand Down Expand Up @@ -108,7 +109,9 @@ void setup(
asyncKeyedStateBackend.setup(aec);

try {
valueState = asyncKeyedStateBackend.createState(stateDescriptor);
valueState =
asyncKeyedStateBackend.createState(
VoidNamespaceSerializer.INSTANCE, stateDescriptor);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ public void setup(@Nonnull StateRequestHandler stateRequestHandler) {
@Nonnull
@Override
@SuppressWarnings("unchecked")
public <SV, S extends org.apache.flink.api.common.state.v2.State> S createState(
public <N, S extends org.apache.flink.api.common.state.v2.State, SV> S createState(
TypeSerializer<N> namespaceSerializer,
@Nonnull org.apache.flink.runtime.state.v2.StateDescriptor<SV> stateDesc) {
return (S) innerStateSupplier.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.runtime.state.v2;

import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequest;
Expand Down Expand Up @@ -126,7 +127,9 @@ public void setup(@Nonnull StateRequestHandler stateRequestHandler) {}

@Nonnull
@Override
public <SV, S extends State> S createState(@Nonnull StateDescriptor<SV> stateDesc)
public <N, S extends State, SV> S createState(
TypeSerializer<N> namespaceSerializer,
@Nonnull StateDescriptor<SV> stateDesc)
throws Exception {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ public void setup(@Nonnull StateRequestHandler stateRequestHandler) {
@Nonnull
@Override
@SuppressWarnings("unchecked")
public <SV, S extends State> S createState(@Nonnull StateDescriptor<SV> stateDesc) {
public <N, S extends State, SV> S createState(
TypeSerializer<N> namespaceSerializer, @Nonnull StateDescriptor<SV> stateDesc) {
Preconditions.checkNotNull(
stateRequestHandler,
"A non-null stateRequestHandler must be setup before createState");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,22 @@ public <N, S extends State, T> S getOrCreateKeyedState(
}
}

/** Create new state (v2) based on new state descriptor. */
public <N, S extends org.apache.flink.api.common.state.v2.State, T> S getOrCreateKeyedState(
TypeSerializer<N> namespaceSerializer,
org.apache.flink.runtime.state.v2.StateDescriptor<T> stateDescriptor)
throws Exception {

if (asyncKeyedStateBackend != null) {
return asyncKeyedStateBackend.createState(namespaceSerializer, stateDescriptor);
} else {
throw new IllegalStateException(
"Cannot create partitioned state. "
+ "The keyed state backend has not been set."
+ "This indicates that the operator is not partitioned/keyed.");
}
}

/**
* Creates a partitioned state handle, using the state backend configured for this task.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
Expand All @@ -31,6 +32,7 @@
import org.apache.flink.runtime.state.AsyncKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.v2.StateDescriptor;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.Input;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
Expand Down Expand Up @@ -175,6 +177,13 @@ public final <T> ThrowingConsumer<StreamRecord<T>, Exception> getRecordProcessor
getClass().getName(), inputId));
}

/** Create new state (v2) based on new state descriptor. */
protected <N, S extends State, T> S getOrCreateKeyedState(
TypeSerializer<N> namespaceSerializer, StateDescriptor<T> stateDescriptor)
throws Exception {
return stateHandler.getOrCreateKeyedState(namespaceSerializer, stateDescriptor);
}

@Override
public final OperatorSnapshotFutures snapshotState(
long checkpointId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
Expand All @@ -30,6 +31,7 @@
import org.apache.flink.runtime.state.AsyncKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.v2.StateDescriptor;
import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.InternalTimerService;
Expand Down Expand Up @@ -150,6 +152,13 @@ public final <T> ThrowingConsumer<StreamRecord<T>, Exception> getRecordProcessor
+ " since this part is handled by the Input.");
}

/** Create new state (v2) based on new state descriptor. */
protected <N, S extends State, T> S getOrCreateKeyedState(
TypeSerializer<N> namespaceSerializer, StateDescriptor<T> stateDescriptor)
throws Exception {
return stateHandler.getOrCreateKeyedState(namespaceSerializer, stateDescriptor);
}

@Override
public final OperatorSnapshotFutures snapshotState(
long checkpointId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,13 @@ public <T> TypeSerializer<T> createSerializer(
doAnswer(
(Answer<Object>)
invocationOnMock -> {
ref.set(invocationOnMock.getArguments()[0]);
ref.set(invocationOnMock.getArguments()[1]);
return null;
})
.when(asyncKeyedStateBackend)
.createState(any(org.apache.flink.runtime.state.v2.StateDescriptor.class));
.createState(
any(TypeSerializer.class),
any(org.apache.flink.runtime.state.v2.StateDescriptor.class));

operator.initializeState(streamTaskStateManager);
operator.getRuntimeContext().setKeyedStateStore(keyedStateStore);
Expand Down

0 comments on commit 055e11e

Please sign in to comment.