Skip to content

Commit

Permalink
Refactor stage output data size estimation
Browse files Browse the repository at this point in the history
Extract OutputDataSizeEstimator and model current logic
as a series of implementations of it wrapped in
CompositeOutputDataSizeEstimator.
  • Loading branch information
losipiuk committed Oct 17, 2023
1 parent 6cc1843 commit d17ae7a
Show file tree
Hide file tree
Showing 9 changed files with 506 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.trino.execution.scheduler.faulttolerant.EventDrivenFaultTolerantQueryScheduler;
import io.trino.execution.scheduler.faulttolerant.EventDrivenTaskSourceFactory;
import io.trino.execution.scheduler.faulttolerant.NodeAllocatorService;
import io.trino.execution.scheduler.faulttolerant.OutputDataSizeEstimatorFactory;
import io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimatorFactory;
import io.trino.execution.scheduler.faulttolerant.TaskDescriptorStorage;
import io.trino.execution.scheduler.policy.ExecutionPolicy;
Expand Down Expand Up @@ -112,6 +113,7 @@ public class SqlQueryExecution
private final NodeScheduler nodeScheduler;
private final NodeAllocatorService nodeAllocatorService;
private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory;
private final OutputDataSizeEstimatorFactory outputDataSizeEstimatorFactory;
private final TaskExecutionStats taskExecutionStats;
private final List<PlanOptimizer> planOptimizers;
private final PlanFragmenter planFragmenter;
Expand Down Expand Up @@ -150,6 +152,7 @@ private SqlQueryExecution(
NodeScheduler nodeScheduler,
NodeAllocatorService nodeAllocatorService,
PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory,
OutputDataSizeEstimatorFactory outputDataSizeEstimatorFactory,
TaskExecutionStats taskExecutionStats,
List<PlanOptimizer> planOptimizers,
PlanFragmenter planFragmenter,
Expand Down Expand Up @@ -182,6 +185,7 @@ private SqlQueryExecution(
this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null");
this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null");
this.partitionMemoryEstimatorFactory = requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null");
this.outputDataSizeEstimatorFactory = requireNonNull(outputDataSizeEstimatorFactory, "outputDataSizeEstimatorFactory is null");
this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null");
this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null");
this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null");
Expand Down Expand Up @@ -552,6 +556,7 @@ private void planDistribution(PlanRoot plan)
tracer,
schedulerStats,
partitionMemoryEstimatorFactory,
outputDataSizeEstimatorFactory,
nodePartitioningManager,
exchangeManagerRegistry.getExchangeManager(),
nodeAllocatorService,
Expand Down Expand Up @@ -747,6 +752,7 @@ public static class SqlQueryExecutionFactory
private final NodeScheduler nodeScheduler;
private final NodeAllocatorService nodeAllocatorService;
private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory;
private final OutputDataSizeEstimatorFactory outputDataSizeEstimatorFactory;
private final TaskExecutionStats taskExecutionStats;
private final List<PlanOptimizer> planOptimizers;
private final PlanFragmenter planFragmenter;
Expand Down Expand Up @@ -777,6 +783,7 @@ public static class SqlQueryExecutionFactory
NodeScheduler nodeScheduler,
NodeAllocatorService nodeAllocatorService,
PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory,
OutputDataSizeEstimatorFactory outputDataSizeEstimatorFactory,
TaskExecutionStats taskExecutionStats,
PlanOptimizersFactory planOptimizersFactory,
PlanFragmenter planFragmenter,
Expand Down Expand Up @@ -807,6 +814,7 @@ public static class SqlQueryExecutionFactory
this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null");
this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null");
this.partitionMemoryEstimatorFactory = requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null");
this.outputDataSizeEstimatorFactory = requireNonNull(outputDataSizeEstimatorFactory, "outputDataSizeEstimatorFactory is null");
this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null");
this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null");
this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
Expand Down Expand Up @@ -851,6 +859,7 @@ public QueryExecution createQueryExecution(
nodeScheduler,
nodeAllocatorService,
partitionMemoryEstimatorFactory,
outputDataSizeEstimatorFactory,
taskExecutionStats,
planOptimizers,
planFragmenter,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.primitives.ImmutableLongArray;
import io.trino.Session;
import io.trino.execution.StageId;
import io.trino.execution.scheduler.faulttolerant.EventDrivenFaultTolerantQueryScheduler.StageExecution;

import java.util.Optional;
import java.util.function.Function;

public class ByEagerParentOutputDataSizeEstimator
implements OutputDataSizeEstimator
{
public static class Factory
implements OutputDataSizeEstimatorFactory
{
@Override
public OutputDataSizeEstimator create(Session session)
{
return new ByEagerParentOutputDataSizeEstimator();
}
}

@Override
public Optional<OutputDataSizeEstimateResult> getEstimatedOutputDataSize(StageExecution stageExecution, Function<StageId, StageExecution> stageExecutionLookup, boolean parentEager)
{
if (!parentEager) {
return Optional.empty();
}

// use empty estimate as fallback for eager parents. It matches current logic of assessing if node should be processed eagerly or not.
// Currently, we use eager task exectuion only for stages with small FINAL LIMIT which implies small input from child stages (child stages will
// enforce small input via PARTIAL LIMIT)
int outputPartitionsCount = stageExecution.getSinkPartitioningScheme().getPartitionCount();
ImmutableLongArray.Builder estimateBuilder = ImmutableLongArray.builder(outputPartitionsCount);
for (int i = 0; i < outputPartitionsCount; ++i) {
estimateBuilder.add(0);
}
return Optional.of(new OutputDataSizeEstimateResult(estimateBuilder.build(), OutputDataSizeEstimateStatus.ESTIMATED_FOR_EAGER_PARENT));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.primitives.ImmutableLongArray;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.execution.StageId;
import io.trino.execution.scheduler.faulttolerant.EventDrivenFaultTolerantQueryScheduler.StageExecution;
import io.trino.spi.QueryId;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.RemoteSourceNode;

import java.util.List;
import java.util.Optional;
import java.util.function.Function;

import static io.trino.SystemSessionProperties.getFaultTolerantExecutionArbitraryDistributionComputeTaskTargetSizeMin;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionSmallStageEstimationThreshold;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionSmallStageSourceSizeMultiplier;
import static io.trino.SystemSessionProperties.isFaultTolerantExecutionSmallStageEstimationEnabled;
import static io.trino.SystemSessionProperties.isFaultTolerantExecutionSmallStageRequireNoMorePartitions;
import static java.util.Objects.requireNonNull;

public class BySmallStageOutputDataSizeEstimator
implements OutputDataSizeEstimator
{
public static class Factory
implements OutputDataSizeEstimatorFactory
{
@Override
public OutputDataSizeEstimator create(Session session)
{
return new BySmallStageOutputDataSizeEstimator(
session.getQueryId(),
isFaultTolerantExecutionSmallStageEstimationEnabled(session),
getFaultTolerantExecutionSmallStageEstimationThreshold(session),
getFaultTolerantExecutionSmallStageSourceSizeMultiplier(session),
getFaultTolerantExecutionArbitraryDistributionComputeTaskTargetSizeMin(session),
isFaultTolerantExecutionSmallStageRequireNoMorePartitions(session));
}
}

private final QueryId queryId;
private final boolean smallStageEstimationEnabled;
private final DataSize smallStageEstimationThreshold;
private final double smallStageSourceSizeMultiplier;
private final DataSize smallSizePartitionSizeEstimate;
private final boolean smallStageRequireNoMorePartitions;

private BySmallStageOutputDataSizeEstimator(
QueryId queryId,
boolean smallStageEstimationEnabled,
DataSize smallStageEstimationThreshold,
double smallStageSourceSizeMultiplier,
DataSize smallSizePartitionSizeEstimate,
boolean smallStageRequireNoMorePartitions)
{
this.queryId = requireNonNull(queryId, "queryId is null");
this.smallStageEstimationEnabled = smallStageEstimationEnabled;
this.smallStageEstimationThreshold = requireNonNull(smallStageEstimationThreshold, "smallStageEstimationThreshold is null");
this.smallStageSourceSizeMultiplier = smallStageSourceSizeMultiplier;
this.smallSizePartitionSizeEstimate = requireNonNull(smallSizePartitionSizeEstimate, "smallSizePartitionSizeEstimate is null");
this.smallStageRequireNoMorePartitions = smallStageRequireNoMorePartitions;
}

@Override
public Optional<OutputDataSizeEstimateResult> getEstimatedOutputDataSize(StageExecution stageExecution, Function<StageId, StageExecution> stageExecutionLookup, boolean parentEager)
{
if (!smallStageEstimationEnabled) {
return Optional.empty();
}

if (smallStageRequireNoMorePartitions && !stageExecution.isNoMorePartitions()) {
return Optional.empty();
}

long[] currentOutputDataSize = stageExecution.currentOutputDataSize();
long totaleOutputDataSize = 0;
for (long partitionOutputDataSize : currentOutputDataSize) {
totaleOutputDataSize += partitionOutputDataSize;
}
if (totaleOutputDataSize > smallStageEstimationThreshold.toBytes()) {
// our output is too big already
return Optional.empty();
}

PlanFragment planFragment = stageExecution.getStageInfo().getPlan();
boolean hasPartitionedSources = planFragment.getPartitionedSources().size() > 0;
List<RemoteSourceNode> remoteSourceNodes = planFragment.getRemoteSourceNodes();

long partitionedInputSizeEstimate = 0;
if (hasPartitionedSources) {
if (!stageExecution.isNoMorePartitions()) {
// stage is reading directly from table
// for leaf stages require all tasks to be enumerated
return Optional.empty();
}
// estimate partitioned input based on number of task partitions
partitionedInputSizeEstimate += stageExecution.getPartitionsCount() * smallSizePartitionSizeEstimate.toBytes();
}

long remoteInputSizeEstimate = 0;
for (RemoteSourceNode remoteSourceNode : remoteSourceNodes) {
for (PlanFragmentId sourceFragmentId : remoteSourceNode.getSourceFragmentIds()) {
StageId sourceStageId = StageId.create(queryId, sourceFragmentId);

StageExecution sourceStage = stageExecutionLookup.apply(sourceStageId);
requireNonNull(sourceStage, "sourceStage is null");
Optional<OutputDataSizeEstimateResult> sourceStageOutputDataSize = sourceStage.getOutputDataSize(stageExecutionLookup, false);

if (sourceStageOutputDataSize.isEmpty()) {
// cant estimate size of one of sources; should not happen in practice
return Optional.empty();
}

remoteInputSizeEstimate += sourceStageOutputDataSize.orElseThrow().outputDataSizeEstimate().getTotalSizeInBytes();
}
}

long inputSizeEstimate = (long) ((partitionedInputSizeEstimate + remoteInputSizeEstimate) * smallStageSourceSizeMultiplier);
if (inputSizeEstimate > smallStageEstimationThreshold.toBytes()) {
return Optional.empty();
}

int outputPartitionsCount = stageExecution.getSinkPartitioningScheme().getPartitionCount();
ImmutableLongArray.Builder estimateBuilder = ImmutableLongArray.builder(outputPartitionsCount);
for (int i = 0; i < outputPartitionsCount; ++i) {
// assume uniform distribution
// TODO; should we use distribution as in this.outputDataSize if we have some data there already?
estimateBuilder.add(inputSizeEstimate / outputPartitionsCount);
}
return Optional.of(new OutputDataSizeEstimateResult(estimateBuilder.build(), OutputDataSizeEstimateStatus.ESTIMATED_BY_SMALL_INPUT));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.primitives.ImmutableLongArray;
import io.trino.Session;
import io.trino.execution.StageId;
import io.trino.execution.scheduler.OutputDataSizeEstimate;
import io.trino.execution.scheduler.faulttolerant.EventDrivenFaultTolerantQueryScheduler.StageExecution;

import java.util.Optional;
import java.util.function.Function;

import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinSourceStageProgress;

public class ByTaskProgressOutputDataSizeEstimator
implements OutputDataSizeEstimator
{
public static class Factory
implements OutputDataSizeEstimatorFactory
{
@Override
public OutputDataSizeEstimator create(Session session)
{
return new ByTaskProgressOutputDataSizeEstimator(getFaultTolerantExecutionMinSourceStageProgress(session));
}
}

private final double minSourceStageProgress;

private ByTaskProgressOutputDataSizeEstimator(double minSourceStageProgress)
{
this.minSourceStageProgress = minSourceStageProgress;
}

@Override
public Optional<OutputDataSizeEstimateResult> getEstimatedOutputDataSize(StageExecution stageExecution, Function<StageId, StageExecution> stageExecutionLookup, boolean parentEager)
{
if (!stageExecution.isNoMorePartitions()) {
return Optional.empty();
}

int allPartitionsCount = stageExecution.getPartitionsCount();
int remainingPartitionsCount = stageExecution.getRemainingPartitionsCount();

if (remainingPartitionsCount == allPartitionsCount) {
return Optional.empty();
}

double progress = (double) (allPartitionsCount - remainingPartitionsCount) / allPartitionsCount;

if (progress < minSourceStageProgress) {
return Optional.empty();
}

long[] currentOutputDataSize = stageExecution.currentOutputDataSize();

ImmutableLongArray.Builder estimateBuilder = ImmutableLongArray.builder(currentOutputDataSize.length);

for (long partitionSize : currentOutputDataSize) {
estimateBuilder.add((long) (partitionSize / progress));
}
return Optional.of(new OutputDataSizeEstimateResult(new OutputDataSizeEstimate(estimateBuilder.build()), OutputDataSizeEstimateStatus.ESTIMATED_BY_PROGRESS));
}
}
Loading

0 comments on commit d17ae7a

Please sign in to comment.