Skip to content

Commit

Permalink
[FLINK-9926][Kinesis Connector] Allow for ShardConsumer override in K…
Browse files Browse the repository at this point in the history
…inesis consumer.

This closes apache#6427.
  • Loading branch information
tweise authored and tzulitai committed Aug 1, 2018
1 parent 58ca87a commit dd4b3d1
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ public class KinesisDataFetcher<T> {
/** Reference to the first error thrown by any of the {@link ShardConsumer} threads. */
private final AtomicReference<Throwable> error;

/** The Kinesis proxy factory that will be used to create instances for discovery and shard consumers. */
private final FlinkKinesisProxyFactory kinesisProxyFactory;

/** The Kinesis proxy that the fetcher will be using to discover new shards. */
private final KinesisProxyInterface kinesis;

Expand All @@ -179,6 +182,13 @@ public class KinesisDataFetcher<T> {

private volatile boolean running = true;

/**
* Factory to create Kinesis proxy instances used by a fetcher.
*/
public interface FlinkKinesisProxyFactory {
KinesisProxyInterface create(Properties configProps);
}

/**
* Creates a Kinesis Data Fetcher.
*
Expand All @@ -204,7 +214,7 @@ public KinesisDataFetcher(List<String> streams,
new AtomicReference<>(),
new ArrayList<>(),
createInitialSubscribedStreamsToLastDiscoveredShardsState(streams),
KinesisProxy.create(configProps));
KinesisProxy::create);
}

@VisibleForTesting
Expand All @@ -218,7 +228,7 @@ protected KinesisDataFetcher(List<String> streams,
AtomicReference<Throwable> error,
List<KinesisStreamShardState> subscribedShardsState,
HashMap<String, String> subscribedStreamsToLastDiscoveredShardIds,
KinesisProxyInterface kinesis) {
FlinkKinesisProxyFactory kinesisProxyFactory) {
this.streams = checkNotNull(streams);
this.configProps = checkNotNull(configProps);
this.sourceContext = checkNotNull(sourceContext);
Expand All @@ -228,7 +238,8 @@ protected KinesisDataFetcher(List<String> streams,
this.indexOfThisConsumerSubtask = runtimeContext.getIndexOfThisSubtask();
this.deserializationSchema = checkNotNull(deserializationSchema);
this.shardAssigner = checkNotNull(shardAssigner);
this.kinesis = checkNotNull(kinesis);
this.kinesisProxyFactory = checkNotNull(kinesisProxyFactory);
this.kinesis = kinesisProxyFactory.create(configProps);

this.consumerMetricGroup = runtimeContext.getMetricGroup()
.addGroup(KinesisConsumerMetricConstants.KINESIS_CONSUMER_METRICS_GROUP);
Expand All @@ -241,6 +252,29 @@ protected KinesisDataFetcher(List<String> streams,
createShardConsumersThreadPool(runtimeContext.getTaskNameWithSubtasks());
}

/**
* Create a new shard consumer.
* Override this method to customize shard consumer behavior in subclasses.
* @param subscribedShardStateIndex the state index of the shard this consumer is subscribed to
* @param subscribedShard the shard this consumer is subscribed to
* @param lastSequenceNum the sequence number in the shard to start consuming
* @param shardMetricsReporter the reporter to report metrics to
* @return shard consumer
*/
protected ShardConsumer createShardConsumer(
Integer subscribedShardStateIndex,
StreamShardHandle subscribedShard,
SequenceNumber lastSequenceNum,
ShardMetricsReporter shardMetricsReporter) {
return new ShardConsumer<>(
this,
subscribedShardStateIndex,
subscribedShard,
lastSequenceNum,
this.kinesisProxyFactory.create(configProps),
shardMetricsReporter);
}

/**
* Starts the fetcher. After starting the fetcher, it can only
* be stopped by calling {@link KinesisDataFetcher#shutdownFetcher()}.
Expand Down Expand Up @@ -297,8 +331,7 @@ public void runFetcher() throws Exception {
}

shardConsumersExecutor.submit(
new ShardConsumer<>(
this,
createShardConsumer(
seededStateIndex,
subscribedShardsState.get(seededStateIndex).getStreamShardHandle(),
subscribedShardsState.get(seededStateIndex).getLastProcessedSequenceNum(),
Expand Down Expand Up @@ -344,8 +377,7 @@ public void runFetcher() throws Exception {
}

shardConsumersExecutor.submit(
new ShardConsumer<>(
this,
createShardConsumer(
newStateIndex,
newShardState.getStreamShardHandle(),
newShardState.getLastProcessedSequenceNum(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle;
import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxy;
import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;

Expand Down Expand Up @@ -87,28 +86,15 @@ public class ShardConsumer<T> implements Runnable {
* @param subscribedShardStateIndex the state index of the shard this consumer is subscribed to
* @param subscribedShard the shard this consumer is subscribed to
* @param lastSequenceNum the sequence number in the shard to start consuming
* @param kinesis the proxy instance to interact with Kinesis
* @param shardMetricsReporter the reporter to report metrics to
*/
public ShardConsumer(KinesisDataFetcher<T> fetcherRef,
Integer subscribedShardStateIndex,
StreamShardHandle subscribedShard,
SequenceNumber lastSequenceNum,
KinesisProxyInterface kinesis,
ShardMetricsReporter shardMetricsReporter) {
this(fetcherRef,
subscribedShardStateIndex,
subscribedShard,
lastSequenceNum,
KinesisProxy.create(fetcherRef.getConsumerConfiguration()),
shardMetricsReporter);
}

/** This constructor is exposed for testing purposes. */
protected ShardConsumer(KinesisDataFetcher<T> fetcherRef,
Integer subscribedShardStateIndex,
StreamShardHandle subscribedShard,
SequenceNumber lastSequenceNum,
KinesisProxyInterface kinesis,
ShardMetricsReporter shardMetricsReporter) {
this.fetcherRef = checkNotNull(fetcherRef);
this.subscribedShardStateIndex = checkNotNull(subscribedShardStateIndex);
this.subscribedShard = checkNotNull(subscribedShard);
Expand Down Expand Up @@ -152,62 +138,73 @@ protected ShardConsumer(KinesisDataFetcher<T> fetcherRef,
}
}

@SuppressWarnings("unchecked")
@Override
public void run() {
/**
* Find the initial shard iterator to start getting records from.
* @return shard iterator
* @throws Exception
*/
protected String getInitialShardIterator() throws Exception {
String nextShardItr;

try {
// before infinitely looping, we set the initial nextShardItr appropriately
// before infinitely looping, we set the initial nextShardItr appropriately

if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_LATEST_SEQUENCE_NUM.get())) {
// if the shard is already closed, there will be no latest next record to get for this shard
if (subscribedShard.isClosed()) {
nextShardItr = null;
} else {
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.LATEST.toString(), null);
}
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get())) {
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.TRIM_HORIZON.toString(), null);
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get())) {
if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_LATEST_SEQUENCE_NUM.get())) {
// if the shard is already closed, there will be no latest next record to get for this shard
if (subscribedShard.isClosed()) {
nextShardItr = null;
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_AT_TIMESTAMP_SEQUENCE_NUM.get())) {
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.AT_TIMESTAMP.toString(), initTimestamp);
} else {
// we will be starting from an actual sequence number (due to restore from failure).
// if the last sequence number refers to an aggregated record, we need to clean up any dangling sub-records
// from the last aggregated record; otherwise, we can simply start iterating from the record right after.

if (lastSequenceNum.isAggregated()) {
String itrForLastAggregatedRecord =
kinesis.getShardIterator(subscribedShard, ShardIteratorType.AT_SEQUENCE_NUMBER.toString(), lastSequenceNum.getSequenceNumber());

// get only the last aggregated record
GetRecordsResult getRecordsResult = getRecords(itrForLastAggregatedRecord, 1);

List<UserRecord> fetchedRecords = deaggregateRecords(
getRecordsResult.getRecords(),
subscribedShard.getShard().getHashKeyRange().getStartingHashKey(),
subscribedShard.getShard().getHashKeyRange().getEndingHashKey());

long lastSubSequenceNum = lastSequenceNum.getSubSequenceNumber();
for (UserRecord record : fetchedRecords) {
// we have found a dangling sub-record if it has a larger subsequence number
// than our last sequence number; if so, collect the record and update state
if (record.getSubSequenceNumber() > lastSubSequenceNum) {
deserializeRecordForCollectionAndUpdateState(record);
}
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.LATEST.toString(), null);
}
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get())) {
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.TRIM_HORIZON.toString(), null);
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get())) {
nextShardItr = null;
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_AT_TIMESTAMP_SEQUENCE_NUM.get())) {
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.AT_TIMESTAMP.toString(), initTimestamp);
} else {
// we will be starting from an actual sequence number (due to restore from failure).
// if the last sequence number refers to an aggregated record, we need to clean up any dangling sub-records
// from the last aggregated record; otherwise, we can simply start iterating from the record right after.

if (lastSequenceNum.isAggregated()) {
String itrForLastAggregatedRecord =
kinesis.getShardIterator(subscribedShard, ShardIteratorType.AT_SEQUENCE_NUMBER.toString(), lastSequenceNum.getSequenceNumber());

// get only the last aggregated record
GetRecordsResult getRecordsResult = getRecords(itrForLastAggregatedRecord, 1);

List<UserRecord> fetchedRecords = deaggregateRecords(
getRecordsResult.getRecords(),
subscribedShard.getShard().getHashKeyRange().getStartingHashKey(),
subscribedShard.getShard().getHashKeyRange().getEndingHashKey());

long lastSubSequenceNum = lastSequenceNum.getSubSequenceNumber();
for (UserRecord record : fetchedRecords) {
// we have found a dangling sub-record if it has a larger subsequence number
// than our last sequence number; if so, collect the record and update state
if (record.getSubSequenceNumber() > lastSubSequenceNum) {
deserializeRecordForCollectionAndUpdateState(record);
}

// set the nextShardItr so we can continue iterating in the next while loop
nextShardItr = getRecordsResult.getNextShardIterator();
} else {
// the last record was non-aggregated, so we can simply start from the next record
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), lastSequenceNum.getSequenceNumber());
}

