Skip to content

Commit

Permalink
[FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesi…
Browse files Browse the repository at this point in the history
…sConsumer

This commit adds some general improvements to the rescalable
implementation of FlinkKinesisConsumer, including:
- Refactor setup procedures in KinesisDataFetcher so that duplicate work
  isn't done on a restored run
- Strengthen corner cases where fetcher was not fully seeded with
  initial state when snapshot is taken

This closes apache#3001.
  • Loading branch information
tzulitai committed May 7, 2017
1 parent a05b574 commit e5b65a7
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
Expand Down Expand Up @@ -67,9 +68,9 @@
* @param <T> the type of data emitted
*/
public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> implements
ResultTypeQueryable<T>,
CheckpointedFunction,
CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
ResultTypeQueryable<T>,
CheckpointedFunction,
CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {

private static final long serialVersionUID = 4724006128720664870L;

Expand All @@ -86,7 +87,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
* shard list retrieval behaviours, etc */
private final Properties configProps;

/** User supplied deseriliazation schema to convert Kinesis byte messages to Flink objects */
/** User supplied deserialization schema to convert Kinesis byte messages to Flink objects */
private final KinesisDeserializationSchema<T> deserializer;

// ------------------------------------------------------------------------
Expand All @@ -96,9 +97,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
/** Per-task fetcher for Kinesis data records, where each fetcher pulls data from one or more Kinesis shards */
private transient KinesisDataFetcher<T> fetcher;

/** The sequence numbers in the last state snapshot of this subtask */
private transient HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot;

/** The sequence numbers to restore to upon restore from failure */
private transient HashMap<KinesisStreamShard, SequenceNumber> sequenceNumsToRestore;

Expand All @@ -108,7 +106,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
// State for Checkpoint
// ------------------------------------------------------------------------

/** The name is the key for sequence numbers state, and cannot be changed. */
/** State name to access shard sequence number states; cannot be changed */
private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State";

private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> sequenceNumsStateForCheckpoint;
Expand Down Expand Up @@ -190,68 +188,55 @@ public FlinkKinesisConsumer(List<String> streams, KinesisDeserializationSchema<T
// Source life cycle
// ------------------------------------------------------------------------

@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);

// restore to the last known sequence numbers from the latest complete snapshot
if (sequenceNumsToRestore != null) {
if (LOG.isInfoEnabled()) {
LOG.info("Subtask {} is restoring sequence numbers {} from previous checkpointed state",
getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore.toString());
}

// initialize sequence numbers with restored state
lastStateSnapshot = sequenceNumsToRestore;
} else {
// start fresh with empty sequence numbers if there are no snapshots to restore from.
lastStateSnapshot = new HashMap<>();
}
}

@Override
public void run(SourceContext<T> sourceContext) throws Exception {

// all subtasks will run a fetcher, regardless of whether or not the subtask will initially have
// shards to subscribe to; fetchers will continuously poll for changes in the shard list, so all subtasks
// can potentially have new shards to subscribe to later on
fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);

boolean isRestoringFromFailure = (sequenceNumsToRestore != null);
fetcher.setIsRestoringFromFailure(isRestoringFromFailure);

// if we are restoring from a checkpoint, we iterate over the restored
// state and accordingly seed the fetcher with subscribed shards states
if (isRestoringFromFailure) {
// Since there may have a situation that some subtasks did not finish discovering before rescale,
// and KinesisDataFetcher will always discover the shard from the largest shard id. To prevent from
// missing some shards which didn't be discovered and whose id is not the largest one, we force the
// consumer to discover once from the smallest id and make sure each shard have its initial sequence
// number from restored state or SENTINEL_EARLIEST_SEQUENCE_NUM.
List<KinesisStreamShard> newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
SequenceNumber startingStateForNewShard;

if (lastStateSnapshot.containsKey(shard)) {
startingStateForNewShard = lastStateSnapshot.get(shard);
KinesisDataFetcher<T> fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);

// initial discovery
List<KinesisStreamShard> allShards = fetcher.discoverNewShardsToSubscribe();

for (KinesisStreamShard shard : allShards) {
if (sequenceNumsToRestore != null) {
if (sequenceNumsToRestore.containsKey(shard)) {
// if the shard was already seen and is contained in the state,
// just use the sequence number stored in the state
fetcher.registerNewSubscribedShardState(
new KinesisStreamShardState(shard, sequenceNumsToRestore.get(shard)));

if (LOG.isInfoEnabled()) {
LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
" starting state set to the restored sequence number {}",
getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingStateForNewShard);
getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), sequenceNumsToRestore.get(shard));
}
} else {
startingStateForNewShard = SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
// the shard wasn't discovered in the previous run, therefore should be consumed from the beginning
fetcher.registerNewSubscribedShardState(
new KinesisStreamShardState(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()));

if (LOG.isInfoEnabled()) {
LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," +
" starting state set to the SENTINEL_EARLIEST_SEQUENCE_NUM",
getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
}
}
} else {
// we're starting fresh; use the configured start position as initial state
SentinelSequenceNumber startingSeqNum =
InitialPosition.valueOf(configProps.getProperty(
ConsumerConfigConstants.STREAM_INITIAL_POSITION,
ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)).toSentinelSequenceNumber();

fetcher.registerNewSubscribedShardState(
new KinesisStreamShardState(shard, startingStateForNewShard));
new KinesisStreamShardState(shard, startingSeqNum.get()));

