Skip to content

Commit

Permalink
[SPARK-16929] Improve performance when check speculatable tasks.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
1. Use a MedianHeap to record durations of successful tasks.  When check speculatable tasks, we can get the median duration with O(1) time complexity.

2. `checkSpeculatableTasks` will synchronize `TaskSchedulerImpl`. If `checkSpeculatableTasks` doesn't finish with 100ms, then the possibility exists for that thread to release and then immediately re-acquire the lock. Change `scheduleAtFixedRate` to be `scheduleWithFixedDelay` when call method of `checkSpeculatableTasks`.
## How was this patch tested?
Added MedianHeapSuite.

Author: jinxing <[email protected]>

Closes apache#16867 from jinxing64/SPARK-16929.
  • Loading branch information
jinxing authored and kayousterhout committed Mar 24, 2017
1 parent bb823ca commit 19596c2
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](

if (!isLocal && conf.getBoolean("spark.speculation", false)) {
logInfo("Starting speculative execution thread")
speculationScheduler.scheduleAtFixedRate(new Runnable {
speculationScheduler.scheduleWithFixedDelay(new Runnable {
override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
checkSpeculatableTasks()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ package org.apache.spark.scheduler

import java.io.NotSerializableException
import java.nio.ByteBuffer
import java.util.Arrays
import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.math.{max, min}
import scala.math.max
import scala.util.control.NonFatal

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.SchedulingMode._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils}
import org.apache.spark.util.collection.MedianHeap

/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
Expand Down Expand Up @@ -63,6 +63,8 @@ private[spark] class TaskSetManager(
// Limit of bytes for total size of results (default is 1GB)
val maxResultSize = Utils.getMaxResultSize(conf)

val speculationEnabled = conf.getBoolean("spark.speculation", false)

// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
Expand Down Expand Up @@ -141,6 +143,11 @@ private[spark] class TaskSetManager(
// Task index, start and finish time for each task attempt (indexed by task ID)
private val taskInfos = new HashMap[Long, TaskInfo]

// Use a MedianHeap to record durations of successful tasks so we know when to launch
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
// of inserting into the heap when the heap won't be used.
val successfulTaskDurations = new MedianHeap()

// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
conf.getLong("spark.logging.exceptionPrintInterval", 10000)
Expand Down Expand Up @@ -698,6 +705,9 @@ private[spark] class TaskSetManager(
val info = taskInfos(tid)
val index = info.index
info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
if (speculationEnabled) {
successfulTaskDurations.insert(info.duration)
}
removeRunningTask(tid)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
Expand Down Expand Up @@ -919,11 +929,10 @@ private[spark] class TaskSetManager(
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)

if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1))
var medianDuration = successfulTaskDurations.median
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

import scala.collection.mutable.PriorityQueue

/**
* MedianHeap is designed to be used to quickly track the median of a group of numbers
* that may contain duplicates. Inserting a new number has O(log n) time complexity and
* determining the median has O(1) time complexity.
* The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf
* stores the smaller half of all numbers while the largerHalf stores the larger half.
* The sizes of two heaps need to be balanced each time when a new number is inserted so
* that their sizes will not be different by more than 1. Therefore each time when
* findMedian() is called we check if two heaps have the same size. If they do, we should
* return the average of the two top values of heaps. Otherwise we return the top of the
* heap which has one more element.
*/
private[spark] class MedianHeap(implicit val ord: Ordering[Double]) {

/**
* Stores all the numbers less than the current median in a smallerHalf,
* i.e median is the maximum, at the root.
*/
private[this] var smallerHalf = PriorityQueue.empty[Double](ord)

/**
* Stores all the numbers greater than the current median in a largerHalf,
* i.e median is the minimum, at the root.
*/
private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse)

def isEmpty(): Boolean = {
smallerHalf.isEmpty && largerHalf.isEmpty
}

def size(): Int = {
smallerHalf.size + largerHalf.size
}

def insert(x: Double): Unit = {
// If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf.
if (isEmpty) {
largerHalf.enqueue(x)
} else {
// If the number is larger than current median, it should be inserted into largerHalf,
// otherwise smallerHalf.
if (x > median) {
largerHalf.enqueue(x)
} else {
smallerHalf.enqueue(x)
}
}
rebalance()
}

private[this] def rebalance(): Unit = {
if (largerHalf.size - smallerHalf.size > 1) {
smallerHalf.enqueue(largerHalf.dequeue())
}
if (smallerHalf.size - largerHalf.size > 1) {
largerHalf.enqueue(smallerHalf.dequeue)
}
}

def median: Double = {
if (isEmpty) {
throw new NoSuchElementException("MedianHeap is empty.")
}
if (largerHalf.size == smallerHalf.size) {
(largerHalf.head + smallerHalf.head) / 2.0
} else if (largerHalf.size > smallerHalf.size) {
largerHalf.head
} else {
smallerHalf.head
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(4)
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
sc.conf.set("spark.speculation.multiplier", "0.0")
sc.conf.set("spark.speculation", "true")
val clock = new ManualClock()
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
Expand Down Expand Up @@ -948,6 +949,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
sc.conf.set("spark.speculation.multiplier", "0.0")
sc.conf.set("spark.speculation.quantile", "0.6")
sc.conf.set("spark.speculation", "true")
val clock = new ManualClock()
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

import java.util.NoSuchElementException

import org.apache.spark.SparkFunSuite

class MedianHeapSuite extends SparkFunSuite {

test("If no numbers in MedianHeap, NoSuchElementException is thrown.") {
val medianHeap = new MedianHeap()
intercept[NoSuchElementException] {
medianHeap.median
}
}

test("Median should be correct when size of MedianHeap is even") {
val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
val medianHeap = new MedianHeap()
array.foreach(medianHeap.insert(_))
assert(medianHeap.size() === 10)
assert(medianHeap.median === 4.5)
}

test("Median should be correct when size of MedianHeap is odd") {
val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8)
val medianHeap = new MedianHeap()
array.foreach(medianHeap.insert(_))
assert(medianHeap.size() === 9)
assert(medianHeap.median === 4)
}

test("Median should be correct though there are duplicated numbers inside.") {
val array = Array(0, 0, 1, 1, 2, 3, 4)
val medianHeap = new MedianHeap()
array.foreach(medianHeap.insert(_))
assert(medianHeap.size === 7)
assert(medianHeap.median === 1)
}

test("Median should be correct when input data is skewed.") {
val medianHeap = new MedianHeap()
(0 until 10).foreach(_ => medianHeap.insert(5))
assert(medianHeap.median === 5)
(0 until 100).foreach(_ => medianHeap.insert(10))
assert(medianHeap.median === 10)
(0 until 1000).foreach(_ => medianHeap.insert(0))
assert(medianHeap.median === 0)
}
}

0 comments on commit 19596c2

Please sign in to comment.