Skip to content

Commit

Permalink
[SPARK-8167] Make tasks that fail from YARN preemption not fail job
Browse files Browse the repository at this point in the history
The architecture is that, in YARN mode, if the driver detects that an executor has disconnected, it asks the ApplicationMaster why the executor died. If the ApplicationMaster is aware that the executor died because of preemption, all tasks associated with that executor are not marked as failed. The executor
is still removed from the driver's list of available executors, however.

There's a few open questions:
1. Should standalone mode have a similar "get executor loss reason" as well? I localized this change as much as possible to affect only YARN, but there could be a valid case to differentiate executor losses in standalone mode as well.
2. I make a pretty strong assumption in YarnAllocator that getExecutorLossReason(executorId) will only be called once per executor id; I do this so that I can remove the metadata from the in-memory map to avoid object accumulation. It's not clear if I'm being overly zealous to save space, however.

cc vanzin specifically for review because it collided with some earlier YARN scheduling work.
cc JoshRosen because it's similar to output commit coordination we did in the past
cc andrewor14 for our discussion on how to get executor exit codes and loss reasons

Author: mcheah <[email protected]>

Closes apache#8007 from mccheah/feature/preemption-handling.
  • Loading branch information
mccheah authored and Andrew Or committed Sep 10, 2015
1 parent a76bde9 commit af3bc59
Show file tree
Hide file tree
Showing 17 changed files with 261 additions and 79 deletions.
18 changes: 16 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ case object Success extends TaskEndReason
sealed trait TaskFailedReason extends TaskEndReason {
/** Error message displayed in the web UI. */
def toErrorString: String

def shouldEventuallyFailJob: Boolean = true
}

/**
Expand Down Expand Up @@ -194,6 +196,12 @@ case object TaskKilled extends TaskFailedReason {
case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason {
override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
/**
* If a task failed because its attempt to commit was denied, do not count this failure
* towards failing the stage. This is intended to prevent spurious stage failures in cases
* where many speculative tasks are launched and denied to commit.
*/
override def shouldEventuallyFailJob: Boolean = false
}

/**
Expand All @@ -202,8 +210,14 @@ case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extend
* the task crashed the JVM.
*/
@DeveloperApi
case class ExecutorLostFailure(execId: String) extends TaskFailedReason {
override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)"
case class ExecutorLostFailure(execId: String, isNormalExit: Boolean = false)
extends TaskFailedReason {
override def toErrorString: String = {
val exitBehavior = if (isNormalExit) "normally" else "abnormally"
s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})"
}

override def shouldEventuallyFailJob: Boolean = !isNormalExit
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ import org.apache.spark.executor.ExecutorExitCode
* Represents an explanation for a executor or whole slave failing or exiting.
*/
private[spark]
class ExecutorLossReason(val message: String) {
class ExecutorLossReason(val message: String) extends Serializable {
override def toString: String = message
}

private[spark]
case class ExecutorExited(val exitCode: Int)
extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) {
case class ExecutorExited(exitCode: Int, isNormalExit: Boolean, reason: String)
extends ExecutorLossReason(reason)

private[spark] object ExecutorExited {
def apply(exitCode: Int, isNormalExit: Boolean): ExecutorExited = {
ExecutorExited(exitCode, isNormalExit, ExecutorExitCode.explainExitCode(exitCode))
}
}

private[spark]
case class SlaveLost(_message: String = "Slave lost")
extends ExecutorLossReason(_message) {
}
extends ExecutorLossReason(_message)
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Pool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ private[spark] class Pool(
null
}

override def executorLost(executorId: String, host: String) {
schedulableQueue.asScala.foreach(_.executorLost(executorId, host))
override def executorLost(executorId: String, host: String, reason: ExecutorLossReason) {
schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason))
}