if (LOG.isInfoEnabled()) {
LOG.info("Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}",
getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingSeqNum.get());
}
}
}

Expand All @@ -260,6 +245,10 @@ public void run(SourceContext<T> sourceContext) throws Exception {
return;
}

// expose the fetcher from this point, so that state
// snapshots can be taken from the fetcher's state holders
this.fetcher = fetcher;

// start the fetcher loop. The fetcher will stop running only when cancel() or
// close() is called, or an error is thrown by threads created by the fetcher
fetcher.runFetcher();
Expand Down Expand Up @@ -306,13 +295,12 @@ public TypeInformation<T> getProducedType() {

@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> tuple = new TupleTypeInfo<>(
TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> shardsStateTypeInfo = new TupleTypeInfo<>(
TypeInformation.of(KinesisStreamShard.class),
TypeInformation.of(SequenceNumber.class)
);
TypeInformation.of(SequenceNumber.class));

sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState(
new ListStateDescriptor<>(sequenceNumsStateStoreName, tuple));
new ListStateDescriptor<>(sequenceNumsStateStoreName, shardsStateTypeInfo));

if (context.isRestored()) {
if (sequenceNumsToRestore == null) {
Expand All @@ -323,8 +311,6 @@ public void initializeState(FunctionInitializationContext context) throws Except

LOG.info("Setting restore state in the FlinkKinesisConsumer. Using the following offsets: {}",
sequenceNumsToRestore);
} else if (sequenceNumsToRestore.isEmpty()) {
sequenceNumsToRestore = null;
}
} else {
LOG.info("No restore state for FlinkKinesisConsumer.");
Expand All @@ -333,27 +319,41 @@ public void initializeState(FunctionInitializationContext context) throws Except

@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
if (lastStateSnapshot == null) {
LOG.debug("snapshotState() requested on not yet opened source; returning null.");
} else if (fetcher == null) {
LOG.debug("snapshotState() requested on not yet running source; returning null.");
} else if (!running) {
if (!running) {
LOG.debug("snapshotState() called on closed source; returning null.");
} else {
if (LOG.isDebugEnabled()) {
LOG.debug("Snapshotting state ...");
}

sequenceNumsStateForCheckpoint.clear();
lastStateSnapshot = fetcher.snapshotState();

if (LOG.isDebugEnabled()) {
LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
}
if (fetcher == null) {
if (sequenceNumsToRestore != null) {
for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : sequenceNumsToRestore.entrySet()) {
// sequenceNumsToRestore is the restored global union state;
// should only snapshot shards that actually belong to us

if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(
entry.getKey(),
getRuntimeContext().getNumberOfParallelSubtasks(),
getRuntimeContext().getIndexOfThisSubtask())) {

sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
}
}
}
} else {
HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot = fetcher.snapshotState();

if (LOG.isDebugEnabled()) {
LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
}

for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
}
}
}
}
Expand All @@ -366,12 +366,14 @@ public void restoreState(HashMap<KinesisStreamShard, SequenceNumber> restoredSta
sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState;
}

