Skip to content

Commit

Permalink
[streaming] Allow using both partitioned and non-partitioned state in…
Browse files Browse the repository at this point in the history
… an operator + refactor
  • Loading branch information
gyfora committed Jun 25, 2015
1 parent 0a4144e commit 474ff4d
Show file tree
Hide file tree
Showing 14 changed files with 128 additions and 157 deletions.
9 changes: 3 additions & 6 deletions docs/apis/streaming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1191,17 +1191,14 @@ Stateful computation
Flink supports the checkpointing and persistence of user defined operator state, so in case of a failure this state can be restored to the latest checkpoint and the processing will continue from there. This gives exactly once processing semantics with respect to the operator states when the sources follow this stateful pattern as well. In practice this usually means that sources keep track of their current offset as their OperatorState. The `PersistentKafkaSource` provides this stateful functionality for reading streams from Kafka.

Flink supports two ways of accessing operator states: partitioned and non-partitioned state access.
In case of non-partitioned state access, an operator state is maintained for each parallel instance of a given operator. When `OperatorState.getState()` is called a separate state is returned in each parallel instance. In practice this means if we keep a counter for the received inputs in a mapper, `getState()` will return number of inputs processed by each parallel mapper.

In case of partitioned state access the user needs to define a `KeyExtractor` which will assign a key to each input of the stateful operator:
In case of non-partitioned state access, an operator state is maintained for each parallel instance of a given operator. When `OperatorState.getState()` is called, a separate state is returned in each parallel instance. In practice this means if we keep a counter for the received inputs in a mapper, `getState()` will return number of inputs processed by each parallel mapper.

`stream.map(counter).setStatePartitioner(…)`

A separate `OperatorState` is maintained for each received key which can be used for instance to count received inputs by different keys, or store and update summary statistics of different sub-streams.
In case of of partitioned `OperatorState` a separate state is maintained for each received key. This can be used for instance to count received inputs by different keys, or store and update summary statistics of different sub-streams.

Checkpointing of the states needs to be enabled from the `StreamExecutionEnvironment` using the `enableCheckpointing(…)` where additional parameters can be passed to modify the default 5 second checkpoint interval.

Operators can be accessed from the `RuntimeContext` using the `getOperatorState(“name”, defaultValue)` method so it is only accessible in `RichFunction`s. A recommended usage pattern is to retrieve the operator state in the `open(…)` method of the operator and set it as a field in the operator instance for runtime usage. Multiple `OperatorState`s can be used simultaneously by the same operator by using different names to identify them.
Operator states can be accessed from the `RuntimeContext` using the `getOperatorState(“name”, defaultValue, partitioned)` method so it is only accessible in `RichFunction`s. A recommended usage pattern is to retrieve the operator state in the `open(…)` method of the operator and set it as a field in the operator instance for runtime usage. Multiple `OperatorState`s can be used simultaneously by the same operator by using different names to identify them.

By default operator states are checkpointed using default java serialization thus they need to be `Serializable`. The user can gain more control over the state checkpoint mechanism by passing a `StateCheckpointer` instance when retrieving the `OperatorState` from the `RuntimeContext`. The `StateCheckpointer` allows custom implementations for the checkpointing logic for increased efficiency and to store arbitrary non-serializable states.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,16 @@ public interface RuntimeContext {
* the first time {@link OperatorState#getState()} (for every
* state partition) is called before
* {@link OperatorState#updateState(Object)}.
* @param partitioned
* Sets whether partitioning should be applied for the given
* state. If true a partitioner key must be used.
* @param checkpointer
* The {@link StateCheckpointer} that will be used to draw
* snapshots from the user state.
* @return The {@link OperatorState} for the underlying operator.
*/
<S,C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState, StateCheckpointer<S,C> checkpointer);
<S, C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
boolean partitioned, StateCheckpointer<S, C> checkpointer);

