diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java index 50474325ed2c..d38f28da7520 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java @@ -39,12 +39,7 @@ public class ScaledWriterScheduler implements StageScheduler { - private interface TaskScheduler - { - RemoteTask scheduleTask(Node node, int partition, OptionalInt totalPartitions); - } - - private final TaskScheduler taskScheduler; + private final SqlStageExecution stage; private final Supplier> sourceTasksProvider; private final Supplier> writerTasksProvider; private final NodeSelector nodeSelector; @@ -62,8 +57,7 @@ public ScaledWriterScheduler( ScheduledExecutorService executor, DataSize writerMinSize) { - requireNonNull(stage, "stage is null"); - this.taskScheduler = stage::scheduleTask; + this.stage = requireNonNull(stage, "stage is null"); this.sourceTasksProvider = requireNonNull(sourceTasksProvider, "sourceTasksProvider is null"); this.writerTasksProvider = requireNonNull(writerTasksProvider, "writerTasksProvider is null"); this.nodeSelector = requireNonNull(nodeSelector, "nodeSelector is null"); @@ -125,7 +119,7 @@ private List scheduleTasks(int count) ImmutableList.Builder tasks = ImmutableList.builder(); for (Node node : nodes) { - tasks.add(taskScheduler.scheduleTask(node, scheduledNodes.size(), OptionalInt.empty())); + tasks.add(stage.scheduleTask(node, scheduledNodes.size(), OptionalInt.empty())); scheduledNodes.add(node); }