Skip to content

Commit

Permalink
[FLINK-31261][runtime] Make AdaptiveScheduler aware of local state size
Browse files Browse the repository at this point in the history
  • Loading branch information
rkhachatryan committed Jul 19, 2024
1 parent caf9b23 commit 203acad
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,22 @@ public class JobAllocationsInformation {
}

public static JobAllocationsInformation fromGraph(@Nullable ExecutionGraph graph) {
return graph == null ? empty() : new JobAllocationsInformation(calculateAllocations(graph));
return graph == null
? empty()
: new JobAllocationsInformation(
calculateAllocations(graph, StateSizeEstimates.fromGraph(graph)));
}

public List<VertexAllocationInformation> getAllocations(JobVertexID jobVertexID) {
return vertexAllocations.getOrDefault(jobVertexID, emptyList());
}

private static Map<JobVertexID, List<VertexAllocationInformation>> calculateAllocations(
ExecutionGraph graph) {
ExecutionGraph graph, StateSizeEstimates stateSizeEstimates) {
final Map<JobVertexID, List<VertexAllocationInformation>> allocations = new HashMap<>();
for (ExecutionJobVertex vertex : graph.getVerticesTopologically()) {
JobVertexID jobVertexId = vertex.getJobVertexId();
long avgKgSize = stateSizeEstimates.estimate(jobVertexId).orElse(0L);
for (ExecutionVertex executionVertex : vertex.getTaskVertices()) {
AllocationID allocationId =
executionVertex.getCurrentExecutionAttempt().getAssignedAllocationID();
Expand All @@ -70,7 +74,9 @@ private static Map<JobVertexID, List<VertexAllocationInformation>> calculateAllo
executionVertex.getParallelSubtaskIndex());
allocations
.computeIfAbsent(jobVertexId, ignored -> new ArrayList<>())
.add(new VertexAllocationInformation(allocationId, jobVertexId, kgr));
.add(
new VertexAllocationInformation(
allocationId, jobVertexId, kgr, avgKgSize));
}
}
return allocations;
Expand All @@ -89,12 +95,17 @@ public static class VertexAllocationInformation {
private final AllocationID allocationID;
private final JobVertexID jobVertexID;
private final KeyGroupRange keyGroupRange;
public final long averageKeyGroupSizeInBytes;

public VertexAllocationInformation(
AllocationID allocationID, JobVertexID jobVertexID, KeyGroupRange keyGroupRange) {
AllocationID allocationID,
JobVertexID jobVertexID,
KeyGroupRange keyGroupRange,
long averageKeyGroupSizeInBytes) {
this.allocationID = allocationID;
this.jobVertexID = jobVertexID;
this.keyGroupRange = keyGroupRange;
this.averageKeyGroupSizeInBytes = averageKeyGroupSizeInBytes;
}

public AllocationID getAllocationID() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.jobmaster.SlotInfo;
import org.apache.flink.runtime.scheduler.adaptive.JobSchedulingPlan.SlotAssignment;
import org.apache.flink.runtime.scheduler.adaptive.allocator.JobAllocationsInformation.VertexAllocationInformation;
import org.apache.flink.runtime.scheduler.adaptive.allocator.SlotSharingSlotAllocator.ExecutionSlotSharingGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.state.KeyGroupRange;
Expand Down Expand Up @@ -183,11 +184,7 @@ public Collection<AllocationScore> calculateScore(
.getAllocations(evi.getJobVertexId())
.forEach(
allocation -> {
long value =
allocation
.getKeyGroupRange()
.getIntersection(kgr)
.getNumberOfKeyGroups();
long value = estimateSize(kgr, allocation);
if (value > 0) {
score.merge(allocation.getAllocationID(), value, Long::sum);
}
Expand All @@ -198,4 +195,14 @@ public Collection<AllocationScore> calculateScore(
.map(e -> new AllocationScore(group.getId(), e.getKey(), e.getValue()))
.collect(Collectors.toList());
}

private static long estimateSize(
KeyGroupRange newRange, VertexAllocationInformation allocation) {
KeyGroupRange oldRange = allocation.getKeyGroupRange();
// Estimate state size per key group. For scoring, assume 1 if size estimate is 0 to
// accommodate for averaging non-zero states
long keyGroupSize = Math.max(allocation.averageKeyGroupSizeInBytes, 1L);
int numberOfKeyGroups = oldRange.getIntersection(newRange).getNumberOfKeyGroups();
return numberOfKeyGroups * keyGroupSize;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.scheduler.adaptive.allocator;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.KeyedStateHandle;

import javax.annotation.Nullable;

import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toMap;

/** Managed Keyed State size estimates used to make scheduling decisions. */
@Internal
public class StateSizeEstimates {
private final Map<JobVertexID, Long> averages;

public StateSizeEstimates() {
this(Collections.emptyMap());
}

public StateSizeEstimates(Map<JobVertexID, Long> averages) {
this.averages = averages;
}

public Optional<Long> estimate(JobVertexID jobVertexId) {
return Optional.ofNullable(averages.get(jobVertexId));
}

static StateSizeEstimates empty() {
return new StateSizeEstimates();
}

public static StateSizeEstimates fromGraph(@Nullable ExecutionGraph executionGraph) {
return Optional.ofNullable(executionGraph)
.flatMap(graph -> Optional.ofNullable(graph.getCheckpointCoordinator()))
.flatMap(coordinator -> Optional.ofNullable(coordinator.getCheckpointStore()))
.flatMap(store -> Optional.ofNullable(store.getLatestCheckpoint()))
.map(
cp ->
build(
fromCompletedCheckpoint(cp),
mapVerticesToOperators(executionGraph)))
.orElse(empty());
}

private static StateSizeEstimates build(
Map<OperatorID, Long> sizePerOperator,
Map<JobVertexID, Set<OperatorID>> verticesToOperators) {
Map<JobVertexID, Long> verticesToSizes =
verticesToOperators.entrySet().stream()
.collect(
toMap(Map.Entry::getKey, e -> size(e.getValue(), sizePerOperator)));
return new StateSizeEstimates(verticesToSizes);
}

private static long size(Set<OperatorID> ids, Map<OperatorID, Long> sizes) {
return ids.stream().mapToLong(key -> sizes.getOrDefault(key, 0L)).sum();
}

private static Map<JobVertexID, Set<OperatorID>> mapVerticesToOperators(
ExecutionGraph executionGraph) {
return executionGraph.getAllVertices().entrySet().stream()
.collect(toMap(Map.Entry::getKey, e -> getOperatorIDS(e.getValue())));
}

private static Set<OperatorID> getOperatorIDS(ExecutionJobVertex v) {
return v.getOperatorIDs().stream()
.map(OperatorIDPair::getGeneratedOperatorID)
.collect(Collectors.toSet());
}

private static Map<OperatorID, Long> fromCompletedCheckpoint(CompletedCheckpoint cp) {
Stream<Map.Entry<OperatorID, OperatorState>> states =
cp.getOperatorStates().entrySet().stream();
return states.collect(
toMap(
Map.Entry::getKey,
e -> calculateAverageKeyGroupStateSizeInBytes(e.getValue())));
}

private static long calculateAverageKeyGroupStateSizeInBytes(OperatorState state) {
Stream<KeyedStateHandle> handles =
state.getSubtaskStates().values().stream()
.flatMap(s -> s.getManagedKeyedState().stream());
Stream<Tuple2<Long, Integer>> sizeAndCount =
handles.map(
h ->
Tuple2.of(
h.getStateSize(),
h.getKeyGroupRange().getNumberOfKeyGroups()));
Optional<Tuple2<Long, Integer>> totalSizeAndCount =
sizeAndCount.reduce(
(left, right) -> Tuple2.of(left.f0 + right.f0, left.f1 + right.f1));
Optional<Long> average = totalSizeAndCount.filter(t2 -> t2.f1 > 0).map(t2 -> t2.f0 / t2.f1);
return average.orElse(0L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -388,20 +388,29 @@ void testStickyAllocation() {
vertex1.getJobVertexID(),
Collections.singletonList(
new VertexAllocationInformation(
allocation1, vertex1.getJobVertexID(), KeyGroupRange.of(1, 100))));
allocation1,
vertex1.getJobVertexID(),
KeyGroupRange.of(1, 100),
1)));
locality.put(
vertex2.getJobVertexID(),
Collections.singletonList(
new VertexAllocationInformation(
allocation1, vertex2.getJobVertexID(), KeyGroupRange.of(1, 100))));
allocation1,
vertex2.getJobVertexID(),
KeyGroupRange.of(1, 100),
1)));

// previous allocation allocation2: v3
AllocationID allocation2 = new AllocationID();
locality.put(
vertex3.getJobVertexID(),
Collections.singletonList(
new VertexAllocationInformation(
allocation2, vertex3.getJobVertexID(), KeyGroupRange.of(1, 100))));
allocation2,
vertex3.getJobVertexID(),
KeyGroupRange.of(1, 100),
1)));

