Skip to content

Commit

Permalink
[hotfix] Introduce BufferIndexAndChannel and make HsSpillingStrategy …
Browse files Browse the repository at this point in the history
…using it instead of BufferWithIdentity.
  • Loading branch information
reswqa authored and xintongsong committed Jul 26, 2022
1 parent 7611928 commit 5494fe6
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.io.network.partition.hybrid;

/** Integrate the buffer index and the channel id which it belongs. */
public class BufferIndexAndChannel {
private final int bufferIndex;

private final int channel;

public BufferIndexAndChannel(int bufferIndex, int channel) {
this.bufferIndex = bufferIndex;
this.channel = channel;
}

public int getBufferIndex() {
return bufferIndex;
}

public int getChannel() {
return channel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public Optional<Decision> onBufferFinished(int numTotalUnSpillBuffers) {

// For the case of buffer consumed, there is no need to take action for HsFullSpillingStrategy.
@Override
public Optional<Decision> onBufferConsumed(BufferWithIdentity consumedBuffer) {
public Optional<Decision> onBufferConsumed(BufferIndexAndChannel consumedBuffer) {
return Optional.of(Decision.NO_ACTION);
}

Expand Down Expand Up @@ -85,7 +85,7 @@ private void checkSpill(HsSpillingInfoProvider spillingInfoProvider, Decision.Bu
return;
}
// Spill all not spill buffers.
List<BufferWithIdentity> unSpillBuffers = new ArrayList<>();
List<BufferIndexAndChannel> unSpillBuffers = new ArrayList<>();
for (int i = 0; i < spillingInfoProvider.getNumSubpartitions(); i++) {
unSpillBuffers.addAll(
spillingInfoProvider.getBuffersInOrder(
Expand All @@ -105,23 +105,23 @@ private void checkRelease(
int releaseNum = (int) (spillingInfoProvider.getPoolSize() * releaseBufferRatio);

// first, release all consumed buffers
TreeMap<Integer, Deque<BufferWithIdentity>> consumedBuffersToRelease = new TreeMap<>();
TreeMap<Integer, Deque<BufferIndexAndChannel>> consumedBuffersToRelease = new TreeMap<>();
int numConsumedBuffers = 0;
for (int subpartitionId = 0;
subpartitionId < spillingInfoProvider.getNumSubpartitions();
subpartitionId++) {

Deque<BufferWithIdentity> consumedSpillSubpartitionBuffers =
Deque<BufferIndexAndChannel> consumedSpillSubpartitionBuffers =
spillingInfoProvider.getBuffersInOrder(
subpartitionId, SpillStatus.SPILL, ConsumeStatus.CONSUMED);
numConsumedBuffers += consumedSpillSubpartitionBuffers.size();
consumedBuffersToRelease.put(subpartitionId, consumedSpillSubpartitionBuffers);
}

// make up the releaseNum with unconsumed buffers, if needed, w.r.t. the consuming priority
TreeMap<Integer, List<BufferWithIdentity>> unconsumedBufferToRelease = new TreeMap<>();
TreeMap<Integer, List<BufferIndexAndChannel>> unconsumedBufferToRelease = new TreeMap<>();
if (releaseNum > numConsumedBuffers) {
TreeMap<Integer, Deque<BufferWithIdentity>> unconsumedBuffers = new TreeMap<>();
TreeMap<Integer, Deque<BufferIndexAndChannel>> unconsumedBuffers = new TreeMap<>();
for (int subpartitionId = 0;
subpartitionId < spillingInfoProvider.getNumSubpartitions();
subpartitionId++) {
Expand All @@ -138,7 +138,7 @@ private void checkRelease(
}

// collect results in order
List<BufferWithIdentity> toRelease = new ArrayList<>();
List<BufferIndexAndChannel> toRelease = new ArrayList<>();
for (int i = 0; i < spillingInfoProvider.getNumSubpartitions(); i++) {
toRelease.addAll(consumedBuffersToRelease.getOrDefault(i, new ArrayDeque<>()));
toRelease.addAll(unconsumedBufferToRelease.getOrDefault(i, new ArrayList<>()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public Optional<Decision> onBufferFinished(int numTotalUnSpillBuffers) {
// For the case of buffer consumed, this buffer need release. The control of the buffer is taken
// over by the downstream task.
@Override
public Optional<Decision> onBufferConsumed(BufferWithIdentity consumedBuffer) {
public Optional<Decision> onBufferConsumed(BufferIndexAndChannel consumedBuffer) {
return Optional.of(Decision.builder().addBufferToRelease(consumedBuffer).build());
}

Expand Down Expand Up @@ -80,15 +80,15 @@ public Decision decideActionWithGlobalInfo(HsSpillingInfoProvider spillingInfoPr

int spillNum = (int) (spillingInfoProvider.getPoolSize() * spillBufferRatio);

TreeMap<Integer, Deque<BufferWithIdentity>> subpartitionToBuffers = new TreeMap<>();
TreeMap<Integer, Deque<BufferIndexAndChannel>> subpartitionToBuffers = new TreeMap<>();
for (int channel = 0; channel < spillingInfoProvider.getNumSubpartitions(); channel++) {
subpartitionToBuffers.put(
channel,
spillingInfoProvider.getBuffersInOrder(
channel, SpillStatus.NOT_SPILL, ConsumeStatus.NOT_CONSUMED));
}

TreeMap<Integer, List<BufferWithIdentity>> subpartitionToHighPriorityBuffers =
TreeMap<Integer, List<BufferIndexAndChannel>> subpartitionToHighPriorityBuffers =
getBuffersByConsumptionPriorityInOrder(
spillingInfoProvider.getNextBufferIndexToConsume(),
subpartitionToBuffers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public interface HsSpillingInfoProvider {
* according to bufferIndex from small to large, in other words, head is the buffer with the
* minimum bufferIndex in the current subpartition.
*/
Deque<BufferWithIdentity> getBuffersInOrder(
Deque<BufferIndexAndChannel> getBuffersInOrder(
int subpartitionId, SpillStatus spillStatus, ConsumeStatus consumeStatus);

/** Get total number of not decided to spill buffers. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public interface HsSpillingStrategy {
* @return A {@link Decision} based on the provided information, or {@link Optional#empty()} if
* the decision cannot be made, which indicates global information is needed.
*/
Optional<Decision> onBufferConsumed(BufferWithIdentity consumedBuffer);
Optional<Decision> onBufferConsumed(BufferIndexAndChannel consumedBuffer);

/**
* Make a decision based on global information. Because this method will directly touch the
Expand All @@ -74,25 +74,26 @@ public interface HsSpillingStrategy {
*/
class Decision {
/** A collection of buffer that needs to be spilled to disk. */
private final List<BufferWithIdentity> bufferToSpill;
private final List<BufferIndexAndChannel> bufferToSpill;

/** A collection of buffer that needs to be released. */
private final List<BufferWithIdentity> bufferToRelease;
private final List<BufferIndexAndChannel> bufferToRelease;

public static final Decision NO_ACTION =
new Decision(Collections.emptyList(), Collections.emptyList());

private Decision(
List<BufferWithIdentity> bufferToSpill, List<BufferWithIdentity> bufferToRelease) {
List<BufferIndexAndChannel> bufferToSpill,
List<BufferIndexAndChannel> bufferToRelease) {
this.bufferToSpill = bufferToSpill;
this.bufferToRelease = bufferToRelease;
}

public List<BufferWithIdentity> getBufferToSpill() {
public List<BufferIndexAndChannel> getBufferToSpill() {
return bufferToSpill;
}

public List<BufferWithIdentity> getBufferToRelease() {
public List<BufferIndexAndChannel> getBufferToRelease() {
return bufferToRelease;
}

Expand All @@ -103,29 +104,29 @@ public static Builder builder() {
/** Builder for {@link Decision}. */
static class Builder {
/** A collection of buffer that needs to be spilled to disk. */
private final List<BufferWithIdentity> bufferToSpill = new ArrayList<>();
private final List<BufferIndexAndChannel> bufferToSpill = new ArrayList<>();

/** A collection of buffer that needs to be released. */
private final List<BufferWithIdentity> bufferToRelease = new ArrayList<>();
private final List<BufferIndexAndChannel> bufferToRelease = new ArrayList<>();

private Builder() {}

public Builder addBufferToSpill(BufferWithIdentity buffer) {
public Builder addBufferToSpill(BufferIndexAndChannel buffer) {
bufferToSpill.add(buffer);
return this;
}

public Builder addBufferToSpill(List<BufferWithIdentity> buffers) {
public Builder addBufferToSpill(List<BufferIndexAndChannel> buffers) {
bufferToSpill.addAll(buffers);
return this;
}

public Builder addBufferToRelease(BufferWithIdentity buffer) {
public Builder addBufferToRelease(BufferIndexAndChannel buffer) {
bufferToRelease.add(buffer);
return this;
}

public Builder addBufferToRelease(List<BufferWithIdentity> buffers) {
public Builder addBufferToRelease(List<BufferIndexAndChannel> buffers) {
bufferToRelease.addAll(buffers);
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ public class HsSpillingStrategyUtils {
* @return mapping for subpartitionId to buffers, the value of map entry must be order by
* bufferIndex ascending.
*/
public static TreeMap<Integer, List<BufferWithIdentity>> getBuffersByConsumptionPriorityInOrder(
List<Integer> nextBufferIndexToConsume,
TreeMap<Integer, Deque<BufferWithIdentity>> subpartitionToAllBuffers,
int expectedSize) {
public static TreeMap<Integer, List<BufferIndexAndChannel>>
getBuffersByConsumptionPriorityInOrder(
List<Integer> nextBufferIndexToConsume,
TreeMap<Integer, Deque<BufferIndexAndChannel>> subpartitionToAllBuffers,
int expectedSize) {
if (expectedSize <= 0) {
return new TreeMap<>();
}
Expand All @@ -63,17 +64,17 @@ public static TreeMap<Integer, List<BufferWithIdentity>> getBuffersByConsumption
}
});

TreeMap<Integer, List<BufferWithIdentity>> subpartitionToHighPriorityBuffers =
TreeMap<Integer, List<BufferIndexAndChannel>> subpartitionToHighPriorityBuffers =
new TreeMap<>();
for (int i = 0; i < expectedSize; i++) {
if (heap.isEmpty()) {
break;
}
BufferConsumptionPriorityIterator bufferConsumptionPriorityIterator = heap.poll();
BufferWithIdentity bufferWithIdentity = bufferConsumptionPriorityIterator.next();
BufferIndexAndChannel bufferIndexAndChannel = bufferConsumptionPriorityIterator.next();
subpartitionToHighPriorityBuffers
.computeIfAbsent(bufferWithIdentity.getChannelIndex(), ArrayList::new)
.add(bufferWithIdentity);
.computeIfAbsent(bufferIndexAndChannel.getChannel(), ArrayList::new)
.add(bufferIndexAndChannel);
// if this iterator has next, re-added it.
if (bufferConsumptionPriorityIterator.hasNext()) {
heap.add(bufferConsumptionPriorityIterator);
Expand All @@ -89,24 +90,25 @@ public static TreeMap<Integer, List<BufferWithIdentity>> getBuffersByConsumption

/**
* Special {@link Iterator} for hybrid shuffle mode that wrapped a deque of {@link
* BufferWithIdentity}. Tow iterator can compare by compute consumption priority of peek
* BufferIndexAndChannel}. Tow iterator can compare by compute consumption priority of peek
* element.
*/
private static class BufferConsumptionPriorityIterator
implements Comparable<BufferConsumptionPriorityIterator>, Iterator<BufferWithIdentity> {
implements Comparable<BufferConsumptionPriorityIterator>,
Iterator<BufferIndexAndChannel> {

private final int consumptionProgress;

private final PeekingIterator<BufferWithIdentity> bufferIterator;
private final PeekingIterator<BufferIndexAndChannel> bufferIterator;

public BufferConsumptionPriorityIterator(
Deque<BufferWithIdentity> bufferQueue, int consumptionProgress) {
Deque<BufferIndexAndChannel> bufferQueue, int consumptionProgress) {
this.consumptionProgress = consumptionProgress;
this.bufferIterator = Iterators.peekingIterator(bufferQueue.descendingIterator());
}

// move the iterator to next item.
public BufferWithIdentity next() {
public BufferIndexAndChannel next() {
return bufferIterator.next();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
import java.util.List;
import java.util.Optional;

import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBuffer;
import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferWithIdentitiesList;
import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsList;
import static org.assertj.core.api.Assertions.assertThat;

/** Tests for {@link HsFullSpillingStrategy}. */
Expand Down Expand Up @@ -67,9 +66,9 @@ void testOnBufferFinishedUnSpillBufferEqualToOrGreatThenThreshold() {

@Test
void testOnBufferConsumed() {
BufferWithIdentity bufferWithIdentity = new BufferWithIdentity(createBuffer(), 0, 0);
BufferIndexAndChannel bufferIndexAndChannel = new BufferIndexAndChannel(0, 0);
Optional<Decision> bufferConsumedDecision =
spillStrategy.onBufferConsumed(bufferWithIdentity);
spillStrategy.onBufferConsumed(bufferIndexAndChannel);
assertThat(bufferConsumedDecision).hasValue(Decision.NO_ACTION);
}

Expand All @@ -96,16 +95,16 @@ void testDecideActionWithGlobalInfo() {
final int progress1 = 10;
final int progress2 = 20;

List<BufferWithIdentity> subpartitionBuffers1 =
createBufferWithIdentitiesList(
List<BufferIndexAndChannel> subpartitionBuffers1 =
createBufferIndexAndChannelsList(
subpartition1,
progress1,
progress1 + 2,
progress1 + 4,
progress1 + 6,
progress1 + 8);
List<BufferWithIdentity> subpartitionBuffers2 =
createBufferWithIdentitiesList(
List<BufferIndexAndChannel> subpartitionBuffers2 =
createBufferIndexAndChannelsList(
subpartition2,
progress2 + 1,
progress2 + 3,
Expand All @@ -132,13 +131,13 @@ void testDecideActionWithGlobalInfo() {
Decision decision = spillStrategy.decideActionWithGlobalInfo(spillInfoProvider);

// all not spilled buffers need to spill.
ArrayList<BufferWithIdentity> expectedSpillBuffers =
ArrayList<BufferIndexAndChannel> expectedSpillBuffers =
new ArrayList<>(subpartitionBuffers1.subList(4, 5));
expectedSpillBuffers.add(subpartitionBuffers2.get(0));
expectedSpillBuffers.addAll(subpartitionBuffers2.subList(4, 5));
assertThat(decision.getBufferToSpill()).isEqualTo(expectedSpillBuffers);

ArrayList<BufferWithIdentity> expectedReleaseBuffers = new ArrayList<>();
ArrayList<BufferIndexAndChannel> expectedReleaseBuffers = new ArrayList<>();
// all consumed spill buffers should release.
expectedReleaseBuffers.addAll(subpartitionBuffers1.subList(0, 2));
// priority higher buffers should release.
Expand All @@ -154,8 +153,8 @@ void testDecideActionWithGlobalInfo() {
@Test
void testDecideActionWithGlobalInfoAllConsumedSpillBufferShouldRelease() {
final int subpartitionId = 0;
List<BufferWithIdentity> subpartitionBuffers =
createBufferWithIdentitiesList(subpartitionId, 0, 1, 2, 3, 4);
List<BufferIndexAndChannel> subpartitionBuffers =
createBufferIndexAndChannelsList(subpartitionId, 0, 1, 2, 3, 4);

final int poolSize = 5;
TestingSpillingInfoProvider spillInfoProvider =
Expand Down
Loading

0 comments on commit 5494fe6

Please sign in to comment.