Skip to content

Commit

Permalink
[hotfix][kafka] Extract TransactionalIdsGenerator class from FlinkKaf…
Browse files Browse the repository at this point in the history
…kaProducer011

This is pure refactor without any functional changes.
  • Loading branch information
pnowojski authored and aljoscha committed Nov 8, 2017
1 parent 6bce2b8 commit ab00d35
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.flink.streaming.api.functions.sink.TwoPhaseCommitSinkFunction;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.streaming.connectors.kafka.internal.FlinkKafkaProducer;
import org.apache.flink.streaming.connectors.kafka.internal.TransactionalIdsGenerator;
import org.apache.flink.streaming.connectors.kafka.internal.metrics.KafkaMetricMuttableWrapper;
import org.apache.flink.streaming.connectors.kafka.partitioner.FlinkFixedPartitioner;
import org.apache.flink.streaming.connectors.kafka.partitioner.FlinkKafkaDelegatePartitioner;
Expand All @@ -59,7 +60,6 @@
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.errors.InvalidTxnStateException;
import org.apache.kafka.common.serialization.ByteArraySerializer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -81,8 +81,6 @@
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.LongStream;
import java.util.stream.Stream;

import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
Expand Down Expand Up @@ -182,6 +180,11 @@ public enum Semantic {
*/
private transient ListState<NextTransactionalIdHint> nextTransactionalIdHintState;

/**
* Generator for Transactional IDs.
*/
private transient TransactionalIdsGenerator transactionalIdsGenerator;

/**
* Hint for picking next transactional id.
*/
Expand Down Expand Up @@ -785,6 +788,11 @@ public void initializeState(FunctionInitializationContext context) throws Except

nextTransactionalIdHintState = context.getOperatorStateStore().getUnionListState(
NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR);
transactionalIdsGenerator = new TransactionalIdsGenerator(
getRuntimeContext().getTaskName(),
getRuntimeContext().getIndexOfThisSubtask(),
kafkaProducersPoolSize,
SAFE_SCALE_DOWN_FACTOR);

if (semantic != Semantic.EXACTLY_ONCE) {
nextTransactionalIdHint = null;
Expand All @@ -803,15 +811,8 @@ public void initializeState(FunctionInitializationContext context) throws Except
// (1) the first execution of this application
// (2) previous execution has failed before first checkpoint completed
//
// in case of (2) we have to abort all previous transactions, but we don't know was the parallelism used
// then, so we must guess using current configured pool size, current parallelism and
// SAFE_SCALE_DOWN_FACTOR
long abortTransactionalIdStart = getRuntimeContext().getIndexOfThisSubtask();
long abortTransactionalIdEnd = abortTransactionalIdStart + 1;

abortTransactionalIdStart *= kafkaProducersPoolSize * SAFE_SCALE_DOWN_FACTOR;
abortTransactionalIdEnd *= kafkaProducersPoolSize * SAFE_SCALE_DOWN_FACTOR;
abortTransactions(LongStream.range(abortTransactionalIdStart, abortTransactionalIdEnd));
// in case of (2) we have to abort all previous transactions
abortTransactions(transactionalIdsGenerator.generateIdsToAbort());
} else {
nextTransactionalIdHint = transactionalIdHints.get(0);
}
Expand All @@ -834,16 +835,7 @@ protected Optional<KafkaTransactionContext> initializeUserContext() {
private Set<String> generateNewTransactionalIds() {
checkState(nextTransactionalIdHint != null, "nextTransactionalIdHint must be present for EXACTLY_ONCE");

// range of available transactional ids is:
// [nextFreeTransactionalId, nextFreeTransactionalId + parallelism * kafkaProducersPoolSize)
// loop below picks in a deterministic way a subrange of those available transactional ids based on index of
// this subtask
int subtaskId = getRuntimeContext().getIndexOfThisSubtask();
Set<String> transactionalIds = new HashSet<>();
for (int i = 0; i < kafkaProducersPoolSize; i++) {
long transactionalId = nextTransactionalIdHint.nextFreeTransactionalId + subtaskId * kafkaProducersPoolSize + i;
transactionalIds.add(generateTransactionalId(transactionalId));
}
Set<String> transactionalIds = transactionalIdsGenerator.generateIdsToUse(nextTransactionalIdHint.nextFreeTransactionalId);
LOG.info("Generated new transactionalIds {}", transactionalIds);
return transactionalIds;
}
Expand All @@ -862,7 +854,7 @@ private void cleanUpUserContext() {
if (!getUserContext().isPresent()) {
return;
}
abortTransactions(getUserContext().get().transactionalIds.stream());
abortTransactions(getUserContext().get().transactionalIds);
}

private void resetAvailableTransactionalIdsPool(Collection<String> transactionalIds) {
Expand All @@ -874,22 +866,13 @@ private void resetAvailableTransactionalIdsPool(Collection<String> transactional

// ----------------------------------- Utilities --------------------------

private void abortTransactions(LongStream transactionalIds) {
abortTransactions(transactionalIds.mapToObj(this::generateTransactionalId));
}

private void abortTransactions(Stream<String> transactionalIds) {
transactionalIds.forEach(transactionalId -> {
private void abortTransactions(Set<String> transactionalIds) {
for (String transactionalId : transactionalIds) {
try (FlinkKafkaProducer<byte[], byte[]> kafkaProducer =
initTransactionalProducer(transactionalId, false)) {
kafkaProducer.initTransactions();
}
});
}

private String generateTransactionalId(long transactionalId) {
String transactionalIdFormat = getRuntimeContext().getTaskName() + "-%d";
return String.format(transactionalIdFormat, transactionalId);
}
}

int getTransactionCoordinatorId() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.connectors.kafka.internal;

import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.LongStream;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
* Class responsible for generating transactional ids to use when communicating with Kafka.
*/
public class TransactionalIdsGenerator {
private final String prefix;
private final int subtaskIndex;
private final int poolSize;
private final int safeScaleDownFactor;

public TransactionalIdsGenerator(
String prefix,
int subtaskIndex,
int poolSize,
int safeScaleDownFactor) {
this.prefix = checkNotNull(prefix);
this.subtaskIndex = subtaskIndex;
this.poolSize = poolSize;
this.safeScaleDownFactor = safeScaleDownFactor;
}

/**
* Range of available transactional ids to use is:
* [nextFreeTransactionalId, nextFreeTransactionalId + parallelism * kafkaProducersPoolSize)
* loop below picks in a deterministic way a subrange of those available transactional ids based on index of
* this subtask.
*/
public Set<String> generateIdsToUse(long nextFreeTransactionalId) {
Set<String> transactionalIds = new HashSet<>();
for (int i = 0; i < poolSize; i++) {
long transactionalId = nextFreeTransactionalId + subtaskIndex * poolSize + i;
transactionalIds.add(generateTransactionalId(transactionalId));
}
return transactionalIds;
}

/**
* If we have to abort previous transactional id in case of restart after a failure BEFORE first checkpoint
* completed, we don't know what was the parallelism used in previous attempt. In that case we must guess the ids
* range to abort based on current configured pool size, current parallelism and safeScaleDownFactor.
*/
public Set<String> generateIdsToAbort() {
long abortTransactionalIdStart = subtaskIndex;
long abortTransactionalIdEnd = abortTransactionalIdStart + 1;

abortTransactionalIdStart *= poolSize * safeScaleDownFactor;
abortTransactionalIdEnd *= poolSize * safeScaleDownFactor;
return LongStream.range(abortTransactionalIdStart, abortTransactionalIdEnd)
.mapToObj(this::generateTransactionalId)
.collect(Collectors.toSet());
}

private String generateTransactionalId(long transactionalId) {
return String.format(prefix + "-%d", transactionalId);
}
}

0 comments on commit ab00d35

Please sign in to comment.