List<SlotInfo> freeSlots = new ArrayList<>();
IntStream.range(0, 10).forEach(i -> freeSlots.add(new TestSlotInfo(new AllocationID())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ void testSlotsAreNotWasted() {
List<VertexAllocationInformation> allocations =
Arrays.asList(
new VertexAllocationInformation(
alloc1, vertex.getJobVertexID(), KeyGroupRange.of(0, 9)),
alloc1, vertex.getJobVertexID(), KeyGroupRange.of(0, 9), 1),
new VertexAllocationInformation(
alloc2, vertex.getJobVertexID(), KeyGroupRange.of(10, 19)));
alloc2, vertex.getJobVertexID(), KeyGroupRange.of(10, 19), 1));

assign(vertex, Arrays.asList(alloc1, alloc2), allocations);
}
Expand All @@ -76,7 +76,8 @@ void testUpScaling() {
iterator.next(),
vertex.getJobVertexID(),
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
vertex.getMaxParallelism(), oldParallelism, i)));
vertex.getMaxParallelism(), oldParallelism, i),
1));
}

Collection<SlotAssignment> assignments = assign(vertex, allocationIDs, prevAllocations);
Expand Down Expand Up @@ -106,7 +107,8 @@ void testDownScaling() {
new VertexAllocationInformation(
biggestAllocation,
vertex.getJobVertexID(),
KeyGroupRange.of(0, halfOfKeyGroupRange - 1)));
KeyGroupRange.of(0, halfOfKeyGroupRange - 1),
1));

// and the remaining subtasks had only one key group each
for (int subtaskIdx = 1; subtaskIdx < oldParallelism; subtaskIdx++) {
Expand All @@ -115,7 +117,8 @@ void testDownScaling() {
new VertexAllocationInformation(
iterator.next(),
vertex.getJobVertexID(),
KeyGroupRange.of(keyGroup, keyGroup)));
KeyGroupRange.of(keyGroup, keyGroup),
1));
}

Collection<SlotAssignment> assignments = assign(vertex, allocationIDs, prevAllocations);
Expand Down

0 comments on commit 203acad

Please sign in to comment.