diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java index a49fdd26ff79c..9ea3198a147f0 100644 --- a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java @@ -25,6 +25,7 @@ import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.contrib.streaming.state.PredefinedOptions; import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend; +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; import org.apache.flink.queryablestate.client.VoidNamespace; import org.apache.flink.queryablestate.client.VoidNamespaceSerializer; import org.apache.flink.runtime.query.TaskKvStateRegistry; @@ -74,9 +75,12 @@ public void testListSerialization() throws Exception { columnFamilyOptions, mock(TaskKvStateRegistry.class), LongSerializer.INSTANCE, - 1, new KeyGroupRange(0, 0), - new ExecutionConfig(), false, - TestLocalRecoveryConfig.disabled() + 1, + new KeyGroupRange(0, 0), + new ExecutionConfig(), + false, + TestLocalRecoveryConfig.disabled(), + RocksDBStateBackend.PriorityQueueStateType.HEAP ); longHeapKeyedStateBackend.restore(null); longHeapKeyedStateBackend.setCurrentKey(key); @@ -112,10 +116,12 @@ public void testMapSerialization() throws Exception { columnFamilyOptions, mock(TaskKvStateRegistry.class), LongSerializer.INSTANCE, - 1, new KeyGroupRange(0, 0), + 1, + new KeyGroupRange(0, 0), new ExecutionConfig(), false, - TestLocalRecoveryConfig.disabled()); + TestLocalRecoveryConfig.disabled(), + RocksDBStateBackend.PriorityQueueStateType.HEAP); longHeapKeyedStateBackend.restore(null); longHeapKeyedStateBackend.setCurrentKey(key); diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java index 2ba7507457c14..73f88319d402a 100644 --- a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.TestLocalRecoveryConfig; import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory; import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.state.internal.InternalListState; import org.apache.flink.runtime.state.internal.InternalMapState; @@ -185,18 +186,19 @@ public void testDeserializeValueTooMany2() throws Exception { @Test public void testListSerialization() throws Exception { final long key = 0L; - + final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 0); // objects for heap state list serialisation final HeapKeyedStateBackend longHeapKeyedStateBackend = new HeapKeyedStateBackend<>( mock(TaskKvStateRegistry.class), LongSerializer.INSTANCE, ClassLoader.getSystemClassLoader(), - 1, - new KeyGroupRange(0, 0), + keyGroupRange.getNumberOfKeyGroups(), + keyGroupRange, async, new ExecutionConfig(), - TestLocalRecoveryConfig.disabled() + TestLocalRecoveryConfig.disabled(), + new HeapPriorityQueueSetFactory(keyGroupRange, keyGroupRange.getNumberOfKeyGroups(), 128) ); longHeapKeyedStateBackend.setCurrentKey(key); @@ -292,18 +294,19 @@ public void testDeserializeListTooShort2() throws Exception { @Test public void testMapSerialization() throws Exception { final long key = 0L; - + final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 0); // objects for heap state list serialisation final HeapKeyedStateBackend longHeapKeyedStateBackend = new HeapKeyedStateBackend<>( mock(TaskKvStateRegistry.class), LongSerializer.INSTANCE, ClassLoader.getSystemClassLoader(), - 1, - new KeyGroupRange(0, 0), + keyGroupRange.getNumberOfKeyGroups(), + keyGroupRange, async, new ExecutionConfig(), - TestLocalRecoveryConfig.disabled() + TestLocalRecoveryConfig.disabled(), + new HeapPriorityQueueSetFactory(keyGroupRange, keyGroupRange.getNumberOfKeyGroups(), 128) ); longHeapKeyedStateBackend.setCurrentKey(key); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java index fb3ee82f9849e..dc46c8ab6373f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java @@ -26,6 +26,8 @@ import javax.annotation.Nullable; import java.util.Collection; +import java.util.function.Consumer; +import java.util.function.Predicate; /** * Interface for collection that gives in order access to elements w.r.t their priority. @@ -35,6 +37,16 @@ @Internal public interface InternalPriorityQueue { + /** + * Polls from the top of the queue as long as the the queue is not empty and passes the elements to + * {@link Consumer} until a {@link Predicate} rejects an offered element. The rejected element is not + * removed from the queue and becomes the new head. + * + * @param canConsume bulk polling ends once this returns false. The rejected element is nor removed and not consumed. + * @param consumer consumer function for elements accepted by canConsume. + */ + void bulkPoll(@Nonnull Predicate canConsume, @Nonnull Consumer consumer); + /** * Retrieves and removes the first element (w.r.t. the order) of this set, * or returns {@code null} if this set is empty. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupedInternalPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupedInternalPriorityQueue.java new file mode 100644 index 0000000000000..68472e2ef3d6f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupedInternalPriorityQueue.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import javax.annotation.Nonnull; + +import java.util.Set; + +/** + * This interface exists as (temporary) adapter between the new {@link InternalPriorityQueue} and the old way in which + * timers are written in a snapshot. This interface can probably go away once timer state becomes part of the + * keyed state backend snapshot. + */ +public interface KeyGroupedInternalPriorityQueue extends InternalPriorityQueue { + + /** + * Returns the subset of elements in the priority queue that belongs to the given key-group, within the operator's + * key-group range. + */ + @Nonnull + Set getSubsetForKeyGroup(int keyGroupId); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java index ad75a1f86c0df..7ba14b3d0070a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java @@ -31,7 +31,8 @@ * * @param The key by which state is keyed. */ -public interface KeyedStateBackend extends InternalKeyContext, KeyedStateFactory, Disposable { +public interface KeyedStateBackend + extends InternalKeyContext, KeyedStateFactory, PriorityQueueSetFactory, Disposable { /** * Sets the current key that is used for partitioned state. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityComparator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityComparator.java new file mode 100644 index 0000000000000..2f6f5a790d87a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityComparator.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +/** + * This interface works similar to {@link Comparable} and is used to prioritize between two objects. The main difference + * between this interface and {@link Comparable} is it is not require to follow the usual contract between that + * {@link Comparable#compareTo(Object)} and {@link Object#equals(Object)}. The contract of this interface is: + * When two objects are equal, they indicate the same priority, but indicating the same priority does not require that + * both objects are equal. + * + * @param type of the compared objects. + */ +@FunctionalInterface +public interface PriorityComparator { + + /** + * Compares two objects for priority. Returns a negative integer, zero, or a positive integer as the first + * argument has lower, equal to, or higher priority than the second. + * @param left left operand in the comparison by priority. + * @param right left operand in the comparison by priority. + * @return a negative integer, zero, or a positive integer as the first argument has lower, equal to, or higher + * priority than the second. + */ + int comparePriority(T left, T right); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java new file mode 100644 index 0000000000000..6f509c092805d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement; + +import javax.annotation.Nonnull; + +/** + * Factory for {@link KeyGroupedInternalPriorityQueue} instances. + */ +public interface PriorityQueueSetFactory { + + /** + * Creates a {@link KeyGroupedInternalPriorityQueue}. + * + * @param stateName unique name for associated with this queue. + * @param byteOrderedElementSerializer a serializer that with a format that is lexicographically ordered in + * alignment with elementPriorityComparator. + * @param type of the stored elements. + * @return the queue with the specified unique name. + */ + @Nonnull + KeyGroupedInternalPriorityQueue create( + @Nonnull String stateName, + @Nonnull TypeSerializer byteOrderedElementSerializer, + @Nonnull PriorityComparator elementPriorityComparator, + @Nonnull KeyExtractorFunction keyExtractor); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TieBreakingPriorityComparator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TieBreakingPriorityComparator.java new file mode 100644 index 0000000000000..4384eb7c816cb --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TieBreakingPriorityComparator.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.util.FlinkRuntimeException; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.util.Comparator; + +/** + * This class is an adapter between {@link PriorityComparator} and a full {@link Comparator} that respects the + * contract between {@link Comparator#compare(Object, Object)} and {@link Object#equals(Object)}. This is currently + * needed for implementations of + * {@link org.apache.flink.runtime.state.heap.CachingInternalPriorityQueueSet.OrderedSetCache} that are implemented + * on top of a data structure that relies on the this contract, e.g. a tree set. We should replace this in the near + * future. + * + * @param type of the compared elements. + */ +public class TieBreakingPriorityComparator implements Comparator, PriorityComparator { + + /** The {@link PriorityComparator} to which we delegate in a first step. */ + @Nonnull + private final PriorityComparator priorityComparator; + + /** Serializer for instances of the compared objects. */ + @Nonnull + private final TypeSerializer serializer; + + /** Stream that we use in serialization. */ + @Nonnull + private final ByteArrayOutputStreamWithPos outStream; + + /** {@link org.apache.flink.core.memory.DataOutputView} around outStream. */ + @Nonnull + private final DataOutputViewStreamWrapper outView; + + public TieBreakingPriorityComparator( + @Nonnull PriorityComparator priorityComparator, + @Nonnull TypeSerializer serializer, + @Nonnull ByteArrayOutputStreamWithPos outStream, + @Nonnull DataOutputViewStreamWrapper outView) { + + this.priorityComparator = priorityComparator; + this.serializer = serializer; + this.outStream = outStream; + this.outView = outView; + } + + @SuppressWarnings("unchecked") + @Override + public int compare(T o1, T o2) { + + // first we compare priority, this should be the most commonly hit case + int cmp = priorityComparator.comparePriority(o1, o2); + + if (cmp != 0) { + return cmp; + } + + // here we start tie breaking and do our best to comply with the compareTo/equals contract, first we try + // to simply find an existing way to fully compare. + if (o1 instanceof Comparable && o1.getClass().equals(o2.getClass())) { + return ((Comparable) o1).compareTo(o2); + } + + // we catch this case before moving to more expensive tie breaks. + if (o1.equals(o2)) { + return 0; + } + + // if objects are not equal, their serialized form should somehow differ as well. this can be costly, and... + // TODO we should have an alternative approach in the future, e.g. a cache that does not rely on compare to check equality. + try { + outStream.reset(); + serializer.serialize(o1, outView); + int leftLen = outStream.getPosition(); + serializer.serialize(o2, outView); + int rightLen = outStream.getPosition() - leftLen; + return compareBytes(outStream.getBuf(), 0, leftLen, leftLen, rightLen); + } catch (IOException ex) { + throw new FlinkRuntimeException("Serializer problem in comparator.", ex); + } + } + + @Override + public int comparePriority(T left, T right) { + return priorityComparator.comparePriority(left, right); + } + + public static int compareBytes(byte[] bytes, int offLeft, int leftLen, int offRight, int rightLen) { + int maxLen = Math.min(leftLen, rightLen); + for (int i = 0; i < maxLen; ++i) { + int cmp = bytes[offLeft + i] - bytes[offRight + i]; + if (cmp != 0) { + return cmp; + } + } + return leftLen - rightLen; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java index 637effde06314..ad1581bf6ad1b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java @@ -36,6 +36,7 @@ import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory; import org.apache.flink.util.TernaryBoolean; import org.slf4j.LoggerFactory; @@ -457,6 +458,8 @@ public AbstractKeyedStateBackend createKeyedStateBackend( TaskStateManager taskStateManager = env.getTaskStateManager(); LocalRecoveryConfig localRecoveryConfig = taskStateManager.createLocalRecoveryConfig(); + HeapPriorityQueueSetFactory priorityQueueSetFactory = + new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128); return new HeapKeyedStateBackend<>( kvStateRegistry, @@ -466,7 +469,8 @@ public AbstractKeyedStateBackend createKeyedStateBackend( keyGroupRange, isUsingAsynchronousSnapshots(), env.getExecutionConfig(), - localRecoveryConfig); + localRecoveryConfig, + priorityQueueSetFactory); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java index 771315d3850da..6dc8cf3838d79 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java @@ -27,6 +27,8 @@ import javax.annotation.Nullable; import java.util.Collection; +import java.util.function.Consumer; +import java.util.function.Predicate; /** * This class is an implementation of a {@link InternalPriorityQueue} with set semantics that internally consists of @@ -76,6 +78,15 @@ public E peek() { return orderedCache.peekFirst(); } + @Override + public void bulkPoll(@Nonnull Predicate canConsume, @Nonnull Consumer consumer) { + E element; + while ((element = peek()) != null && canConsume.test(element)) { + poll(); + consumer.accept(element); + } + } + @Nullable @Override public E poll() { @@ -158,7 +169,11 @@ public boolean isEmpty() { @Nonnull @Override public CloseableIterator iterator() { - return orderedStore.orderedIterator(); + if (storeOnlyElements) { + return orderedStore.orderedIterator(); + } else { + return orderedCache.orderedIterator(); + } } @Override @@ -184,7 +199,7 @@ private void checkRefillCacheFromStore() { } storeOnlyElements = iterator.hasNext(); } catch (Exception e) { - throw new FlinkRuntimeException("Exception while closing RocksDB iterator.", e); + throw new FlinkRuntimeException("Exception while refilling store from iterator.", e); } } } @@ -249,6 +264,13 @@ public interface OrderedSetCache { */ @Nullable E peekLast(); + + /** + * Returns an iterator over the store that returns element in order. The iterator must be closed by the client + * after usage. + */ + @Nonnull + CloseableIterator orderedIterator(); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index 82ce5847627a1..b5b262639b05b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -44,13 +44,17 @@ import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider; import org.apache.flink.runtime.state.CheckpointedStateScope; import org.apache.flink.runtime.state.DoneFuture; +import org.apache.flink.runtime.state.KeyExtractorFunction; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedBackendSerializationProxy; import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.LocalRecoveryConfig; +import org.apache.flink.runtime.state.PriorityComparator; +import org.apache.flink.runtime.state.PriorityQueueSetFactory; import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo; import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator; import org.apache.flink.runtime.state.SnapshotResult; @@ -102,6 +106,21 @@ public class HeapKeyedStateBackend extends AbstractKeyedStateBackend { Tuple2.of(FoldingStateDescriptor.class, (StateFactory) HeapFoldingState::create) ).collect(Collectors.toMap(t -> t.f0, t -> t.f1)); + @Nonnull + @Override + public KeyGroupedInternalPriorityQueue create( + @Nonnull String stateName, + @Nonnull TypeSerializer byteOrderedElementSerializer, + @Nonnull PriorityComparator elementPriorityComparator, + @Nonnull KeyExtractorFunction keyExtractor) { + + return priorityQueueSetFactory.create( + stateName, + byteOrderedElementSerializer, + elementPriorityComparator, + keyExtractor); + } + private interface StateFactory { IS createState( StateDescriptor stateDesc, @@ -137,6 +156,11 @@ IS createState( */ private final HeapSnapshotStrategy snapshotStrategy; + /** + * Factory for state that is organized as priority queue. + */ + private final PriorityQueueSetFactory priorityQueueSetFactory; + public HeapKeyedStateBackend( TaskKvStateRegistry kvStateRegistry, TypeSerializer keySerializer, @@ -145,7 +169,8 @@ public HeapKeyedStateBackend( KeyGroupRange keyGroupRange, boolean asynchronousSnapshots, ExecutionConfig executionConfig, - LocalRecoveryConfig localRecoveryConfig) { + LocalRecoveryConfig localRecoveryConfig, + PriorityQueueSetFactory priorityQueueSetFactory) { super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange, executionConfig); this.localRecoveryConfig = Preconditions.checkNotNull(localRecoveryConfig); @@ -157,6 +182,7 @@ public HeapKeyedStateBackend( this.snapshotStrategy = new HeapSnapshotStrategy(synchronicityTrait); LOG.info("Initializing heap keyed state backend with stream factory."); this.restoredKvStateMetaInfos = new HashMap<>(); + this.priorityQueueSetFactory = priorityQueueSetFactory; } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java index 7017905d97c81..e5f610eab73dc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.state.heap; import org.apache.flink.runtime.state.InternalPriorityQueue; +import org.apache.flink.runtime.state.PriorityComparator; import org.apache.flink.util.CloseableIterator; import javax.annotation.Nonnegative; @@ -27,9 +28,10 @@ import java.util.Arrays; import java.util.Collection; -import java.util.Comparator; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.function.Consumer; +import java.util.function.Predicate; import static org.apache.flink.util.CollectionUtil.MAX_ARRAY_SIZE; @@ -56,9 +58,9 @@ public class HeapPriorityQueue implements In private static final int QUEUE_HEAD_INDEX = 1; /** - * Comparator for the contained elements. + * Comparator for the priority of contained elements. */ - private final Comparator elementComparator; + private final PriorityComparator elementPriorityComparator; /** * The array that represents the heap-organized priority queue. @@ -73,18 +75,27 @@ public class HeapPriorityQueue implements In /** * Creates an empty {@link HeapPriorityQueue} with the requested initial capacity. * - * @param elementComparator comparator for the contained elements. + * @param elementPriorityComparator comparator for the priority of contained elements. * @param minimumCapacity the minimum and initial capacity of this priority queue. */ @SuppressWarnings("unchecked") public HeapPriorityQueue( - @Nonnull Comparator elementComparator, + @Nonnull PriorityComparator elementPriorityComparator, @Nonnegative int minimumCapacity) { - this.elementComparator = elementComparator; + this.elementPriorityComparator = elementPriorityComparator; this.queue = (T[]) new HeapPriorityQueueElement[QUEUE_HEAD_INDEX + minimumCapacity]; } + @Override + public void bulkPoll(@Nonnull Predicate canConsume, @Nonnull Consumer consumer) { + T element; + while ((element = peek()) != null && canConsume.test(element)) { + poll(); + consumer.accept(element); + } + } + @Override @Nullable public T poll() { @@ -227,7 +238,7 @@ private void siftUp(int idx) { final T currentElement = heap[idx]; int parentIdx = idx >>> 1; - while (parentIdx > 0 && isElementLessThen(currentElement, heap[parentIdx])) { + while (parentIdx > 0 && isElementPriorityLessThen(currentElement, heap[parentIdx])) { moveElementToIdx(heap[parentIdx], idx); idx = parentIdx; parentIdx >>>= 1; @@ -245,19 +256,19 @@ private void siftDown(int idx) { int secondChildIdx = firstChildIdx + 1; if (isElementIndexValid(secondChildIdx, heapSize) && - isElementLessThen(heap[secondChildIdx], heap[firstChildIdx])) { + isElementPriorityLessThen(heap[secondChildIdx], heap[firstChildIdx])) { firstChildIdx = secondChildIdx; } while (isElementIndexValid(firstChildIdx, heapSize) && - isElementLessThen(heap[firstChildIdx], currentElement)) { + isElementPriorityLessThen(heap[firstChildIdx], currentElement)) { moveElementToIdx(heap[firstChildIdx], idx); idx = firstChildIdx; firstChildIdx = idx << 1; secondChildIdx = firstChildIdx + 1; if (isElementIndexValid(secondChildIdx, heapSize) && - isElementLessThen(heap[secondChildIdx], heap[firstChildIdx])) { + isElementPriorityLessThen(heap[secondChildIdx], heap[firstChildIdx])) { firstChildIdx = secondChildIdx; } } @@ -269,8 +280,8 @@ private boolean isElementIndexValid(int elementIndex, int heapSize) { return elementIndex <= heapSize; } - private boolean isElementLessThen(T a, T b) { - return elementComparator.compare(a, b) < 0; + private boolean isElementPriorityLessThen(T a, T b) { + return elementPriorityComparator.comparePriority(a, b) < 0; } private void moveElementToIdx(T element, int idx) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java index 61313e91152d1..79f319c8d8a56 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java @@ -18,20 +18,17 @@ package org.apache.flink.runtime.state.heap; -import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.state.KeyExtractorFunction; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; +import org.apache.flink.runtime.state.PriorityComparator; import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; -import java.util.List; import java.util.Set; import static org.apache.flink.util.Preconditions.checkArgument; @@ -49,7 +46,9 @@ * * @param type of the contained elements. */ -public class HeapPriorityQueueSet extends HeapPriorityQueue { +public class HeapPriorityQueueSet + extends HeapPriorityQueue + implements KeyGroupedInternalPriorityQueue { /** * Function to extract the key from contained elements. @@ -74,7 +73,7 @@ public class HeapPriorityQueueSet extends He /** * Creates an empty {@link HeapPriorityQueueSet} with the requested initial capacity. * - * @param elementComparator comparator for the contained elements. + * @param elementPriorityComparator comparator for the priority of contained elements. * @param keyExtractor function to extract a key from the contained elements. * @param minimumCapacity the minimum and initial capacity of this priority queue. * @param keyGroupRange the key-group range of the elements in this set. @@ -82,13 +81,13 @@ public class HeapPriorityQueueSet extends He */ @SuppressWarnings("unchecked") public HeapPriorityQueueSet( - @Nonnull Comparator elementComparator, + @Nonnull PriorityComparator elementPriorityComparator, @Nonnull KeyExtractorFunction keyExtractor, @Nonnegative int minimumCapacity, @Nonnull KeyGroupRange keyGroupRange, @Nonnegative int totalNumberOfKeyGroups) { - super(elementComparator, minimumCapacity); + super(elementPriorityComparator, minimumCapacity); this.keyExtractor = keyExtractor; @@ -147,28 +146,9 @@ public void clear() { } } - /** - * Returns an unmodifiable set of all elements in the given key-group. - */ - @Nonnull - public Set getElementsForKeyGroup(@Nonnegative int keyGroupIdx) { - return Collections.unmodifiableSet(getDedupMapForKeyGroup(keyGroupIdx).keySet()); - } - - @VisibleForTesting - @SuppressWarnings("unchecked") - @Nonnull - public List> getElementsByKeyGroup() { - List> result = new ArrayList<>(deduplicationMapsByKeyGroup.length); - for (int i = 0; i < deduplicationMapsByKeyGroup.length; ++i) { - result.add(i, Collections.unmodifiableSet(deduplicationMapsByKeyGroup[i].keySet())); - } - return result; - } - private HashMap getDedupMapForKeyGroup( - @Nonnegative int keyGroupIdx) { - return deduplicationMapsByKeyGroup[globalKeyGroupToLocalIndex(keyGroupIdx)]; + @Nonnegative int keyGroupId) { + return deduplicationMapsByKeyGroup[globalKeyGroupToLocalIndex(keyGroupId)]; } private HashMap getDedupMapForElement(T element) { @@ -182,4 +162,10 @@ private int globalKeyGroupToLocalIndex(int keyGroup) { checkArgument(keyGroupRange.contains(keyGroup)); return keyGroup - keyGroupRange.getStartKeyGroup(); } + + @Nonnull + @Override + public Set getSubsetForKeyGroup(int keyGroupId) { + return getDedupMapForKeyGroup(keyGroupId).keySet(); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java new file mode 100644 index 0000000000000..ee6fda90c0ba8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.KeyExtractorFunction; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; +import org.apache.flink.runtime.state.PriorityComparator; +import org.apache.flink.runtime.state.PriorityQueueSetFactory; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * + */ +public class HeapPriorityQueueSetFactory implements PriorityQueueSetFactory { + + @Nonnull + private final KeyGroupRange keyGroupRange; + + @Nonnegative + private final int totalKeyGroups; + + @Nonnegative + private final int minimumCapacity; + + public HeapPriorityQueueSetFactory( + @Nonnull KeyGroupRange keyGroupRange, + @Nonnegative int totalKeyGroups, + @Nonnegative int minimumCapacity) { + + this.keyGroupRange = keyGroupRange; + this.totalKeyGroups = totalKeyGroups; + this.minimumCapacity = minimumCapacity; + } + + @Nonnull + @Override + public KeyGroupedInternalPriorityQueue create( + @Nonnull String stateName, + @Nonnull TypeSerializer byteOrderedElementSerializer, + @Nonnull PriorityComparator elementPriorityComparator, + @Nonnull KeyExtractorFunction keyExtractor) { + return new HeapPriorityQueueSet<>( + elementPriorityComparator, + keyExtractor, + minimumCapacity, + keyGroupRange, + totalKeyGroups); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java index af4d54fee94cd..6f4f9110748d7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java @@ -22,7 +22,10 @@ import org.apache.flink.runtime.state.KeyExtractorFunction; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; +import org.apache.flink.runtime.state.PriorityComparator; import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.util.IOUtils; import javax.annotation.Nonnegative; @@ -30,7 +33,10 @@ import javax.annotation.Nullable; import java.util.Collection; -import java.util.Comparator; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Predicate; /** * This implementation of {@link InternalPriorityQueue} is internally partitioned into sub-queues per key-group and @@ -41,7 +47,7 @@ * @param type type of sub-queue used for each key-group partition. */ public class KeyGroupPartitionedPriorityQueue & HeapPriorityQueueElement> - implements InternalPriorityQueue { + implements InternalPriorityQueue, KeyGroupedInternalPriorityQueue { /** A heap of heap sets. Each sub-heap represents the partition for a key-group.*/ @Nonnull @@ -66,7 +72,7 @@ public class KeyGroupPartitionedPriorityQueue keyExtractor, - @Nonnull Comparator elementComparator, + @Nonnull PriorityComparator elementPriorityComparator, @Nonnull PartitionQueueSetFactory orderedCacheFactory, @Nonnull KeyGroupRange keyGroupRange, @Nonnegative int totalKeyGroups) { @@ -76,16 +82,25 @@ public KeyGroupPartitionedPriorityQueue( this.firstKeyGroup = keyGroupRange.getStartKeyGroup(); this.keyGroupedHeaps = (PQ[]) new InternalPriorityQueue[keyGroupRange.getNumberOfKeyGroups()]; this.heapOfkeyGroupedHeaps = new HeapPriorityQueue<>( - new InternalPriorityQueueComparator<>(elementComparator), + new InternalPriorityQueueComparator<>(elementPriorityComparator), keyGroupRange.getNumberOfKeyGroups()); for (int i = 0; i < keyGroupedHeaps.length; i++) { final PQ keyGroupSubHeap = - orderedCacheFactory.create(firstKeyGroup + i, totalKeyGroups, elementComparator); + orderedCacheFactory.create(firstKeyGroup + i, totalKeyGroups, elementPriorityComparator); keyGroupedHeaps[i] = keyGroupSubHeap; heapOfkeyGroupedHeaps.add(keyGroupSubHeap); } } + @Override + public void bulkPoll(@Nonnull Predicate canConsume, @Nonnull Consumer consumer) { + T element; + while ((element = peek()) != null && canConsume.test(element)) { + poll(); + consumer.accept(element); + } + } + @Nullable @Override public T poll() { @@ -173,9 +188,28 @@ private PQ getKeyGroupSubHeapForElement(T element) { private int computeKeyGroupIndex(T element) { final Object extractKeyFromElement = keyExtractor.extractKeyFromElement(element); final int keyGroupId = KeyGroupRangeAssignment.assignToKeyGroup(extractKeyFromElement, totalKeyGroups); + return globalKeyGroupToLocalIndex(keyGroupId); + } + + private int globalKeyGroupToLocalIndex(int keyGroupId) { return keyGroupId - firstKeyGroup; } + @Nonnull + @Override + public Set getSubsetForKeyGroup(int keyGroupId) { + HashSet result = new HashSet<>(); + PQ partitionQueue = keyGroupedHeaps[globalKeyGroupToLocalIndex(keyGroupId)]; + try (CloseableIterator iterator = partitionQueue.iterator()) { + while (iterator.hasNext()) { + result.add(iterator.next()); + } + } catch (Exception e) { + throw new FlinkRuntimeException("Exception while iterating key group.", e); + } + return result; + } + /** * Iterator for {@link KeyGroupPartitionedPriorityQueue}. This iterator is not guaranteeing any order of elements. * Using code must {@link #close()} after usage. @@ -236,24 +270,24 @@ public void close() throws Exception { * @param type of queue. */ private static final class InternalPriorityQueueComparator> - implements Comparator { + implements PriorityComparator { /** Comparator for the queue elements, so we can compare their heads. */ @Nonnull - private final Comparator elementComparator; + private final PriorityComparator elementPriorityComparator; - InternalPriorityQueueComparator(@Nonnull Comparator elementComparator) { - this.elementComparator = elementComparator; + InternalPriorityQueueComparator(@Nonnull PriorityComparator elementPriorityComparator) { + this.elementPriorityComparator = elementPriorityComparator; } @Override - public int compare(Q o1, Q o2) { + public int comparePriority(Q o1, Q o2) { final T left = o1.peek(); final T right = o2.peek(); if (left == null) { return (right == null ? 0 : 1); } else { - return (right == null ? -1 : elementComparator.compare(left, right)); + return (right == null ? -1 : elementPriorityComparator.comparePriority(left, right)); } } } @@ -271,10 +305,13 @@ public interface PartitionQueueSetFactory elementComparator); + PQS create( + @Nonnegative int keyGroupId, + @Nonnegative int numKeyGroups, + @Nonnull PriorityComparator elementPriorityComparator); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java index 0e7d9dd28a104..14c281effc910 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.state.heap; +import org.apache.flink.util.CloseableIterator; import org.apache.flink.util.Preconditions; import it.unimi.dsi.fastutil.objects.ObjectAVLTreeSet; @@ -125,4 +126,10 @@ public E peekFirst() { public E peekLast() { return !avlTree.isEmpty() ? avlTree.last() : null; } + + @Nonnull + @Override + public CloseableIterator orderedIterator() { + return CloseableIterator.adapterForIterator(avlTree.iterator()); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java index 3da60e4861f9f..d78944c78ec59 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.state.filesystem.AbstractFileStateBackend; import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory; import org.apache.flink.util.TernaryBoolean; import javax.annotation.Nullable; @@ -309,7 +310,8 @@ public AbstractKeyedStateBackend createKeyedStateBackend( TaskKvStateRegistry kvStateRegistry) { TaskStateManager taskStateManager = env.getTaskStateManager(); - + HeapPriorityQueueSetFactory priorityQueueSetFactory = + new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128); return new HeapKeyedStateBackend<>( kvStateRegistry, keySerializer, @@ -318,7 +320,8 @@ public AbstractKeyedStateBackend createKeyedStateBackend( keyGroupRange, isUsingAsynchronousSnapshots(), env.getExecutionConfig(), - taskStateManager.createLocalRecoveryConfig()); + taskStateManager.createLocalRecoveryConfig(), + priorityQueueSetFactory); } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java index c0c3ba4a2adf8..0cd551ca0e56b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java @@ -51,8 +51,16 @@ public abstract class InternalPriorityQueueTestBase extends TestLogger { protected static final KeyGroupRange KEY_GROUP_RANGE = new KeyGroupRange(0, 2); protected static final KeyExtractorFunction KEY_EXTRACTOR_FUNCTION = TestElement::getKey; - protected static final Comparator TEST_ELEMENT_COMPARATOR = - Comparator.comparingLong(TestElement::getPriority).thenComparingLong(TestElement::getKey); + protected static final PriorityComparator TEST_ELEMENT_PRIORITY_COMPARATOR = + (left, right) -> Long.compare(left.getPriority(), right.getPriority()); + protected static final Comparator TEST_ELEMENT_COMPARATOR = (o1, o2) -> { + int priorityCmp = TEST_ELEMENT_PRIORITY_COMPARATOR.comparePriority(o1, o2); + if (priorityCmp != 0) { + return priorityCmp; + } + // to fully comply with compareTo/equals contract. + return Long.compare(o1.getKey(), o2.getKey()); + }; protected static void insertRandomElements( @Nonnull InternalPriorityQueue priorityQueue, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java index 3c06b71a293d9..dfcdffc786912 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java @@ -53,7 +53,8 @@ public void testCompressionConfiguration() { new KeyGroupRange(0, 15), true, executionConfig, - TestLocalRecoveryConfig.disabled()); + TestLocalRecoveryConfig.disabled(), + mock(PriorityQueueSetFactory.class)); try { Assert.assertTrue( @@ -75,7 +76,8 @@ public void testCompressionConfiguration() { new KeyGroupRange(0, 15), true, executionConfig, - TestLocalRecoveryConfig.disabled()); + TestLocalRecoveryConfig.disabled(), + mock(PriorityQueueSetFactory.class)); try { Assert.assertTrue( @@ -115,7 +117,8 @@ private void snapshotRestoreRoundtrip(boolean useCompression) throws Exception { new KeyGroupRange(0, 15), true, executionConfig, - TestLocalRecoveryConfig.disabled()); + TestLocalRecoveryConfig.disabled(), + mock(PriorityQueueSetFactory.class)); try { @@ -156,7 +159,8 @@ private void snapshotRestoreRoundtrip(boolean useCompression) throws Exception { new KeyGroupRange(0, 15), true, executionConfig, - TestLocalRecoveryConfig.disabled()); + TestLocalRecoveryConfig.disabled(), + mock(PriorityQueueSetFactory.class)); try { stateBackend.restore(StateObjectCollection.singleton(stateHandle)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java index 618da4ea097ca..415497da38767 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java @@ -25,7 +25,7 @@ public class HeapPriorityQueueSetTest extends HeapPriorityQueueTest { @Override protected HeapPriorityQueueSet newPriorityQueue(int initialCapacity) { return new HeapPriorityQueueSet<>( - TEST_ELEMENT_COMPARATOR, + TEST_ELEMENT_PRIORITY_COMPARATOR, KEY_EXTRACTOR_FUNCTION, initialCapacity, KEY_GROUP_RANGE, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java index 8ffb8b8137c3a..6ba5a6835a644 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java @@ -89,7 +89,7 @@ public void testToArray() { @Override protected HeapPriorityQueue newPriorityQueue(int initialCapacity) { - return new HeapPriorityQueue<>(TEST_ELEMENT_COMPARATOR, initialCapacity); + return new HeapPriorityQueue<>(TEST_ELEMENT_PRIORITY_COMPARATOR, initialCapacity); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java index bf428dc0776af..cf6aef463aa58 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java @@ -49,14 +49,18 @@ public HeapKeyedStateBackend createKeyedBackend() throws Exception { } public HeapKeyedStateBackend createKeyedBackend(TypeSerializer keySerializer) throws Exception { + final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 15); + final int numKeyGroups = keyGroupRange.getNumberOfKeyGroups(); + return new HeapKeyedStateBackend<>( mock(TaskKvStateRegistry.class), keySerializer, HeapStateBackendTestBase.class.getClassLoader(), - 16, - new KeyGroupRange(0, 15), + numKeyGroups, + keyGroupRange, async, new ExecutionConfig(), - TestLocalRecoveryConfig.disabled()); + TestLocalRecoveryConfig.disabled(), + new HeapPriorityQueueSetFactory(keyGroupRange, numKeyGroups, 128)); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java index 277de19a6438b..d348e10458f23 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java @@ -29,7 +29,7 @@ public class KeyGroupPartitionedPriorityQueueTest extends InternalPriorityQueueT protected InternalPriorityQueue newPriorityQueue(int initialCapacity) { return new KeyGroupPartitionedPriorityQueue<>( KEY_EXTRACTOR_FUNCTION, - TEST_ELEMENT_COMPARATOR, + TEST_ELEMENT_PRIORITY_COMPARATOR, newFactory(initialCapacity), KEY_GROUP_RANGE, KEY_GROUP_RANGE.getNumberOfKeyGroups()); } diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RockDBBackendOptions.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RockDBBackendOptions.java new file mode 100644 index 0000000000000..ede45e368c8df --- /dev/null +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RockDBBackendOptions.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.configuration.ConfigOptions; + +/** + * Configuration options for the RocksDB backend. + */ +public class RockDBBackendOptions { + + /** + * Choice of implementation for priority queue state (e.g. timers). + */ + public static final ConfigOption PRIORITY_QUEUE_STATE_TYPE = ConfigOptions + .key("backend.rocksdb.priority_queue_state_type") + .defaultValue(RocksDBStateBackend.PriorityQueueStateType.HEAP.name()) + .withDescription("This determines the implementation for the priority queue state (e.g. timers). Options are" + + "either " + RocksDBStateBackend.PriorityQueueStateType.HEAP.name() + " (heap-based, default) or " + + RocksDBStateBackend.PriorityQueueStateType.ROCKS.name() + " for in implementation based on RocksDB."); +} diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java index 21d2a655593df..f2430ae19df4c 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java @@ -58,14 +58,18 @@ import org.apache.flink.runtime.state.DoneFuture; import org.apache.flink.runtime.state.IncrementalKeyedStateHandle; import org.apache.flink.runtime.state.IncrementalLocalKeyedStateHandle; +import org.apache.flink.runtime.state.KeyExtractorFunction; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedBackendSerializationProxy; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.LocalRecoveryConfig; import org.apache.flink.runtime.state.LocalRecoveryDirectoryProvider; import org.apache.flink.runtime.state.PlaceholderStreamStateHandle; +import org.apache.flink.runtime.state.PriorityComparator; +import org.apache.flink.runtime.state.PriorityQueueSetFactory; import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo; import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator; import org.apache.flink.runtime.state.SnapshotDirectory; @@ -76,7 +80,13 @@ import org.apache.flink.runtime.state.StateUtil; import org.apache.flink.runtime.state.StreamCompressionDecorator; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TieBreakingPriorityComparator; import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator; +import org.apache.flink.runtime.state.heap.CachingInternalPriorityQueueSet; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory; +import org.apache.flink.runtime.state.heap.KeyGroupPartitionedPriorityQueue; +import org.apache.flink.runtime.state.heap.TreeOrderedSetCache; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FileUtils; import org.apache.flink.util.FlinkRuntimeException; @@ -243,6 +253,9 @@ IS createState( /** The snapshot strategy, e.g., if we use full or incremental checkpoints, local state, and so on. */ private final SnapshotStrategy> snapshotStrategy; + /** Factory for priority queue state. */ + private PriorityQueueSetFactory priorityQueueFactory; + public RocksDBKeyedStateBackend( String operatorIdentifier, ClassLoader userCodeClassLoader, @@ -255,7 +268,8 @@ public RocksDBKeyedStateBackend( KeyGroupRange keyGroupRange, ExecutionConfig executionConfig, boolean enableIncrementalCheckpointing, - LocalRecoveryConfig localRecoveryConfig + LocalRecoveryConfig localRecoveryConfig, + RocksDBStateBackend.PriorityQueueStateType priorityQueueStateType ) throws IOException { super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange, executionConfig); @@ -296,6 +310,17 @@ public RocksDBKeyedStateBackend( this.writeOptions = new WriteOptions().setDisableWAL(true); + switch (priorityQueueStateType) { + case HEAP: + this.priorityQueueFactory = new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128); + break; + case ROCKS: + this.priorityQueueFactory = new RocksDBPriorityQueueSetFactory(); + break; + default: + break; + } + LOG.debug("Setting initial keyed backend uid for operator {} to {}.", this.operatorIdentifier, this.backendUID); } @@ -378,6 +403,11 @@ public void dispose() { IOUtils.closeQuietly(columnMetaData.f0); } + // ... then close the priority queue related resources ... + if (priorityQueueFactory instanceof AutoCloseable) { + IOUtils.closeQuietly((AutoCloseable) priorityQueueFactory); + } + // ... and finally close the DB instance ... IOUtils.closeQuietly(db); @@ -394,6 +424,17 @@ public void dispose() { } } + @Nonnull + @Override + public KeyGroupedInternalPriorityQueue create( + @Nonnull String stateName, + @Nonnull TypeSerializer byteOrderedElementSerializer, + @Nonnull PriorityComparator elementComparator, + @Nonnull KeyExtractorFunction keyExtractor) { + + return priorityQueueFactory.create(stateName, byteOrderedElementSerializer, elementComparator, keyExtractor); + } + private void cleanInstanceBasePath() { LOG.info("Deleting existing instance base directory {}.", instanceBasePath); @@ -1290,7 +1331,7 @@ private Tuple2 Tuple2 priorityQueueColumnFamilies; + + /** The mandatory default column family, so that we can close it later. */ + @Nonnull + private final ColumnFamilyHandle defaultColumnFamily; + + /** Path of the RocksDB instance that holds the priority queues. */ + @Nonnull + private final File pqInstanceRocksDBPath; + + /** RocksDB instance that holds the priority queues. */ + @Nonnull + private final RocksDB pqDb; + + RocksDBPriorityQueueSetFactory() throws IOException { + this.pqInstanceRocksDBPath = new File(instanceBasePath, "pqdb"); + if (pqInstanceRocksDBPath.exists()) { + try { + FileUtils.deleteDirectory(pqInstanceRocksDBPath); + } catch (IOException ex) { + LOG.warn("Could not delete instance path for PQ RocksDB: " + pqInstanceRocksDBPath, ex); + } + } + List columnFamilyHandles = new ArrayList<>(1); + this.pqDb = openDB(pqInstanceRocksDBPath.getAbsolutePath(), Collections.emptyList(), columnFamilyHandles); + this.elementSerializationOutStream = new ByteArrayOutputStreamWithPos(); + this.elementSerializationOutView = new DataOutputViewStreamWrapper(elementSerializationOutStream); + this.writeBatchWrapper = new RocksDBWriteBatchWrapper(pqDb, writeOptions); + this.defaultColumnFamily = columnFamilyHandles.get(0); + this.priorityQueueColumnFamilies = new HashMap<>(); + } + + @Nonnull + @Override + public KeyGroupedInternalPriorityQueue create( + @Nonnull String stateName, + @Nonnull TypeSerializer byteOrderedElementSerializer, + @Nonnull PriorityComparator elementPriorityComparator, + @Nonnull KeyExtractorFunction keyExtractor) { + + final ColumnFamilyHandle columnFamilyHandle = + priorityQueueColumnFamilies.computeIfAbsent( + stateName, + (name) -> RocksDBKeyedStateBackend.this.createColumnFamily(name, pqDb)); + + @Nonnull + TieBreakingPriorityComparator tieBreakingComparator = + new TieBreakingPriorityComparator<>( + elementPriorityComparator, + byteOrderedElementSerializer, + elementSerializationOutStream, + elementSerializationOutView); + + return new KeyGroupPartitionedPriorityQueue<>( + keyExtractor, + elementPriorityComparator, + new KeyGroupPartitionedPriorityQueue.PartitionQueueSetFactory>() { + @Nonnull + @Override + public CachingInternalPriorityQueueSet create( + int keyGroupId, + int numKeyGroups, + @Nonnull PriorityComparator elementPriorityComparator) { + + CachingInternalPriorityQueueSet.OrderedSetCache cache = + new TreeOrderedSetCache<>(tieBreakingComparator, DEFAULT_CACHES_SIZE); + CachingInternalPriorityQueueSet.OrderedSetStore store = + new RocksDBOrderedSetStore<>( + keyGroupId, + keyGroupPrefixBytes, + pqDb, + columnFamilyHandle, + byteOrderedElementSerializer, + elementSerializationOutStream, + elementSerializationOutView, + writeBatchWrapper); + + return new CachingInternalPriorityQueueSet<>(cache, store); + } + }, + keyGroupRange, + numberOfKeyGroups); + } + + @Override + public void close() { + IOUtils.closeQuietly(writeBatchWrapper); + for (ColumnFamilyHandle columnFamilyHandle : priorityQueueColumnFamilies.values()) { + IOUtils.closeQuietly(columnFamilyHandle); + } + IOUtils.closeQuietly(defaultColumnFamily); + IOUtils.closeQuietly(pqDb); + try { + FileUtils.deleteDirectory(pqInstanceRocksDBPath); + } catch (IOException ex) { + LOG.warn("Could not delete instance path for PQ RocksDB: " + pqInstanceRocksDBPath, ex); + } + } + } } diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java index e512933124e97..52843146f5548 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java @@ -28,7 +28,6 @@ import org.apache.flink.util.FlinkRuntimeException; import org.rocksdb.ColumnFamilyHandle; -import org.rocksdb.ReadOptions; import org.rocksdb.RocksDB; import org.rocksdb.RocksDBException; @@ -61,10 +60,6 @@ public class RocksDBOrderedSetStore implements CachingInternalPriorityQueueSe @Nonnull private final ColumnFamilyHandle columnFamilyHandle; - /** Read options for RocksDB. */ - @Nonnull - private final ReadOptions readOptions; - /** * Serializer for the contained elements. The lexicographical order of the bytes of serialized objects must be * aligned with their logical order. @@ -93,14 +88,12 @@ public RocksDBOrderedSetStore( @Nonnegative int keyGroupPrefixBytes, @Nonnull RocksDB db, @Nonnull ColumnFamilyHandle columnFamilyHandle, - @Nonnull ReadOptions readOptions, @Nonnull TypeSerializer byteOrderProducingSerializer, @Nonnull ByteArrayOutputStreamWithPos outputStream, @Nonnull DataOutputViewStreamWrapper outputView, @Nonnull RocksDBWriteBatchWrapper batchWrapper) { this.db = db; this.columnFamilyHandle = columnFamilyHandle; - this.readOptions = readOptions; this.byteOrderProducingSerializer = byteOrderProducingSerializer; this.outputStream = outputStream; this.outputView = outputView; @@ -169,7 +162,7 @@ public RocksToJavaIteratorAdapter orderedIterator() { return new RocksToJavaIteratorAdapter( new RocksIteratorWrapper( - db.newIterator(columnFamilyHandle, readOptions))); + db.newIterator(columnFamilyHandle))); } /** @@ -232,6 +225,10 @@ private class RocksToJavaIteratorAdapter implements CloseableIterator { private RocksToJavaIteratorAdapter(@Nonnull RocksIteratorWrapper iterator) { this.iterator = iterator; try { + // TODO we could check if it is more efficient to make the seek more specific, e.g. with a provided hint + // that is lexicographically closer the first expected element in the key-group. I wonder if this could + // help to improve the seek if there are many tombstones for elements at the beginning of the key-group + // (like for elements that have been removed in previous polling, before they are compacted away). iterator.seek(groupPrefixBytes); deserializeNextElementIfAvailable(); } catch (Exception ex) { diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java index 81d62653a40e4..998521b810af9 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java @@ -59,6 +59,7 @@ import java.util.Random; import java.util.UUID; +import static org.apache.flink.contrib.streaming.state.RockDBBackendOptions.PRIORITY_QUEUE_STATE_TYPE; import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -76,6 +77,14 @@ */ public class RocksDBStateBackend extends AbstractStateBackend implements ConfigurableStateBackend { + /** + * The options to chose for the type of priority queue state. + */ + public enum PriorityQueueStateType { + HEAP, + ROCKS + } + private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(RocksDBStateBackend.class); @@ -109,6 +118,9 @@ public class RocksDBStateBackend extends AbstractStateBackend implements Configu /** This determines if incremental checkpointing is enabled. */ private final TernaryBoolean enableIncrementalCheckpointing; + /** This determines the type of priority queue state. */ + private final PriorityQueueStateType priorityQueueStateType; + // -- runtime values, set on TaskManager when initializing / using the backend /** Base paths for RocksDB directory, as initialized. */ @@ -221,6 +233,8 @@ public RocksDBStateBackend(StateBackend checkpointStreamBackend) { public RocksDBStateBackend(StateBackend checkpointStreamBackend, TernaryBoolean enableIncrementalCheckpointing) { this.checkpointStreamBackend = checkNotNull(checkpointStreamBackend); this.enableIncrementalCheckpointing = enableIncrementalCheckpointing; + // for now, we use still the heap-based implementation as default + this.priorityQueueStateType = PriorityQueueStateType.HEAP; } /** @@ -256,6 +270,11 @@ private RocksDBStateBackend(RocksDBStateBackend original, Configuration config) this.enableIncrementalCheckpointing = original.enableIncrementalCheckpointing.resolveUndefined( config.getBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS)); + final String priorityQueueTypeString = config.getString(PRIORITY_QUEUE_STATE_TYPE.key(), ""); + + this.priorityQueueStateType = priorityQueueTypeString.length() > 0 ? + PriorityQueueStateType.valueOf(priorityQueueTypeString.toUpperCase()) : original.priorityQueueStateType; + // configure local directories if (original.localRocksDbDirectories != null) { this.localRocksDbDirectories = original.localRocksDbDirectories; @@ -422,7 +441,8 @@ public AbstractKeyedStateBackend createKeyedStateBackend( keyGroupRange, env.getExecutionConfig(), isIncrementalCheckpointsEnabled(), - localRecoveryConfig); + localRecoveryConfig, + priorityQueueStateType); } @Override diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java index ae20cf2071c77..5f26835282c74 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java @@ -57,7 +57,6 @@ public static CachingInternalPriorityQueueSet.OrderedSetStore creat prefixBytes, rocksDBResource.getRocksDB(), rocksDBResource.getDefaultColumnFamily(), - rocksDBResource.getReadOptions(), TestElementSerializer.INSTANCE, outputStream, outputView, diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java index 256a83b400a6e..0b1d07bd0823f 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java @@ -124,7 +124,6 @@ public static RocksDBOrderedSetStore createRocksDBOrderedStore( keyGroupPrefixBytes, rocksDBResource.getRocksDB(), rocksDBResource.getDefaultColumnFamily(), - rocksDBResource.getReadOptions(), byteOrderSerializer, outputStreamWithPos, outputView, diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java index ad895838b4950..69069d6e784b0 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java @@ -240,7 +240,8 @@ public void testCorrectMergeOperatorSet() throws IOException { new KeyGroupRange(0, 0), new ExecutionConfig(), enableIncrementalCheckpointing, - TestLocalRecoveryConfig.disabled()); + TestLocalRecoveryConfig.disabled(), + RocksDBStateBackend.PriorityQueueStateType.HEAP); verify(columnFamilyOptions, Mockito.times(1)) .setMergeOperatorName(RocksDBKeyedStateBackend.MERGE_OPERATOR_NAME); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 9915dd518168a..797a26a61b494 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -730,9 +730,11 @@ public InternalTimerService getInternalTimerService( checkTimerServiceInitialization(); // the following casting is to overcome type restrictions. - TypeSerializer keySerializer = (TypeSerializer) getKeyedStateBackend().getKeySerializer(); + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(); + TypeSerializer keySerializer = keyedStateBackend.getKeySerializer(); InternalTimeServiceManager keyedTimeServiceHandler = (InternalTimeServiceManager) timeServiceManager; - return keyedTimeServiceHandler.getInternalTimerService(name, keySerializer, namespaceSerializer, triggerable); + TimerSerializer timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer); + return keyedTimeServiceHandler.getInternalTimerService(name, timerSerializer, triggerable); } public void processWatermark(Watermark mark) throws Exception { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java index 7bf652f0c914d..6c1b1886b941b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java @@ -24,13 +24,15 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.state.InternalPriorityQueue; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.util.CloseableIterator; import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.util.Preconditions; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Set; import java.util.concurrent.ScheduledFuture; @@ -50,12 +52,12 @@ public class HeapInternalTimerService implements InternalTimerService, /** * Processing time timers that are currently in-flight. */ - private final HeapPriorityQueueSet> processingTimeTimersQueue; + private final KeyGroupedInternalPriorityQueue> processingTimeTimersQueue; /** * Event time timers that are currently in-flight. */ - private final HeapPriorityQueueSet> eventTimeTimersQueue; + private final KeyGroupedInternalPriorityQueue> eventTimeTimersQueue; /** * Information concerning the local key-group range. @@ -94,14 +96,17 @@ public class HeapInternalTimerService implements InternalTimerService, private InternalTimersSnapshot restoredTimersSnapshot; HeapInternalTimerService( - int totalKeyGroups, KeyGroupRange localKeyGroupRange, KeyContext keyContext, - ProcessingTimeService processingTimeService) { + ProcessingTimeService processingTimeService, + KeyGroupedInternalPriorityQueue> processingTimeTimersQueue, + KeyGroupedInternalPriorityQueue> eventTimeTimersQueue) { this.keyContext = checkNotNull(keyContext); this.processingTimeService = checkNotNull(processingTimeService); this.localKeyGroupRange = checkNotNull(localKeyGroupRange); + this.processingTimeTimersQueue = checkNotNull(processingTimeTimersQueue); + this.eventTimeTimersQueue = checkNotNull(eventTimeTimersQueue); // find the starting index of the local key-group range int startIdx = Integer.MAX_VALUE; @@ -109,9 +114,6 @@ public class HeapInternalTimerService implements InternalTimerService, startIdx = Math.min(keyGroupIdx, startIdx); } this.localKeyGroupRangeStartIdx = startIdx; - - this.eventTimeTimersQueue = createPriorityQueue(localKeyGroupRange, totalKeyGroups); - this.processingTimeTimersQueue = createPriorityQueue(localKeyGroupRange, totalKeyGroups); } /** @@ -225,16 +227,20 @@ public void onProcessingTime(long time) throws Exception { // inside the callback. nextTimer = null; - InternalTimer timer; - - while ((timer = processingTimeTimersQueue.peek()) != null && timer.getTimestamp() <= time) { - processingTimeTimersQueue.poll(); - keyContext.setCurrentKey(timer.getKey()); - triggerTarget.onProcessingTime(timer); - } + processingTimeTimersQueue.bulkPoll( + (timer) -> (timer.getTimestamp() <= time), + (timer) -> { + keyContext.setCurrentKey(timer.getKey()); + try { + triggerTarget.onProcessingTime(timer); + } catch (Exception e) { + throw new FlinkRuntimeException("Problem in trigger target.", e); + } + }); - if (timer != null) { - if (nextTimer == null) { + if (nextTimer == null) { + final TimerHeapInternalTimer timer = processingTimeTimersQueue.peek(); + if (timer != null) { nextTimer = processingTimeService.registerTimer(timer.getTimestamp(), this); } } @@ -242,14 +248,16 @@ public void onProcessingTime(long time) throws Exception { public void advanceWatermark(long time) throws Exception { currentWatermark = time; - - InternalTimer timer; - - while ((timer = eventTimeTimersQueue.peek()) != null && timer.getTimestamp() <= time) { - eventTimeTimersQueue.poll(); - keyContext.setCurrentKey(timer.getKey()); - triggerTarget.onEventTime(timer); - } + eventTimeTimersQueue.bulkPoll( + (timer) -> (timer.getTimestamp() <= time), + (timer) -> { + keyContext.setCurrentKey(timer.getKey()); + try { + triggerTarget.onEventTime(timer); + } catch (Exception e) { + throw new FlinkRuntimeException("Problem in trigger target.", e); + } + }); } /** @@ -264,8 +272,8 @@ public InternalTimersSnapshot snapshotTimersForKeyGroup(int keyGroupIdx) { keySerializer.snapshotConfiguration(), namespaceSerializer, namespaceSerializer.snapshotConfiguration(), - eventTimeTimersQueue.getElementsForKeyGroup(keyGroupIdx), - processingTimeTimersQueue.getElementsForKeyGroup(keyGroupIdx)); + eventTimeTimersQueue.getSubsetForKeyGroup(keyGroupIdx), + processingTimeTimersQueue.getSubsetForKeyGroup(keyGroupIdx)); } /** @@ -339,27 +347,24 @@ int getLocalKeyGroupRangeStartIdx() { @VisibleForTesting List>> getEventTimeTimersPerKeyGroup() { - return eventTimeTimersQueue.getElementsByKeyGroup(); + return partitionElementsByKeyGroup(eventTimeTimersQueue); } @VisibleForTesting List>> getProcessingTimeTimersPerKeyGroup() { - return processingTimeTimersQueue.getElementsByKeyGroup(); + return partitionElementsByKeyGroup(processingTimeTimersQueue); + } + + private List> partitionElementsByKeyGroup(KeyGroupedInternalPriorityQueue keyGroupedQueue) { + List> result = new ArrayList<>(localKeyGroupRange.getNumberOfKeyGroups()); + for (int keyGroup : localKeyGroupRange) { + result.add(Collections.unmodifiableSet(keyGroupedQueue.getSubsetForKeyGroup(keyGroup))); + } + return result; } private boolean areSnapshotSerializersIncompatible(InternalTimersSnapshot restoredSnapshot) { return (this.keyDeserializer != null && !this.keyDeserializer.equals(restoredSnapshot.getKeySerializer())) || (this.namespaceDeserializer != null && !this.namespaceDeserializer.equals(restoredSnapshot.getNamespaceSerializer())); } - - private static HeapPriorityQueueSet> createPriorityQueue( - KeyGroupRange localKeyGroupRange, - int totalKeyGroups) { - return new HeapPriorityQueueSet<>( - TimerHeapInternalTimer.getTimerComparator(), - TimerHeapInternalTimer.getKeyExtractorFunction(), - 128, - localKeyGroupRange, - totalKeyGroups); - } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java index e62883aed3fe0..ad1617e30949c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java @@ -20,16 +20,17 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; +import org.apache.flink.runtime.state.PriorityQueueSetFactory; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.util.Preconditions; import java.io.IOException; import java.io.InputStream; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -49,6 +50,7 @@ public class InternalTimeServiceManager { private final KeyGroupRange localKeyGroupRange; private final KeyContext keyContext; + private final PriorityQueueSetFactory priorityQueueSetFactory; private final ProcessingTimeService processingTimeService; private final Map> timerServices; @@ -57,52 +59,66 @@ public class InternalTimeServiceManager { int totalKeyGroups, KeyGroupRange localKeyGroupRange, KeyContext keyContext, + PriorityQueueSetFactory priorityQueueSetFactory, ProcessingTimeService processingTimeService) { Preconditions.checkArgument(totalKeyGroups > 0); this.totalKeyGroups = totalKeyGroups; this.localKeyGroupRange = Preconditions.checkNotNull(localKeyGroupRange); - + this.priorityQueueSetFactory = Preconditions.checkNotNull(priorityQueueSetFactory); this.keyContext = Preconditions.checkNotNull(keyContext); this.processingTimeService = Preconditions.checkNotNull(processingTimeService); this.timerServices = new HashMap<>(); } - /** - * Returns a {@link InternalTimerService} that can be used to query current processing time - * and event time and to set timers. An operator can have several timer services, where - * each has its own namespace serializer. Timer services are differentiated by the string - * key that is given when requesting them, if you call this method with the same key - * multiple times you will get the same timer service instance in subsequent requests. - * - *

Timers are always scoped to a key, the currently active key of a keyed stream operation. - * When a timer fires, this key will also be set as the currently active key. - * - *

Each timer has attached metadata, the namespace. Different timer services - * can have a different namespace type. If you don't need namespace differentiation you - * can use {@link VoidNamespaceSerializer} as the namespace serializer. - * - * @param name The name of the requested timer service. If no service exists under the given - * name a new one will be created and returned. - * @param keySerializer {@code TypeSerializer} for the timer keys. - * @param namespaceSerializer {@code TypeSerializer} for the timer namespace. - * @param triggerable The {@link Triggerable} that should be invoked when timers fire - */ @SuppressWarnings("unchecked") - public InternalTimerService getInternalTimerService(String name, TypeSerializer keySerializer, - TypeSerializer namespaceSerializer, Triggerable triggerable) { + public InternalTimerService getInternalTimerService( + String name, + TimerSerializer timerSerializer, + Triggerable triggerable) { + + HeapInternalTimerService timerService = registerOrGetTimerService(name, timerSerializer); + + timerService.startTimerService( + timerSerializer.getKeySerializer(), + timerSerializer.getNamespaceSerializer(), + triggerable); + + return timerService; + } + @SuppressWarnings("unchecked") + HeapInternalTimerService registerOrGetTimerService(String name, TimerSerializer timerSerializer) { HeapInternalTimerService timerService = (HeapInternalTimerService) timerServices.get(name); if (timerService == null) { - timerService = new HeapInternalTimerService<>(totalKeyGroups, - localKeyGroupRange, keyContext, processingTimeService); + + timerService = new HeapInternalTimerService<>( + localKeyGroupRange, + keyContext, + processingTimeService, + createTimerPriorityQueue("__ts_" + name + "/processing_timers", timerSerializer), + createTimerPriorityQueue("__ts_" + name + "/event_timers", timerSerializer)); + timerServices.put(name, timerService); } - timerService.startTimerService(keySerializer, namespaceSerializer, triggerable); return timerService; } + Map> getRegisteredTimerServices() { + return Collections.unmodifiableMap(timerServices); + } + + private KeyGroupedInternalPriorityQueue> createTimerPriorityQueue( + String name, + TimerSerializer timerSerializer) { + return priorityQueueSetFactory.create( + name, + timerSerializer, + InternalTimer.getTimerComparator(), + InternalTimer.getKeyExtractorFunction()); + } + public void advanceWatermark(Watermark watermark) throws Exception { for (HeapInternalTimerService service : timerServices.values()) { service.advanceWatermark(watermark.getTimestamp()); @@ -113,7 +129,7 @@ public void advanceWatermark(Watermark watermark) throws Exception { public void snapshotStateForKeyGroup(DataOutputView stream, int keyGroupIdx) throws IOException { InternalTimerServiceSerializationProxy serializationProxy = - new InternalTimerServiceSerializationProxy<>(timerServices, keyGroupIdx); + new InternalTimerServiceSerializationProxy<>(this, keyGroupIdx); serializationProxy.write(stream); } @@ -125,12 +141,8 @@ public void restoreStateForKeyGroup( InternalTimerServiceSerializationProxy serializationProxy = new InternalTimerServiceSerializationProxy<>( - timerServices, + this, userCodeClassLoader, - totalKeyGroups, - localKeyGroupRange, - keyContext, - processingTimeService, keyGroupIdx); serializationProxy.read(stream); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java index 5ba1a0facb084..f88b4fb2aa2b5 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java @@ -19,6 +19,10 @@ package org.apache.flink.streaming.api.operators; import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.state.KeyExtractorFunction; +import org.apache.flink.runtime.state.PriorityComparator; + +import javax.annotation.Nonnull; /** * Internal interface for in-flight timers. @@ -29,6 +33,12 @@ @Internal public interface InternalTimer { + /** Function to extract the key from a {@link InternalTimer}. */ + KeyExtractorFunction> KEY_EXTRACTOR_FUNCTION = InternalTimer::getKey; + + /** Function to compare instances of {@link InternalTimer}. */ + PriorityComparator> TIMER_COMPARATOR = + (left, right) -> Long.compare(left.getTimestamp(), right.getTimestamp()); /** * Returns the timestamp of the timer. This value determines the point in time when the timer will fire. */ @@ -37,10 +47,22 @@ public interface InternalTimer { /** * Returns the key that is bound to this timer. */ + @Nonnull K getKey(); /** * Returns the namespace that is bound to this timer. */ + @Nonnull N getNamespace(); + + @SuppressWarnings("unchecked") + static PriorityComparator getTimerComparator() { + return (PriorityComparator) TIMER_COMPARATOR; + } + + @SuppressWarnings("unchecked") + static KeyExtractorFunction getKeyExtractorFunction() { + return (KeyExtractorFunction) KEY_EXTRACTOR_FUNCTION; + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java index efa93d3e266a0..ce490b5dfc0f3 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java @@ -19,11 +19,10 @@ package org.apache.flink.streaming.api.operators; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.io.PostVersionedIOReadableWritable; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import java.io.IOException; import java.util.Map; @@ -39,36 +38,24 @@ public class InternalTimerServiceSerializationProxy extends PostVersionedIORe public static final int VERSION = 1; /** The key-group timer services to write / read. */ - private Map> timerServices; + private final InternalTimeServiceManager timerServicesManager; /** The user classloader; only relevant if the proxy is used to restore timer services. */ private ClassLoader userCodeClassLoader; /** Properties of restored timer services. */ - private int keyGroupIdx; - private int totalKeyGroups; - private KeyGroupRange localKeyGroupRange; - private KeyContext keyContext; - private ProcessingTimeService processingTimeService; + private final int keyGroupIdx; + /** * Constructor to use when restoring timer services. */ public InternalTimerServiceSerializationProxy( - Map> timerServicesMapToPopulate, - ClassLoader userCodeClassLoader, - int totalKeyGroups, - KeyGroupRange localKeyGroupRange, - KeyContext keyContext, - ProcessingTimeService processingTimeService, - int keyGroupIdx) { - - this.timerServices = checkNotNull(timerServicesMapToPopulate); + InternalTimeServiceManager timerServicesManager, + ClassLoader userCodeClassLoader, + int keyGroupIdx) { + this.timerServicesManager = checkNotNull(timerServicesManager); this.userCodeClassLoader = checkNotNull(userCodeClassLoader); - this.totalKeyGroups = totalKeyGroups; - this.localKeyGroupRange = checkNotNull(localKeyGroupRange); - this.keyContext = checkNotNull(keyContext); - this.processingTimeService = checkNotNull(processingTimeService); this.keyGroupIdx = keyGroupIdx; } @@ -76,10 +63,9 @@ public InternalTimerServiceSerializationProxy( * Constructor to use when writing timer services. */ public InternalTimerServiceSerializationProxy( - Map> timerServices, - int keyGroupIdx) { - - this.timerServices = checkNotNull(timerServices); + InternalTimeServiceManager timerServicesManager, + int keyGroupIdx) { + this.timerServicesManager = checkNotNull(timerServicesManager); this.keyGroupIdx = keyGroupIdx; } @@ -91,9 +77,11 @@ public int getVersion() { @Override public void write(DataOutputView out) throws IOException { super.write(out); + final Map> registeredTimerServices = + timerServicesManager.getRegisteredTimerServices(); - out.writeInt(timerServices.size()); - for (Map.Entry> entry : timerServices.entrySet()) { + out.writeInt(registeredTimerServices.size()); + for (Map.Entry> entry : registeredTimerServices.entrySet()) { String serviceName = entry.getKey(); HeapInternalTimerService timerService = entry.getValue(); @@ -111,22 +99,25 @@ protected void read(DataInputView in, boolean wasVersioned) throws IOException { for (int i = 0; i < noOfTimerServices; i++) { String serviceName = in.readUTF(); - HeapInternalTimerService timerService = timerServices.get(serviceName); - if (timerService == null) { - timerService = new HeapInternalTimerService<>( - totalKeyGroups, - localKeyGroupRange, - keyContext, - processingTimeService); - timerServices.put(serviceName, timerService); - } - int readerVersion = wasVersioned ? getReadVersion() : InternalTimersSnapshotReaderWriters.NO_VERSION; InternalTimersSnapshot restoredTimersSnapshot = InternalTimersSnapshotReaderWriters .getReaderForVersion(readerVersion, userCodeClassLoader) .readTimersSnapshot(in); + HeapInternalTimerService timerService = registerOrGetTimerService( + serviceName, + restoredTimersSnapshot); + timerService.restoreTimersForKeyGroup(restoredTimersSnapshot, keyGroupIdx); } } + + @SuppressWarnings("unchecked") + private HeapInternalTimerService registerOrGetTimerService( + String serviceName, InternalTimersSnapshot restoredTimersSnapshot) { + final TypeSerializer keySerializer = (TypeSerializer) restoredTimersSnapshot.getKeySerializer(); + final TypeSerializer namespaceSerializer = (TypeSerializer) restoredTimersSnapshot.getNamespaceSerializer(); + TimerSerializer timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer); + return timerServicesManager.registerOrGetTimerService(serviceName, timerSerializer); + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java index 578302ba1e5ee..594f337f41693 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java @@ -207,6 +207,7 @@ protected InternalTimeServiceManager internalTimeServiceManager( keyedStatedBackend.getNumberOfKeyGroups(), keyGroupRange, keyContext, + keyedStatedBackend, processingTimeService); // and then initialize the timer services diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java index bd821c47527f8..b9ef88edaf8c9 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java @@ -19,21 +19,18 @@ package org.apache.flink.streaming.api.operators; import org.apache.flink.annotation.Internal; -import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.typeutils.CompatibilityResult; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.runtime.state.KeyExtractorFunction; import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement; import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet; import javax.annotation.Nonnull; import java.io.IOException; -import java.util.Comparator; /** * Implementation of {@link InternalTimer} to use with a {@link HeapPriorityQueueSet}. @@ -44,14 +41,6 @@ @Internal public final class TimerHeapInternalTimer implements InternalTimer, HeapPriorityQueueElement { - /** Function to extract the key from a {@link TimerHeapInternalTimer}. */ - private static final KeyExtractorFunction> KEY_EXTRACTOR_FUNCTION = - TimerHeapInternalTimer::getKey; - - /** Function to compare instances of {@link TimerHeapInternalTimer}. */ - private static final Comparator> TIMER_COMPARATOR = - (o1, o2) -> Long.compare(o1.getTimestamp(), o2.getTimestamp()); - /** The key for which the timer is scoped. */ @Nonnull private final K key; @@ -144,18 +133,6 @@ public String toString() { '}'; } - @VisibleForTesting - @SuppressWarnings("unchecked") - static Comparator getTimerComparator() { - return (Comparator) TIMER_COMPARATOR; - } - - @SuppressWarnings("unchecked") - @VisibleForTesting - static KeyExtractorFunction getKeyExtractorFunction() { - return (KeyExtractorFunction) KEY_EXTRACTOR_FUNCTION; - } - /** * A {@link TypeSerializer} used to serialize/deserialize a {@link TimerHeapInternalTimer}. */ diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java new file mode 100644 index 0000000000000..87a3159cdb49c --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.api.common.typeutils.CompatibilityResult; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.MathUtils; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.util.Objects; + +/** + * A serializer for {@link TimerHeapInternalTimer} objects that produces a serialization format that is aligned with + * {@link InternalTimer#getTimerComparator()}. + * + * @param type of the timer key. + * @param type of the timer namespace. + */ +public class TimerSerializer extends TypeSerializer> { + + private static final long serialVersionUID = 1L; + + /** Serializer for the key. */ + @Nonnull + private final TypeSerializer keySerializer; + + /** Serializer for the namespace. */ + @Nonnull + private final TypeSerializer namespaceSerializer; + + /** The bytes written for one timer, or -1 if variable size. */ + private final int length; + + /** True iff the serialized type (and composite objects) are immutable. */ + private final boolean immutableType; + + TimerSerializer( + @Nonnull TypeSerializer keySerializer, + @Nonnull TypeSerializer namespaceSerializer) { + this( + keySerializer, + namespaceSerializer, + computeTotalByteLength(keySerializer, namespaceSerializer), + keySerializer.isImmutableType() && namespaceSerializer.isImmutableType()); + } + + private TimerSerializer( + @Nonnull TypeSerializer keySerializer, + @Nonnull TypeSerializer namespaceSerializer, + int length, + boolean immutableType) { + + this.keySerializer = keySerializer; + this.namespaceSerializer = namespaceSerializer; + this.length = length; + this.immutableType = immutableType; + } + + private static int computeTotalByteLength( + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer) { + if (keySerializer.getLength() >= 0 && namespaceSerializer.getLength() >= 0) { + // timestamp + key + namespace + return Long.BYTES + keySerializer.getLength() + namespaceSerializer.getLength(); + } else { + return -1; + } + } + + @Override + public boolean isImmutableType() { + return immutableType; + } + + @Override + public TimerSerializer duplicate() { + + final TypeSerializer keySerializerDuplicate = keySerializer.duplicate(); + final TypeSerializer namespaceSerializerDuplicate = namespaceSerializer.duplicate(); + + if (keySerializerDuplicate == keySerializer && + namespaceSerializerDuplicate == namespaceSerializer) { + // all delegate serializers seem stateless, so this is also stateless. + return this; + } else { + // at least one delegate serializer seems to be stateful, so we return a new instance. + return new TimerSerializer<>( + keySerializerDuplicate, + namespaceSerializerDuplicate, + length, + immutableType); + } + } + + @Override + public TimerHeapInternalTimer createInstance() { + return new TimerHeapInternalTimer<>( + 0L, + keySerializer.createInstance(), + namespaceSerializer.createInstance()); + } + + @Override + public TimerHeapInternalTimer copy(TimerHeapInternalTimer from) { + + K keyDuplicate; + N namespaceDuplicate; + if (isImmutableType()) { + keyDuplicate = from.getKey(); + namespaceDuplicate = from.getNamespace(); + } else { + keyDuplicate = keySerializer.copy(from.getKey()); + namespaceDuplicate = namespaceSerializer.copy(from.getNamespace()); + } + + return new TimerHeapInternalTimer<>(from.getTimestamp(), keyDuplicate, namespaceDuplicate); + } + + @Override + public TimerHeapInternalTimer copy(TimerHeapInternalTimer from, TimerHeapInternalTimer reuse) { + return copy(from); + } + + @Override + public int getLength() { + return length; + } + + @Override + public void serialize(TimerHeapInternalTimer record, DataOutputView target) throws IOException { + target.writeLong(MathUtils.flipSignBit(record.getTimestamp())); + keySerializer.serialize(record.getKey(), target); + namespaceSerializer.serialize(record.getNamespace(), target); + } + + @Override + public TimerHeapInternalTimer deserialize(DataInputView source) throws IOException { + long timestamp = MathUtils.flipSignBit(source.readLong()); + K key = keySerializer.deserialize(source); + N namespace = namespaceSerializer.deserialize(source); + return new TimerHeapInternalTimer<>(timestamp, key, namespace); + } + + @Override + public TimerHeapInternalTimer deserialize( + TimerHeapInternalTimer reuse, + DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + target.writeLong(source.readLong()); + keySerializer.copy(source, target); + namespaceSerializer.copy(source, target); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TimerSerializer that = (TimerSerializer) o; + return Objects.equals(keySerializer, that.keySerializer) && + Objects.equals(namespaceSerializer, that.namespaceSerializer); + } + + @Override + public int hashCode() { + return Objects.hash(keySerializer, namespaceSerializer); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof TimerSerializer; + } + + @Override + public TypeSerializerConfigSnapshot snapshotConfiguration() { + throw new UnsupportedOperationException("This serializer is currently not used to write state."); + } + + @Override + public CompatibilityResult> ensureCompatibility( + TypeSerializerConfigSnapshot configSnapshot) { + throw new UnsupportedOperationException("This serializer is currently not used to write state."); + } + + @Nonnull + public TypeSerializer getKeySerializer() { + return keySerializer; + } + + @Nonnull + public TypeSerializer getNamespaceSerializer() { + return namespaceSerializer; + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java index b008fa2fa1468..519f10e5fb54f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java @@ -18,12 +18,16 @@ package org.apache.flink.streaming.api.operators; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; +import org.apache.flink.runtime.state.PriorityQueueSetFactory; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService; @@ -85,12 +89,13 @@ public void testKeyGroupStartIndexSetting() { TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); - HeapInternalTimerService service = - new HeapInternalTimerService<>( - testKeyGroupList.getNumberOfKeyGroups(), - testKeyGroupList, - keyContext, - processingTimeService); + HeapInternalTimerService service = createInternalTimerService( + testKeyGroupList, + keyContext, + processingTimeService, + IntSerializer.INSTANCE, + StringSerializer.INSTANCE, + createQueueFactory()); Assert.assertEquals(startKeyGroupIdx, service.getLocalKeyGroupRangeStartIdx()); } @@ -105,14 +110,20 @@ public void testTimerAssignmentToKeyGroups() { @SuppressWarnings("unchecked") Set>[] expectedNonEmptyTimerSets = new HashSet[totalNoOfKeyGroups]; - TestKeyContext keyContext = new TestKeyContext(); - HeapInternalTimerService timerService = - new HeapInternalTimerService<>( - totalNoOfKeyGroups, - new KeyGroupRange(startKeyGroupIdx, endKeyGroupIdx), - keyContext, - new TestProcessingTimeService()); + + final KeyGroupRange keyGroupRange = new KeyGroupRange(startKeyGroupIdx, endKeyGroupIdx); + + final PriorityQueueSetFactory priorityQueueSetFactory = + createQueueFactory(keyGroupRange, totalNoOfKeyGroups); + + HeapInternalTimerService timerService = createInternalTimerService( + keyGroupRange, + keyContext, + new TestProcessingTimeService(), + IntSerializer.INSTANCE, + StringSerializer.INSTANCE, + priorityQueueSetFactory); timerService.startTimerService(IntSerializer.INSTANCE, StringSerializer.INSTANCE, mock(Triggerable.class)); @@ -169,9 +180,10 @@ public void testOnlySetsOnePhysicalProcessingTimeTimer() throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); - + PriorityQueueSetFactory priorityQueueSetFactory = + new HeapPriorityQueueSetFactory(testKeyGroupRange, maxParallelism, 128); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, priorityQueueSetFactory); int key = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); keyContext.setCurrentKey(key); @@ -233,7 +245,7 @@ public void testRegisterEarlierProcessingTimerMovesPhysicalProcessingTimer() thr TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); int key = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); @@ -266,7 +278,7 @@ public void testRegisteringProcessingTimeTimerInOnProcessingTimeDoesNotLeakPhysi TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); final HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); int key = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); @@ -317,7 +329,7 @@ public void testCurrentProcessingTime() throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); processingTimeService.setCurrentTime(17L); assertEquals(17, timerService.currentProcessingTime()); @@ -335,7 +347,7 @@ public void testCurrentEventTime() throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); timerService.advanceWatermark(17); assertEquals(17, timerService.currentWatermark()); @@ -355,7 +367,7 @@ public void testSetAndFireEventTimeTimers() throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); // get two different keys int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); @@ -400,7 +412,7 @@ public void testSetAndFireProcessingTimeTimers() throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); // get two different keys int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); @@ -447,7 +459,7 @@ public void testDeleteEventTimeTimers() throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); // get two different keys int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); @@ -504,7 +516,7 @@ public void testDeleteProcessingTimeTimers() throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); // get two different keys int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); @@ -579,7 +591,7 @@ private void testSnapshotAndRestore(int snapshotVersion) throws Exception { TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory()); // get two different keys int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism); @@ -631,7 +643,7 @@ private void testSnapshotAndRestore(int snapshotVersion) throws Exception { keyContext, processingTimeService, testKeyGroupRange, - maxParallelism); + createQueueFactory()); processingTimeService.setCurrentTime(10); timerService.advanceWatermark(10); @@ -652,8 +664,9 @@ private void testSnapshotAndRebalancingRestore(int snapshotVersion) throws Excep TestKeyContext keyContext = new TestKeyContext(); TestProcessingTimeService processingTimeService = new TestProcessingTimeService(); + final PriorityQueueSetFactory queueFactory = createQueueFactory(); HeapInternalTimerService timerService = - createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism); + createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, queueFactory); int midpoint = testKeyGroupRange.getStartKeyGroup() + (testKeyGroupRange.getEndKeyGroup() - testKeyGroupRange.getStartKeyGroup()) / 2; @@ -724,7 +737,7 @@ private void testSnapshotAndRebalancingRestore(int snapshotVersion) throws Excep keyContext1, processingTimeService1, subKeyGroupRange1, - maxParallelism); + queueFactory); HeapInternalTimerService timerService2 = restoreTimerService( snapshot2, @@ -733,7 +746,7 @@ private void testSnapshotAndRebalancingRestore(int snapshotVersion) throws Excep keyContext2, processingTimeService2, subKeyGroupRange2, - maxParallelism); + queueFactory); processingTimeService1.setCurrentTime(10); timerService1.advanceWatermark(10); @@ -793,18 +806,19 @@ private static int getKeyInKeyGroupRange(KeyGroupRange range, int maxParallelism return result; } - private static HeapInternalTimerService createTimerService( + private static HeapInternalTimerService createAndStartInternalTimerService( Triggerable triggerable, KeyContext keyContext, ProcessingTimeService processingTimeService, KeyGroupRange keyGroupList, - int maxParallelism) { - HeapInternalTimerService service = - new HeapInternalTimerService<>( - maxParallelism, - keyGroupList, - keyContext, - processingTimeService); + PriorityQueueSetFactory priorityQueueSetFactory) { + HeapInternalTimerService service = createInternalTimerService( + keyGroupList, + keyContext, + processingTimeService, + IntSerializer.INSTANCE, + StringSerializer.INSTANCE, + priorityQueueSetFactory); service.startTimerService(IntSerializer.INSTANCE, StringSerializer.INSTANCE, triggerable); return service; @@ -817,15 +831,16 @@ private static HeapInternalTimerService restoreTimerService( KeyContext keyContext, ProcessingTimeService processingTimeService, KeyGroupRange keyGroupsList, - int maxParallelism) throws Exception { + PriorityQueueSetFactory priorityQueueSetFactory) throws Exception { // create an empty service - HeapInternalTimerService service = - new HeapInternalTimerService<>( - maxParallelism, - keyGroupsList, - keyContext, - processingTimeService); + HeapInternalTimerService service = createInternalTimerService( + keyGroupsList, + keyContext, + processingTimeService, + IntSerializer.INSTANCE, + StringSerializer.INSTANCE, + priorityQueueSetFactory); // restore the timers for (Integer keyGroupIndex : keyGroupsList) { @@ -846,6 +861,14 @@ private static HeapInternalTimerService restoreTimerService( return service; } + private PriorityQueueSetFactory createQueueFactory() { + return createQueueFactory(testKeyGroupRange, maxParallelism); + } + + protected PriorityQueueSetFactory createQueueFactory(KeyGroupRange keyGroupRange, int numKeyGroups) { + return new HeapPriorityQueueSetFactory(keyGroupRange, numKeyGroups, 128); + } + // ------------------------------------------------------------------------ // Parametrization for testing with different key-group ranges // ------------------------------------------------------------------------ @@ -862,4 +885,33 @@ public static Collection keyRanges(){ {2, 5, 6} }); } + + private static HeapInternalTimerService createInternalTimerService( + KeyGroupRange keyGroupsList, + KeyContext keyContext, + ProcessingTimeService processingTimeService, + TypeSerializer keySerializer, + TypeSerializer namespaceSerializer, + PriorityQueueSetFactory priorityQueueSetFactory) { + + TimerSerializer timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer); + + return new HeapInternalTimerService<>( + keyGroupsList, + keyContext, + processingTimeService, + createTimerQueue("__test_processing_timers", timerSerializer, priorityQueueSetFactory), + createTimerQueue("__test_event_timers", timerSerializer, priorityQueueSetFactory)); + } + + private static KeyGroupedInternalPriorityQueue> createTimerQueue( + String name, + TimerSerializer timerSerializer, + PriorityQueueSetFactory priorityQueueSetFactory) { + return priorityQueueSetFactory.create( + name, + timerSerializer, + InternalTimer.getTimerComparator(), + InternalTimer.getKeyExtractorFunction()); + } }