Skip to content

Commit

Permalink
Introduce BucketNodeMap
Browse files Browse the repository at this point in the history
Splits from the same buckets are scheduled to the same node to allow
bucketed execution. Previously, NodeSelector#computeAssignments takes
NodePartitionMap, which contains extra information.

This commit introduces BucketNodeMap to represent the mapping in a
separate class. This also opens opportunities for engine to schedule
bucket execution in a more flexible way.
  • Loading branch information
wenleix committed Dec 2, 2018
1 parent b06eb0e commit fd35b98
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution.scheduler;

import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.Node;

import java.util.Optional;
import java.util.function.ToIntFunction;

import static java.util.Objects.requireNonNull;

public abstract class BucketNodeMap
{
private final ToIntFunction<Split> splitToBucket;

public BucketNodeMap(ToIntFunction<Split> splitToBucket)
{
this.splitToBucket = requireNonNull(splitToBucket, "splitToBucket is null");
}

public abstract int getBucketCount();

public abstract Optional<Node> getAssignedNode(int bucketedId);

public final Optional<Node> getAssignedNode(Split split)
{
return getAssignedNode(splitToBucket.applyAsInt(split));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution.scheduler;

import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.Node;
import com.google.common.collect.ImmutableMap;

import java.util.Map;
import java.util.Optional;
import java.util.function.ToIntFunction;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static java.util.Objects.requireNonNull;

// the bucket to node mapping is fixed and pre-assigned
public class FixedBucketNodeMap
extends BucketNodeMap
{
private final Map<Integer, Node> bucketToNode;
private final int bucketCount;

public FixedBucketNodeMap(ToIntFunction<Split> splitToBucket, Map<Integer, Node> bucketToNode)
{
super(splitToBucket);
requireNonNull(bucketToNode, "bucketToNode is null");
this.bucketToNode = ImmutableMap.copyOf(bucketToNode);
bucketCount = bucketToNode.keySet().stream()
.mapToInt(Integer::intValue)
.max()
.getAsInt() + 1;
}

@Override
public Optional<Node> getAssignedNode(int bucketedId)
{
checkArgument(bucketedId >= 0 && bucketedId < bucketCount);
Node node = bucketToNode.get(bucketedId);
verify(node != null);
return Optional.of(node);
}

@Override
public int getBucketCount()
{
return bucketCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.connector.ConnectorPartitionHandle;
import com.facebook.presto.split.SplitSource;
import com.facebook.presto.sql.planner.NodePartitionMap;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.log.Logger;
Expand Down Expand Up @@ -55,6 +55,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.airlift.concurrent.MoreFutures.whenAnyComplete;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public class FixedSourcePartitionedScheduler
Expand All @@ -63,7 +64,7 @@ public class FixedSourcePartitionedScheduler
private static final Logger log = Logger.get(FixedSourcePartitionedScheduler.class);

private final SqlStageExecution stage;
private final NodePartitionMap partitioning;
private final List<Node> nodes;
private final List<SourceScheduler> sourceSchedulers;
private final List<ConnectorPartitionHandle> partitionHandles;
private boolean scheduledTasks;
Expand All @@ -74,30 +75,32 @@ public FixedSourcePartitionedScheduler(
Map<PlanNodeId, SplitSource> splitSources,
StageExecutionStrategy stageExecutionStrategy,
List<PlanNodeId> schedulingOrder,
NodePartitionMap partitioning,
List<Node> nodes,
BucketNodeMap bucketNodeMap,
int splitBatchSize,
OptionalInt concurrentLifespansPerTask,
NodeSelector nodeSelector,
List<ConnectorPartitionHandle> partitionHandles)
{
requireNonNull(stage, "stage is null");
requireNonNull(splitSources, "splitSources is null");
requireNonNull(partitioning, "partitioning is null");
requireNonNull(bucketNodeMap, "bucketNodeMap is null");
checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty");
requireNonNull(partitionHandles, "partitionHandles is null");

this.stage = stage;
this.partitioning = partitioning;
this.nodes = nodes;
this.partitionHandles = ImmutableList.copyOf(partitionHandles);

checkArgument(splitSources.keySet().equals(ImmutableSet.copyOf(schedulingOrder)));

FixedSplitPlacementPolicy splitPlacementPolicy = new FixedSplitPlacementPolicy(nodeSelector, partitioning, stage::getAllTasks);
BucketedSplitPlacementPolicy splitPlacementPolicy = new BucketedSplitPlacementPolicy(nodeSelector, nodes, bucketNodeMap, stage::getAllTasks);

ArrayList<SourceScheduler> sourceSchedulers = new ArrayList<>();
checkArgument(
partitionHandles.equals(ImmutableList.of(NOT_PARTITIONED)) != stageExecutionStrategy.isAnyScanGroupedExecution(),
"PartitionHandles should be [NOT_PARTITIONED] if and only if all scan nodes use ungrouped execution strategy");
int nodeCount = partitioning.getPartitionToNode().size();
int nodeCount = nodes.size();
int concurrentLifespans;
if (concurrentLifespansPerTask.isPresent() && concurrentLifespansPerTask.getAsInt() * nodeCount <= partitionHandles.size()) {
concurrentLifespans = concurrentLifespansPerTask.getAsInt() * nodeCount;
Expand Down Expand Up @@ -129,7 +132,7 @@ public FixedSourcePartitionedScheduler(
sourceScheduler.noMoreLifespans();
}
else {
LifespanScheduler lifespanScheduler = new LifespanScheduler(partitioning, partitionHandles, concurrentLifespansPerTask);
LifespanScheduler lifespanScheduler = new LifespanScheduler(bucketNodeMap, partitionHandles, concurrentLifespansPerTask);
// Schedule the first few lifespans
lifespanScheduler.scheduleInitial(sourceScheduler);
// Schedule new lifespans for finished ones
Expand All @@ -156,9 +159,10 @@ public ScheduleResult schedule()
// schedule a task on every node in the distribution
List<RemoteTask> newTasks = ImmutableList.of();
if (!scheduledTasks) {
OptionalInt totalPartitions = OptionalInt.of(partitioning.getPartitionToNode().size());
newTasks = partitioning.getPartitionToNode().entrySet().stream()
.map(entry -> stage.scheduleTask(entry.getValue(), entry.getKey(), totalPartitions))
OptionalInt totalPartitions = OptionalInt.of(nodes.size());
newTasks = Streams.mapWithIndex(
nodes.stream(),
(node, id) -> stage.scheduleTask(node, toIntExact(id), totalPartitions))
.collect(toImmutableList());
scheduledTasks = true;
}
Expand Down Expand Up @@ -236,26 +240,30 @@ public void close()
sourceSchedulers.clear();
}

public static class FixedSplitPlacementPolicy
public static class BucketedSplitPlacementPolicy
implements SplitPlacementPolicy
{
private final NodeSelector nodeSelector;
private final NodePartitionMap partitioning;
private final List<Node> allNodes;
private final BucketNodeMap bucketNodeMap;
private final Supplier<? extends List<RemoteTask>> remoteTasks;

public FixedSplitPlacementPolicy(NodeSelector nodeSelector,
NodePartitionMap partitioning,
public BucketedSplitPlacementPolicy(
NodeSelector nodeSelector,
List<Node> allNodes,
BucketNodeMap bucketNodeMap,
Supplier<? extends List<RemoteTask>> remoteTasks)
{
this.nodeSelector = nodeSelector;
this.partitioning = partitioning;
this.remoteTasks = remoteTasks;
this.nodeSelector = requireNonNull(nodeSelector, "nodeSelector is null");
this.allNodes = ImmutableList.copyOf(requireNonNull(allNodes, "allNodes is null"));
this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null");
this.remoteTasks = requireNonNull(remoteTasks, "remoteTasks is null");
}

@Override
public SplitPlacementResult computeAssignments(Set<Split> splits)
{
return nodeSelector.computeAssignments(splits, remoteTasks.get(), partitioning);
return nodeSelector.computeAssignments(splits, remoteTasks.get(), bucketNodeMap);
}

@Override
Expand All @@ -266,12 +274,12 @@ public void lockDownNodes()
@Override
public List<Node> allNodes()
{
return ImmutableList.copyOf(partitioning.getPartitionToNode().values());
return allNodes;
}

public Node getNodeForBucket(int bucketId)
{
return partitioning.getPartitionToNode().get(partitioning.getBucketToPartition()[bucketId]);
return bucketNodeMap.getAssignedNode(bucketId).get();
}
}

Expand All @@ -294,15 +302,15 @@ private static class LifespanScheduler
private final List<Lifespan> recentlyCompletedDriverGroups = new ArrayList<>();
private int totalDriverGroupsScheduled;

public LifespanScheduler(NodePartitionMap nodePartitionMap, List<ConnectorPartitionHandle> partitionHandles, OptionalInt concurrentLifespansPerTask)
public LifespanScheduler(BucketNodeMap bucketNodeMap, List<ConnectorPartitionHandle> partitionHandles, OptionalInt concurrentLifespansPerTask)
{
checkArgument(!partitionHandles.equals(ImmutableList.of(NOT_PARTITIONED)));
checkArgument(partitionHandles.size() == bucketNodeMap.getBucketCount());

Map<Node, IntList> nodeToDriverGroupMap = new HashMap<>();
Int2ObjectMap<Node> driverGroupToNodeMap = new Int2ObjectOpenHashMap<>();
int[] bucketToPartition = nodePartitionMap.getBucketToPartition();
Map<Integer, Node> partitionToNode = nodePartitionMap.getPartitionToNode();
for (int bucket = 0; bucket < bucketToPartition.length; bucket++) {
int partition = bucketToPartition[bucket];
Node node = partitionToNode.get(partition);
for (int bucket = 0; bucket < bucketNodeMap.getBucketCount(); bucket++) {
Node node = bucketNodeMap.getAssignedNode(bucket).get();
nodeToDriverGroupMap.computeIfAbsent(node, key -> new IntArrayList()).add(bucket);
driverGroupToNodeMap.put(bucket, node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.HostAddress;
import com.facebook.presto.spi.Node;
import com.facebook.presto.sql.planner.NodePartitionMap;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultimap;
Expand Down Expand Up @@ -266,15 +265,15 @@ public static SplitPlacementResult selectDistributionNodes(
int maxPendingSplitsPerTask,
Set<Split> splits,
List<RemoteTask> existingTasks,
NodePartitionMap partitioning)
BucketNodeMap bucketNodeMap)
{
Multimap<Node, Split> assignments = HashMultimap.create();
NodeAssignmentStats assignmentStats = new NodeAssignmentStats(nodeTaskMap, nodeMap, existingTasks);

Set<Node> blockedNodes = new HashSet<>();
for (Split split : splits) {
// node placement is forced by the partitioning
Node node = partitioning.getNode(split);
// node placement is forced by the bucket to node map
Node node = bucketNodeMap.getAssignedNode(split).get();

// if node is full, don't schedule now, which will push back on the scheduling of splits
if (assignmentStats.getTotalSplitCount(node) < maxSplitsPerNode ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.facebook.presto.execution.RemoteTask;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.Node;
import com.facebook.presto.sql.planner.NodePartitionMap;
import com.google.common.collect.ImmutableSet;

import java.util.List;
Expand Down Expand Up @@ -55,5 +54,5 @@ default List<Node> selectRandomNodes(int limit)
* If we cannot find an assignment for a split, it is not included in the map. Also returns a future indicating when
* to reattempt scheduling of this batch of splits, if some of them could not be scheduled.
*/
SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTask> existingTasks, NodePartitionMap partitioning);
SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTask> existingTasks, BucketNodeMap bucketNodeMap);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.sql.planner.NodePartitionMap;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultimap;
Expand Down Expand Up @@ -170,8 +169,8 @@ else if (!splitWaitingForAnyNode) {
}

@Override
public SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTask> existingTasks, NodePartitionMap partitioning)
public SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTask> existingTasks, BucketNodeMap bucketNodeMap)
{
return selectDistributionNodes(nodeMap.get().get(), nodeTaskMap, maxSplitsPerNode, maxPendingSplitsPerTask, splits, existingTasks, partitioning);
return selectDistributionNodes(nodeMap.get().get(), nodeTaskMap, maxSplitsPerNode, maxPendingSplitsPerTask, splits, existingTasks, bucketNodeMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.RemoteTask;
import com.facebook.presto.execution.SqlStageExecution;
import com.facebook.presto.execution.scheduler.FixedSourcePartitionedScheduler.FixedSplitPlacementPolicy;
import com.facebook.presto.execution.scheduler.FixedSourcePartitionedScheduler.BucketedSplitPlacementPolicy;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.connector.ConnectorPartitionHandle;
Expand Down Expand Up @@ -285,7 +285,7 @@ else if (pendingSplits.isEmpty()) {
if (pendingSplits.isEmpty() && scheduleGroup.state == ScheduleGroupState.NO_MORE_SPLITS) {
scheduleGroup.state = ScheduleGroupState.DONE;
if (!lifespan.isTaskWide()) {
Node node = ((FixedSplitPlacementPolicy) splitPlacementPolicy).getNodeForBucket(lifespan.getId());
Node node = ((BucketedSplitPlacementPolicy) splitPlacementPolicy).getNodeForBucket(lifespan.getId());
noMoreSplitsNotification = ImmutableMultimap.of(node, lifespan);
}
}
Expand Down
Loading

0 comments on commit fd35b98

Please sign in to comment.