/** This method is created for tests that can mock the KinesisDataFetcher in the consumer. */
protected KinesisDataFetcher<T> createFetcher(List<String> streams,
SourceFunction.SourceContext<T> sourceContext,
RuntimeContext runtimeContext,
Properties configProps,
KinesisDeserializationSchema<T> deserializationSchema) {
/** This method is exposed for tests that need to mock the KinesisDataFetcher in the consumer. */
protected KinesisDataFetcher<T> createFetcher(
List<String> streams,
SourceFunction.SourceContext<T> sourceContext,
RuntimeContext runtimeContext,
Properties configProps,
KinesisDeserializationSchema<T> deserializationSchema) {

return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
Expand Down Expand Up @@ -99,12 +97,6 @@ public class KinesisDataFetcher<T> {

private final int indexOfThisConsumerSubtask;

/**
* This flag should be set by {@link FlinkKinesisConsumer} using
* {@link KinesisDataFetcher#setIsRestoringFromFailure(boolean)}
*/
private boolean isRestoredFromFailure;

// ------------------------------------------------------------------------
// Executor services to run created threads
// ------------------------------------------------------------------------
Expand Down Expand Up @@ -235,41 +227,7 @@ public void runFetcher() throws Exception {
// Procedures before starting the infinite while loop:
// ------------------------------------------------------------------------

// 1. query for any new shards that may have been created while the Kinesis consumer was not running,
// and register them to the subscribedShardState list.
if (LOG.isDebugEnabled()) {
String logFormat = (!isRestoredFromFailure)
? "Subtask {} is trying to discover initial shards ..."
: "Subtask {} is trying to discover any new shards that were created while the consumer wasn't " +
"running due to failure ...";

LOG.debug(logFormat, indexOfThisConsumerSubtask);
}
List<KinesisStreamShard> newShardsCreatedWhileNotRunning = discoverNewShardsToSubscribe();
for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
// the starting state for new shards created while the consumer wasn't running depends on whether or not
// we are starting fresh (not restoring from a checkpoint); when we are starting fresh, this simply means
// all existing shards of streams we are subscribing to are new shards; when we are restoring from checkpoint,
// any new shards due to Kinesis resharding from the time of the checkpoint will be considered new shards.
InitialPosition initialPosition = InitialPosition.valueOf(configProps.getProperty(
ConsumerConfigConstants.STREAM_INITIAL_POSITION, ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION));

SentinelSequenceNumber startingStateForNewShard = (isRestoredFromFailure)
? SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
: initialPosition.toSentinelSequenceNumber();

if (LOG.isInfoEnabled()) {
String logFormat = (!isRestoredFromFailure)
? "Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}"
: "Subtask {} will be seeded with new shard {} that was created while the consumer wasn't " +
"running due to failure, starting state set as sequence number {}";

LOG.info(logFormat, indexOfThisConsumerSubtask, shard.toString(), startingStateForNewShard.get());
}
registerNewSubscribedShardState(new KinesisStreamShardState(shard, startingStateForNewShard.get()));
}

// 2. check that there is at least one shard in the subscribed streams to consume from (can be done by
// 1. check that there is at least one shard in the subscribed streams to consume from (can be done by
// checking if at least one value in subscribedStreamsToLastDiscoveredShardIds is not null)
boolean hasShards = false;
StringBuilder streamsWithNoShardsFound = new StringBuilder();
Expand All @@ -290,7 +248,7 @@ public void runFetcher() throws Exception {
throw new RuntimeException("No shards can be found for all subscribed streams: " + streams);
}

// 3. start consuming any shard state we already have in the subscribedShardState up to this point; the
// 2. start consuming any shard state we already have in the subscribedShardState up to this point; the
// subscribedShardState may already be seeded with values due to step 1., or explicitly added by the
// consumer using a restored state checkpoint
for (int seededStateIndex = 0; seededStateIndex < subscribedShardsState.size(); seededStateIndex++) {
Expand Down Expand Up @@ -489,10 +447,6 @@ public List<KinesisStreamShard> discoverNewShardsToSubscribe() throws Interrupte
// Functions to get / set information about the consumer
// ------------------------------------------------------------------------

public void setIsRestoringFromFailure(boolean bool) {
this.isRestoredFromFailure = bool;
}

protected Properties getConsumerConfiguration() {
return configProps;
}
Expand Down Expand Up @@ -595,7 +549,7 @@ public int registerNewSubscribedShardState(KinesisStreamShardState newSubscribed
* @param totalNumberOfConsumerSubtasks total number of consumer subtasks
* @param indexOfThisConsumerSubtask index of this consumer subtask
*/
private static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
public static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
int totalNumberOfConsumerSubtasks,
int indexOfThisConsumerSubtask) {
return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@

/**
* Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were
* done using the Flink 1.1 {@link FlinkKinesisConsumer}.
*
* <p>For regenerating the binary snapshot file you have to run the commented out portion
* of each test on a checkout of the Flink 1.1 branch.
* done using the Flink 1.1 {@code FlinkKinesisConsumer}.
*/
public class FlinkKinesisConsumerMigrationTest {

Expand Down
Loading

0 comments on commit e5b65a7

Please sign in to comment.