diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 3e3f1ad031..67446da0a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -93,10 +93,12 @@ private[spark] class TaskSchedulerImpl( // Incrementing task IDs val nextTaskId = new AtomicLong(0) - // Number of tasks running on each executor - private val executorIdToTaskCount = new HashMap[String, Int] + // IDs of the tasks running on each executor + private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]] - def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap + def runningTasksByExecutors(): Map[String, Int] = { + executorIdToRunningTaskIds.toMap.mapValues(_.size) + } // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host @@ -264,7 +266,7 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId - executorIdToTaskCount(execId) += 1 + executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) launchedTask = true @@ -294,11 +296,11 @@ private[spark] class TaskSchedulerImpl( if (!hostToExecutors.contains(o.host)) { hostToExecutors(o.host) = new HashSet[String]() } - if (!executorIdToTaskCount.contains(o.executorId)) { + if (!executorIdToRunningTaskIds.contains(o.executorId)) { hostToExecutors(o.host) += o.executorId executorAdded(o.executorId, o.host) executorIdToHost(o.executorId) = o.host - executorIdToTaskCount(o.executorId) = 0 + executorIdToRunningTaskIds(o.executorId) = HashSet[Long]() newExecAvail = true } for (rack <- getRackForHost(o.host)) { @@ -349,38 +351,34 @@ private[spark] class TaskSchedulerImpl( var reason: Option[ExecutorLossReason] = None synchronized { try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - - if (executorIdToTaskCount.contains(execId)) { - reason = Some( - SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) - removeExecutor(execId, reason.get) - failedExecutor = Some(execId) - } - } taskIdToTaskSetManager.get(tid) match { case Some(taskSet) => - if (TaskState.isFinished(state)) { - taskIdToTaskSetManager.remove(tid) - taskIdToExecutorId.remove(tid).foreach { execId => - if (executorIdToTaskCount.contains(execId)) { - executorIdToTaskCount(execId) -= 1 - } + if (state == TaskState.LOST) { + // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode, + // where each executor corresponds to a single task, so mark the executor as failed. + val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException( + "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)")) + if (executorIdToRunningTaskIds.contains(execId)) { + reason = Some( + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) + removeExecutor(execId, reason.get) + failedExecutor = Some(execId) } } - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + if (TaskState.isFinished(state)) { + cleanupTaskState(tid) taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + if (state == TaskState.FINISHED) { + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + } } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") + "likely the result of receiving duplicate task finished status updates) or its " + + "executor has been marked as failed.") .format(state, tid)) } } catch { @@ -491,7 +489,7 @@ private[spark] class TaskSchedulerImpl( var failedExecutor: Option[String] = None synchronized { - if (executorIdToTaskCount.contains(executorId)) { + if (executorIdToRunningTaskIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) @@ -533,13 +531,31 @@ private[spark] class TaskSchedulerImpl( logError(s"Lost executor $executorId on $hostPort: $reason") } + /** + * Cleans up the TaskScheduler's state for tracking the given task. + */ + private def cleanupTaskState(tid: Long): Unit = { + taskIdToTaskSetManager.remove(tid) + taskIdToExecutorId.remove(tid).foreach { executorId => + executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) } + } + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status * of any running tasks, since the loss reason defines whether we'll fail those tasks. */ private def removeExecutor(executorId: String, reason: ExecutorLossReason) { - executorIdToTaskCount -= executorId + // The tasks on the lost executor may not send any more status updates (because the executor + // has been lost), so they should be cleaned up here. + executorIdToRunningTaskIds.remove(executorId).foreach { taskIds => + logDebug("Cleaning up TaskScheduler state for tasks " + + s"${taskIds.mkString("[", ",", "]")} on failed executor $executorId") + // We do not notify the TaskSetManager of the task failures because that will + // happen below in the rootPool.executorLost() call. + taskIds.foreach(cleanupTaskState) + } val host = executorIdToHost(executorId) val execs = hostToExecutors.getOrElse(host, new HashSet) @@ -577,11 +593,11 @@ private[spark] class TaskSchedulerImpl( } def isExecutorAlive(execId: String): Boolean = synchronized { - executorIdToTaskCount.contains(execId) + executorIdToRunningTaskIds.contains(execId) } def isExecutorBusy(execId: String): Boolean = synchronized { - executorIdToTaskCount.getOrElse(execId, -1) > 0 + executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty) } // By default, rack is unknown diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index e29eb8552e..05dad7a4b8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -433,10 +433,11 @@ class StandaloneDynamicAllocationSuite assert(executors.size === 2) // simulate running a task on the executor - val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount) + val getMap = + PrivateMethod[mutable.HashMap[String, mutable.HashSet[Long]]]('executorIdToRunningTaskIds) val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] - val executorIdToTaskCount = taskScheduler invokePrivate getMap() - executorIdToTaskCount(executors.head) = 1 + val executorIdToRunningTaskIds = taskScheduler invokePrivate getMap() + executorIdToRunningTaskIds(executors.head) = mutable.HashSet(1L) // kill the busy executor without force; this should fail assert(killExecutor(sc, executors.head, force = false).isEmpty) apps = getApplications() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 5dc7708530..59bea27596 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.nio.ByteBuffer + import scala.collection.mutable.HashMap import org.mockito.Matchers.{anyInt, anyString, eq => meq} @@ -648,4 +650,70 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3"))) } + test("if an executor is lost then the state for its running tasks is cleaned up (SPARK-18553)") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1)) + val attempt1 = FakeTask.createTaskSet(1) + + // submit attempt 1, offer resources, task gets scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten + assert(1 === taskDescriptions.length) + + // mark executor0 as dead + taskScheduler.executorLost("executor0", SlaveLost()) + assert(!taskScheduler.isExecutorAlive("executor0")) + assert(!taskScheduler.hasExecutorsAliveOnHost("host0")) + assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty) + + + // Check that state associated with the lost task attempt is cleaned up: + assert(taskScheduler.taskIdToExecutorId.isEmpty) + assert(taskScheduler.taskIdToTaskSetManager.isEmpty) + assert(taskScheduler.runningTasksByExecutors().get("executor0").isEmpty) + } + + test("if a task finishes with TaskState.LOST its executor is marked as dead") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1)) + val attempt1 = FakeTask.createTaskSet(1) + + // submit attempt 1, offer resources, task gets scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten + assert(1 === taskDescriptions.length) + + // Report the task as failed with TaskState.LOST + taskScheduler.statusUpdate( + tid = taskDescriptions.head.taskId, + state = TaskState.LOST, + serializedData = ByteBuffer.allocate(0) + ) + + // Check that state associated with the lost task attempt is cleaned up: + assert(taskScheduler.taskIdToExecutorId.isEmpty) + assert(taskScheduler.taskIdToTaskSetManager.isEmpty) + assert(taskScheduler.runningTasksByExecutors().get("executor0").isEmpty) + + // Check that the executor has been marked as dead + assert(!taskScheduler.isExecutorAlive("executor0")) + assert(!taskScheduler.hasExecutorsAliveOnHost("host0")) + assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty) + } }