/**
* Returns the {@link OperatorState} with the given name of the underlying
Expand All @@ -205,14 +209,18 @@ public interface RuntimeContext {
* </p>
*
* @param name
* Identifier for the state allowing that more operator states can be
* used by the same operator.
* Identifier for the state allowing that more operator states
* can be used by the same operator.
* @param defaultState
* Default value for the operator state. This will be returned
* the first time {@link OperatorState#getState()} (for every
* state partition) is called before
* {@link OperatorState#updateState(Object)}.
* @param partitioned
* Sets whether partitioning should be applied for the given
* state. If true a partitioner key must be used.
* @return The {@link OperatorState} for the underlying operator.
*/
<S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState);
<S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
boolean partitioned);
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,14 @@ private <V, A extends Serializable> Accumulator<V, A> getAccumulator(String name
}

@Override
public <S, C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState, StateCheckpointer<S, C> checkpointer) {
throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
public <S, C extends Serializable> OperatorState<S> getOperatorState(String name,
S defaultState, boolean partitioned, StateCheckpointer<S, C> checkpointer) {
throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
}

@Override
public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState) {
throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
boolean partitioned) {
throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void open(Configuration parameters) throws Exception {
// most likely the number of offsets we're going to store here will be lower than the number of partitions.
int numPartitions = getNumberOfPartitions();
LOG.debug("The topic {} has {} partitions", topicName, numPartitions);
this.lastOffsets = getRuntimeContext().getOperatorState("offset", new long[numPartitions]);
this.lastOffsets = getRuntimeContext().getOperatorState("offset", new long[numPartitions], false);
this.commitedOffsets = new long[numPartitions];
// check if there are offsets to restore
if (!Arrays.equals(lastOffsets.getState(), new long[numPartitions])) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void run(SourceContext<Long> ctx) throws Exception {

@Override
public void open(Configuration conf){
collected = getRuntimeContext().getOperatorState("collected", 0L);
collected = getRuntimeContext().getOperatorState("collected", 0L, false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.apache.flink.streaming.api.state.StreamOperatorState;
import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;

import com.google.common.collect.ImmutableMap;

/**
* This is used as the base class for operators that have a user-defined
* function.
Expand Down Expand Up @@ -72,25 +74,25 @@ public void close() throws Exception {

@SuppressWarnings({ "unchecked", "rawtypes" })
public void restoreInitialState(Map<String, PartitionedStateHandle> snapshots) throws Exception {

Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();

// We iterate over the states registered for this operator, initialize and restore it
for (Entry<String, PartitionedStateHandle> snapshot : snapshots.entrySet()) {
StreamOperatorState restoredState = runtimeContext.createRawState();
Map<Serializable, StateHandle<Serializable>> handles = snapshot.getValue().getState();
StreamOperatorState restoredState = runtimeContext.getState(snapshot.getKey(),
!(handles instanceof ImmutableMap));
restoredState.restoreState(snapshot.getValue().getState());
operatorStates.put(snapshot.getKey(), restoredState);
}

}

@SuppressWarnings({ "rawtypes", "unchecked" })
public Map<String, PartitionedStateHandle> getStateSnapshotFromFunction(long checkpointId, long timestamp)
throws Exception {

// Get all the states for the operator
Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();
if (operatorStates.isEmpty()) {
// We return null to signal that there is nothing to checkpoint
return null;
} else {
// Checkpoint the states and store the handles in a map
Map<String, PartitionedStateHandle> snapshots = new HashMap<String, PartitionedStateHandle>();

for (Entry<String, StreamOperatorState> state : operatorStates.entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
*/
public class StreamOperatorState<S, C extends Serializable> implements OperatorState<S> {

protected static final Serializable DEFAULTKEY = -1;
public static final Serializable DEFAULTKEY = -1;

private S state;
private StateCheckpointer<S, C> checkpointer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.flink.streaming.api.graph.StreamEdge;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.runtime.io.RecordWriterFactory;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
Expand All @@ -51,7 +52,7 @@ public class OutputHandler<OUT> {
private ClassLoader cl;
private Output<OUT> outerOutput;

public List<OneInputStreamOperator<?, ?>> chainedOperators;
public List<StreamOperator<?>> chainedOperators;

private Map<StreamEdge, StreamOutput<?>> outputMap;

Expand All @@ -63,7 +64,7 @@ public OutputHandler(StreamTask<OUT, ?> vertex) {
// Initialize some fields
this.vertex = vertex;
this.configuration = new StreamConfig(vertex.getTaskConfiguration());
this.chainedOperators = new ArrayList<OneInputStreamOperator<?, ?>>();
this.chainedOperators = new ArrayList<StreamOperator<?>>();
this.outputMap = new HashMap<StreamEdge, StreamOutput<?>>();
this.cl = vertex.getUserCodeClassLoader();

Expand All @@ -88,6 +89,9 @@ public OutputHandler(StreamTask<OUT, ?> vertex) {
// We create the outer output that will be passed to the first task
// in the chain
this.outerOutput = createChainedCollector(configuration);

// Add the head operator to the end of the list
this.chainedOperators.add(vertex.streamOperator);
}

public void broadcastBarrier(long id, long timestamp) throws IOException, InterruptedException {
Expand All @@ -101,7 +105,7 @@ public Collection<StreamOutput<?>> getOutputs() {
return outputMap.values();
}

public List<OneInputStreamOperator<?, ?>> getChainedOperators(){
public List<StreamOperator<?>> getChainedOperators(){
return chainedOperators;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.flink.runtime.util.SerializedValue;
import org.apache.flink.runtime.util.event.EventListener;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StatefulStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.state.WrapperStateHandle;
Expand Down Expand Up @@ -74,8 +73,6 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs

protected ClassLoader userClassLoader;

private StateHandleProvider<Serializable> stateHandleProvider;

private EventListener<TaskEvent> superstepListener;

public StreamTask() {
Expand All @@ -88,11 +85,10 @@ public StreamTask() {
public void registerInputOutput() {
this.userClassLoader = getUserCodeClassLoader();
this.configuration = new StreamConfig(getTaskConfiguration());
this.stateHandleProvider = getStateHandleProvider();

streamOperator = configuration.getStreamOperator(userClassLoader);

outputHandler = new OutputHandler<OUT>(this);

streamOperator = configuration.getStreamOperator(userClassLoader);

if (streamOperator != null) {
// IterationHead and IterationTail don't have an Operator...
Expand All @@ -103,7 +99,7 @@ public void registerInputOutput() {
streamOperator.setup(outputHandler.getOutput(), headContext);
}

hasChainedOperators = !outputHandler.getChainedOperators().isEmpty();
hasChainedOperators = !(outputHandler.getChainedOperators().size() == 1);
}

public String getName() {
Expand Down Expand Up @@ -167,20 +163,16 @@ private enum StateBackend {
}

protected void openOperator() throws Exception {
streamOperator.open(getTaskConfiguration());

for (OneInputStreamOperator<?, ?> operator : outputHandler.chainedOperators) {
for (StreamOperator<?> operator : outputHandler.getChainedOperators()) {
operator.open(getTaskConfiguration());
}
}

protected void closeOperator() throws Exception {
streamOperator.close();

// We need to close them first to last, since upstream operators in the chain might emit
// elements in their close methods.
for (int i = outputHandler.chainedOperators.size()-1; i >= 0; i--) {
outputHandler.chainedOperators.get(i).close();
for (int i = outputHandler.getChainedOperators().size()-1; i >= 0; i--) {
outputHandler.getChainedOperators().get(i).close();
}
}

Expand All @@ -206,28 +198,19 @@ public EventListener<TaskEvent> getSuperstepListener() {
@SuppressWarnings("unchecked")
@Override
public void setInitialState(StateHandle<Serializable> stateHandle) throws Exception {
// here, we later resolve the state handle into the actual state by
// loading the state described by the handle from the backup store
Serializable state = stateHandle.getState();

// We retrieve end restore the states for the chained oeprators.
List<Serializable> chainedStates = (List<Serializable>) stateHandle.getState();

List<Serializable> chainedStates = (List<Serializable>) state;
// We restore all stateful chained operators
for (int i = 0; i < chainedStates.size(); i++) {
Serializable state = chainedStates.get(i);
// If state is not null we need to restore it
if (state != null) {
StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i);

Serializable headState = chainedStates.get(0);
if (headState != null) {
if (streamOperator instanceof StatefulStreamOperator) {
((StatefulStreamOperator<?>) streamOperator)
.restoreInitialState((Map<String, PartitionedStateHandle>) headState);
}
}

for (int i = 1; i < chainedStates.size(); i++) {
Serializable chainedState = chainedStates.get(i);
if (chainedState != null) {
StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i - 1);
if (chainedOperator instanceof StatefulStreamOperator) {
((StatefulStreamOperator<?>) chainedOperator)
.restoreInitialState((Map<String, PartitionedStateHandle>) chainedState);
}
((StatefulStreamOperator<?>) chainedOperator)
.restoreInitialState((Map<String, PartitionedStateHandle>) state);

}
}
Expand All @@ -242,22 +225,21 @@ public void triggerCheckpoint(long checkpointId, long timestamp) throws Exceptio
try {
LOG.debug("Starting checkpoint {} on task {}", checkpointId, getName());

// first draw the state that should go into checkpoint
// We wrap the states of the chained operators in a list, marking non-stateful oeprators with null
List<Map<String, PartitionedStateHandle>> chainedStates = new ArrayList<Map<String, PartitionedStateHandle>>();
StateHandle<Serializable> stateHandle;

// A wrapper handle is created for the List of statehandles
WrapperStateHandle stateHandle;
try {

if (streamOperator instanceof StatefulStreamOperator) {
chainedStates.add(((StatefulStreamOperator<?>) streamOperator).getStateSnapshotFromFunction(checkpointId, timestamp));
}


if (hasChainedOperators) {
// We construct a list of states for chained tasks
for (OneInputStreamOperator<?, ?> chainedOperator : outputHandler.getChainedOperators()) {
if (chainedOperator instanceof StatefulStreamOperator) {
chainedStates.add(((StatefulStreamOperator<?>) chainedOperator).getStateSnapshotFromFunction(checkpointId, timestamp));
}
// We construct a list of states for chained tasks
for (StreamOperator<?> chainedOperator : outputHandler
.getChainedOperators()) {
if (chainedOperator instanceof StatefulStreamOperator) {
chainedStates.add(((StatefulStreamOperator<?>) chainedOperator)
.getStateSnapshotFromFunction(checkpointId, timestamp));
}else{
chainedStates.add(null);
}
}

Expand Down Expand Up @@ -296,23 +278,10 @@ public void confirmCheckpoint(long checkpointId, SerializedValue<StateHandle<?>>

List<Map<String, PartitionedStateHandle>> chainedStates = (List<Map<String, PartitionedStateHandle>>) stateHandle.deserializeValue(getUserCodeClassLoader()).getState();

Map<String, PartitionedStateHandle> headState = chainedStates.get(0);
if (headState != null) {
if (streamOperator instanceof StatefulStreamOperator) {
for (Entry<String, PartitionedStateHandle> stateEntry : headState
.entrySet()) {
for (StateHandle<Serializable> handle : stateEntry.getValue().getState().values()) {
((StatefulStreamOperator) streamOperator).confirmCheckpointCompleted(
checkpointId, stateEntry.getKey(), handle);
}
}
}
}

for (int i = 1; i < chainedStates.size(); i++) {
for (int i = 0; i < chainedStates.size(); i++) {
Map<String, PartitionedStateHandle> chainedState = chainedStates.get(i);
if (chainedState != null) {
StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i - 1);
StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i);
if (chainedOperator instanceof StatefulStreamOperator) {
for (Entry<String, PartitionedStateHandle> stateEntry : chainedState
.entrySet()) {
Expand Down
Loading

0 comments on commit 474ff4d

Please sign in to comment.