Skip to content

Commit

Permalink
[FLINK-20491] Add preferred/pass-though inputs in MultiInputSortingDa…
Browse files Browse the repository at this point in the history
…taInput

This will allow processing the broadcast side of a broadcast operator
first, before processing the keyed side that requires sorting for
stateful BATCH execution.

For now, the wiring from the API is not there, this will be added in
follow-up changes.
  • Loading branch information
aljoscha committed Jan 7, 2021
1 parent e31b162 commit 00f8de7
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.operators.sort.ExternalSorter;
import org.apache.flink.runtime.operators.sort.PushSorter;
import org.apache.flink.streaming.api.operators.BoundedMultiInput;
import org.apache.flink.streaming.api.operators.InputSelectable;
import org.apache.flink.streaming.api.operators.InputSelection;
import org.apache.flink.streaming.api.watermark.Watermark;
Expand All @@ -48,8 +49,13 @@
import javax.annotation.Nonnull;

import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
Expand Down Expand Up @@ -107,29 +113,38 @@ private MultiInputSortingDataInput(
*/
public static class SelectableSortingInputs {
private final InputSelectable inputSelectable;
private final StreamTaskInput<?>[] sortingInputs;
private final StreamTaskInput<?>[] sortedInputs;
private final StreamTaskInput<?>[] passThroughInputs;

public SelectableSortingInputs(
StreamTaskInput<?>[] sortingInputs, InputSelectable inputSelectable) {
this.sortingInputs = sortingInputs;
StreamTaskInput<?>[] sortedInputs,
StreamTaskInput<?>[] passThroughInputs,
InputSelectable inputSelectable) {
this.sortedInputs = sortedInputs;
this.passThroughInputs = passThroughInputs;
this.inputSelectable = inputSelectable;
}

public InputSelectable getInputSelectable() {
return inputSelectable;
}

public StreamTaskInput<?>[] getSortingInputs() {
return sortingInputs;
public StreamTaskInput<?>[] getSortedInputs() {
return sortedInputs;
}

public StreamTaskInput<?>[] getPassThroughInputs() {
return passThroughInputs;
}
}

public static <K> SelectableSortingInputs wrapInputs(
AbstractInvokable containingTask,
StreamTaskInput<Object>[] inputs,
StreamTaskInput<Object>[] sortingInputs,
KeySelector<Object, K>[] keySelectors,
TypeSerializer<Object>[] inputSerializers,
TypeSerializer<K> keySerializer,
StreamTaskInput<Object>[] passThroughInputs,
MemoryManager memoryManager,
IOManager ioManager,
boolean objectReuse,
Expand All @@ -146,20 +161,28 @@ public static <K> SelectableSortingInputs wrapInputs(
comparator = new VariableLengthByteKeyComparator<>();
}

int numberOfInputs = inputs.length;
CommonContext commonContext = new CommonContext(numberOfInputs);
StreamTaskInput<?>[] sortingInputs =
IntStream.range(0, numberOfInputs)
List<Integer> passThroughInputIndices =
Arrays.stream(passThroughInputs)
.map(StreamTaskInput::getInputIndex)
.collect(Collectors.toList());
int numberOfInputs = sortingInputs.length + passThroughInputs.length;
CommonContext commonContext = new CommonContext(sortingInputs);
InputSelector inputSelector =
new InputSelector(commonContext, numberOfInputs, passThroughInputIndices);

StreamTaskInput<?>[] wrappedSortingInputs =
IntStream.range(0, sortingInputs.length)
.mapToObj(
idx -> {
try {
KeyAndValueSerializer<Object> keyAndValueSerializer =
new KeyAndValueSerializer<>(
inputSerializers[idx], keyLength);

return new MultiInputSortingDataInput<>(
commonContext,
inputs[idx],
idx,
sortingInputs[idx],
sortingInputs[idx].getInputIndex(),
ExternalSorter.newBuilder(
memoryManager,
containingTask,
Expand Down Expand Up @@ -189,8 +212,14 @@ public static <K> SelectableSortingInputs wrapInputs(
}
})
.toArray(StreamTaskInput[]::new);

StreamTaskInput<?>[] wrappedPassThroughInputs =
Arrays.stream(passThroughInputs)
.map(input -> new ObservableStreamTaskInput<>(input, inputSelector))
.toArray(StreamTaskInput[]::new);

return new SelectableSortingInputs(
sortingInputs, new InputSelector(commonContext, numberOfInputs));
wrappedSortingInputs, wrappedPassThroughInputs, inputSelector);
}

@Override
Expand Down Expand Up @@ -318,23 +347,40 @@ public CompletableFuture<?> getAvailableFuture() {
* all sorting inputs. Should be used by the {@link StreamInputProcessor} to choose the next
* input to consume from.
*/
private static class InputSelector implements InputSelectable {
private static class InputSelector implements InputSelectable, BoundedMultiInput {

private final CommonContext commonContext;
private final int numberOfInputs;
private final int numInputs;
private final Queue<Integer> passThroughInputsIndices;

private InputSelector(CommonContext commonContext, int numberOfInputs) {
private InputSelector(
CommonContext commonContext, int numInputs, List<Integer> passThroughInputIndices) {
this.commonContext = commonContext;
this.numberOfInputs = numberOfInputs;
this.numInputs = numInputs;
this.passThroughInputsIndices = new LinkedList<>(passThroughInputIndices);
}

@Override
public void endInput(int inputId) throws Exception {
passThroughInputsIndices.remove(inputId);
}

@Override
public InputSelection nextSelection() {
Integer currentPassThroughInputIndex = passThroughInputsIndices.peek();

if (currentPassThroughInputIndex != null) {
// yes, 0-based to 1-based mapping ... 🙏
return new InputSelection.Builder()
.select(currentPassThroughInputIndex + 1)
.build(numInputs);
}

if (commonContext.allSorted()) {
HeadElement headElement = commonContext.getQueueOfHeads().peek();
if (headElement != null) {
int headIdx = headElement.inputIndex;
return new InputSelection.Builder().select(headIdx + 1).build(numberOfInputs);
return new InputSelection.Builder().select(headIdx + 1).build(numInputs);
}
}
return InputSelection.ALL;
Expand Down Expand Up @@ -419,9 +465,10 @@ private static final class CommonContext {
private long notFinishedSortingMask = 0;
private long finishedEmitting = 0;

public CommonContext(int numberOfInputs) {
for (int i = 0; i < numberOfInputs; i++) {
notFinishedSortingMask = setBitMask(notFinishedSortingMask, i);
public CommonContext(StreamTaskInput<Object>[] sortingInputs) {
for (StreamTaskInput<Object> sortingInput : sortingInputs) {
notFinishedSortingMask =
setBitMask(notFinishedSortingMask, sortingInput.getInputIndex());
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.operators.sort;

import org.apache.flink.core.io.InputStatus;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.streaming.api.operators.BoundedMultiInput;
import org.apache.flink.streaming.runtime.io.StreamTaskInput;

import java.io.IOException;
import java.util.concurrent.CompletableFuture;

/**
* A wrapping {@link StreamTaskInput} that invokes a given {@link BoundedMultiInput} when reaching
* {@link InputStatus#END_OF_INPUT}.
*/
class ObservableStreamTaskInput<T> implements StreamTaskInput<T> {

private final StreamTaskInput<T> wrappedInput;
private final BoundedMultiInput endOfInputObserver;

public ObservableStreamTaskInput(
StreamTaskInput<T> wrappedInput, BoundedMultiInput endOfInputObserver) {
this.wrappedInput = wrappedInput;
this.endOfInputObserver = endOfInputObserver;
}

@Override
public InputStatus emitNext(DataOutput<T> output) throws Exception {
InputStatus result = wrappedInput.emitNext(output);
if (result == InputStatus.END_OF_INPUT) {
endOfInputObserver.endInput(wrappedInput.getInputIndex());
}
return result;
}

@Override
public int getInputIndex() {
return wrappedInput.getInputIndex();
}

@Override
public CompletableFuture<Void> prepareSnapshot(
ChannelStateWriter channelStateWriter, long checkpointId) throws IOException {
return wrappedInput.prepareSnapshot(channelStateWriter, checkpointId);
}

@Override
public void close() throws IOException {
wrappedInput.close();
}

@Override
public CompletableFuture<?> getAvailableFuture() {
return wrappedInput.getAvailableFuture();
}

@Override
public boolean isAvailable() {
return wrappedInput.isAvailable();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ public static StreamMultipleInputProcessor create(
idx, userClassloader))
.toArray(TypeSerializer[]::new),
streamConfig.getStateKeySerializer(userClassloader),
new StreamTaskInput[0],
memoryManager,
ioManager,
executionConfig.isObjectReuseEnabled(),
Expand All @@ -151,7 +152,7 @@ public static StreamMultipleInputProcessor create(
userClassloader),
jobConfig);

inputs = selectableSortingInputs.getSortingInputs();
inputs = selectableSortingInputs.getSortedInputs();
inputSelectable = selectableSortingInputs.getInputSelectable();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public static <IN1, IN2> StreamTwoInputProcessor<IN1, IN2> create(
},
new TypeSerializer[] {typeSerializer1, typeSerializer2},
streamConfig.getStateKeySerializer(userClassloader),
new StreamTaskInput[0],
memoryManager,
ioManager,
executionConfig.isObjectReuseEnabled(),
Expand All @@ -117,8 +118,8 @@ public static <IN1, IN2> StreamTwoInputProcessor<IN1, IN2> create(
userClassloader),
jobConfig);
inputSelectable = selectableSortingInputs.getInputSelectable();
input1 = getSortedInput(selectableSortingInputs.getSortingInputs()[0]);
input2 = getSortedInput(selectableSortingInputs.getSortingInputs()[1]);
input1 = getSortedInput(selectableSortingInputs.getSortedInputs()[0]);
input2 = getSortedInput(selectableSortingInputs.getSortedInputs()[1]);
}

StreamTaskNetworkOutput<IN1> output1 =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,14 @@ public void multiInputKeySorting() throws Exception {
GeneratedRecordsDataInput.SERIALIZER
},
new StringSerializer(),
new StreamTaskInput[0],
environment.getMemoryManager(),
environment.getIOManager(),
true,
1.0,
new Configuration());

StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortingInputs();
StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortedInputs();
try (StreamTaskInput<Tuple3<Integer, String, byte[]>> sortedInput1 =
(StreamTaskInput<Tuple3<Integer, String, byte[]>>)
sortingDataInputs[0];
Expand Down
Loading

0 comments on commit 00f8de7

Please sign in to comment.