// set the nextShardItr so we can continue iterating in the next while loop
nextShardItr = getRecordsResult.getNextShardIterator();
} else {
// the last record was non-aggregated, so we can simply start from the next record
nextShardItr = kinesis.getShardIterator(subscribedShard, ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), lastSequenceNum.getSequenceNumber());
}
}
return nextShardItr;
}

@SuppressWarnings("unchecked")
@Override
public void run() {
try {
String nextShardItr = getInitialShardIterator();

long processingStartTimeNanos = System.nanoTime();

while (isRunning()) {
if (nextShardItr == null) {
fetcherRef.updateState(subscribedShardStateIndex, SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
import com.amazonaws.regions.Regions;
import com.amazonaws.services.kinesis.AmazonKinesis;
import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.deser.BeanDeserializerFactory;
import com.fasterxml.jackson.databind.deser.BeanDeserializerModifier;
import com.fasterxml.jackson.databind.deser.DefaultDeserializationContext;
import com.fasterxml.jackson.databind.deser.DeserializerFactory;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;

import java.io.IOException;
import java.util.HashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import org.apache.flink.annotation.Internal;

/**
* Internal use.
*/
@Internal
public class TimeoutLatch {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public TestableKinesisDataFetcher(
thrownErrorUnderTest,
subscribedShardsStateUnderTest,
subscribedStreamsToLastDiscoveredShardIdsStateUnderTest,
fakeKinesis);
(properties) -> fakeKinesis);

this.runWaiter = new OneShotLatch();
this.initialDiscoveryWaiter = new OneShotLatch();
Expand Down

0 comments on commit dd4b3d1

Please sign in to comment.