override def checkSpeculatableTasks(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[spark] trait Schedulable {
def addSchedulable(schedulable: Schedulable): Unit
def removeSchedulable(schedulable: Schedulable): Unit
def getSchedulableByName(name: String): Schedulable
def executorLost(executorId: String, host: String): Unit
def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit
def checkSpeculatableTasks(): Boolean
def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
}
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ private[spark] class TaskSchedulerImpl(
// We lost this entire executor, so remember that it's gone
val execId = taskIdToExecutorId(tid)
if (activeExecutorIds.contains(execId)) {
removeExecutor(execId)
removeExecutor(execId,
SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
failedExecutor = Some(execId)
}
}
Expand Down Expand Up @@ -464,7 +465,7 @@ private[spark] class TaskSchedulerImpl(
if (activeExecutorIds.contains(executorId)) {
val hostPort = executorIdToHost(executorId)
logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
removeExecutor(executorId, reason)
failedExecutor = Some(executorId)
} else {
// We may get multiple executorLost() calls with different loss reasons. For example, one
Expand All @@ -482,7 +483,7 @@ private[spark] class TaskSchedulerImpl(
}

/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
private def removeExecutor(executorId: String, reason: ExecutorLossReason) {
activeExecutorIds -= executorId
val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
Expand All @@ -497,7 +498,7 @@ private[spark] class TaskSchedulerImpl(
}
}
executorIdToHost -= executorId
rootPool.executorLost(executorId, host)
rootPool.executorLost(executorId, host, reason)
}

def executorAdded(execId: String, host: String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,11 @@ private[spark] class TaskSetManager(
}
ef.exception

case e: ExecutorLostFailure if e.isNormalExit =>
logInfo(s"Task $tid failed because while it was being computed, its executor" +
s" exited normally. Not marking the task as failed.")
None

case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
logWarning(failureReason)
None
Expand All @@ -722,10 +727,9 @@ private[spark] class TaskSetManager(
put(info.executorId, clock.getTimeMillis())
sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
addPendingTask(index)
if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) {
// If a task failed because its attempt to commit was denied, do not count this failure
// towards failing the stage. This is intended to prevent spurious stage failures in cases
// where many speculative tasks are launched and denied to commit.
if (!isZombie && state != TaskState.KILLED
&& reason.isInstanceOf[TaskFailedReason]
&& reason.asInstanceOf[TaskFailedReason].shouldEventuallyFailJob) {
assert (null != failureReason)
numFailures(index) += 1
if (numFailures(index) >= maxTaskFailures) {
Expand Down Expand Up @@ -778,7 +782,7 @@ private[spark] class TaskSetManager(
}

/** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
override def executorLost(execId: String, host: String) {
override def executorLost(execId: String, host: String, reason: ExecutorLossReason) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)

// Re-enqueue pending tasks for this host based on the status of the cluster. Note
Expand Down Expand Up @@ -809,9 +813,12 @@ private[spark] class TaskSetManager(
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId))
val isNormalExit: Boolean = reason match {
case exited: ExecutorExited => exited.isNormalExit
case _ => false
}
handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, isNormalExit))
}
// recalculate valid locality levels and waits when executor is lost
recomputeLocality()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import org.apache.spark.TaskState.TaskState
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.ExecutorLossReason
import org.apache.spark.util.{SerializableBuffer, Utils}

private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
Expand Down Expand Up @@ -70,7 +71,8 @@ private[spark] object CoarseGrainedClusterMessages {

case object StopExecutors extends CoarseGrainedClusterMessage

case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
case class RemoveExecutor(executorId: String, reason: ExecutorLossReason)
extends CoarseGrainedClusterMessage

case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage

Expand All @@ -92,6 +94,10 @@ private[spark] object CoarseGrainedClusterMessages {
hostToLocalTaskCount: Map[String, Int])
extends CoarseGrainedClusterMessage

// Check if an executor was force-killed but for a normal reason.
// This could be the case if the executor is preempted, for instance.
case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage

case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rpc._
import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME
import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils}

/**
Expand Down Expand Up @@ -82,7 +83,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp

override protected def log = CoarseGrainedSchedulerBackend.this.log

private val addressToExecutorId = new HashMap[RpcAddress, String]
protected val addressToExecutorId = new HashMap[RpcAddress, String]

private val reviveThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread")
Expand Down Expand Up @@ -128,6 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) =>
Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorDataMap.contains(executorId)) {
Expand Down Expand Up @@ -185,8 +187,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}

override def onDisconnected(remoteAddress: RpcAddress): Unit = {
addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_,
"remote Rpc client disassociated"))
addressToExecutorId
.get(remoteAddress)
.foreach(removeExecutor(_, SlaveLost("remote Rpc client disassociated")))
}

// Make fake resource offers on just one executor
Expand Down Expand Up @@ -227,7 +230,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}

// Remove a disconnected slave from the cluster
def removeExecutor(executorId: String, reason: String): Unit = {
def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
// This must be synchronized because variables mutated
Expand All @@ -239,9 +242,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
totalCoreCount.addAndGet(-executorInfo.totalCores)
totalRegisteredExecutors.addAndGet(-1)
scheduler.executorLost(executorId, SlaveLost(reason))
scheduler.executorLost(executorId, reason)
listenerBus.post(
SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason))
SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString))
case None => logInfo(s"Asked to remove non-existent executor $executorId")
}
}
Expand All @@ -263,8 +266,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}

// TODO (prashant) send conf instead of properties
driverEndpoint = rpcEnv.setupEndpoint(
CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties))
driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties))
}

protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
new DriverEndpoint(rpcEnv, properties)
}

def stopExecutors() {
Expand Down Expand Up @@ -304,7 +310,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}

// Called by subclasses when notified of a lost worker
def removeExecutor(executorId: String, reason: String) {
def removeExecutor(executorId: String, reason: ExecutorLossReason) {
try {
driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason))
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.rpc.RpcAddress
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils

private[spark] class SparkDeploySchedulerBackend(
Expand Down Expand Up @@ -135,11 +135,11 @@ private[spark] class SparkDeploySchedulerBackend(

override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) {
val reason: ExecutorLossReason = exitStatus match {
case Some(code) => ExecutorExited(code)
case Some(code) => ExecutorExited(code, isNormalExit = true, message)
case None => SlaveLost(message)
}
logInfo("Executor %s removed: %s".format(fullId, message))
removeExecutor(fullId.split("/")(1), reason.toString)
removeExecutor(fullId.split("/")(1), reason)
}

override def sufficientResourcesRegistered(): Boolean = {
Expand Down
Loading

0 comments on commit af3bc59

Please sign in to comment.