Skip to content

Commit

Permalink
[SPARK-11256] Mark all Stage/ResultStage/ShuffleMapStage internal sta…
Browse files Browse the repository at this point in the history
…te as private.

Author: Reynold Xin <[email protected]>

Closes apache#9219 from rxin/stage-cleanup1.
  • Loading branch information
rxin committed Nov 3, 2015
1 parent d188a67 commit 57446eb
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 38 deletions.
33 changes: 16 additions & 17 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.Map
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
import scala.collection.mutable.{HashMap, HashSet, Stack}
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
Expand Down Expand Up @@ -535,10 +535,8 @@ class DAGScheduler(
jobIdToActiveJob -= job.jobId
activeJobs -= job
job.finalStage match {
case r: ResultStage =>
r.resultOfJob = None
case m: ShuffleMapStage =>
m.mapStageJobs = m.mapStageJobs.filter(_ != job)
case r: ResultStage => r.removeActiveJob()
case m: ShuffleMapStage => m.removeActiveJob(job)
}
}

Expand Down Expand Up @@ -848,7 +846,7 @@ class DAGScheduler(
val jobSubmissionTime = clock.getTimeMillis()
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.resultOfJob = Some(job)
finalStage.setActiveJob(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
listenerBus.post(
Expand Down Expand Up @@ -880,15 +878,15 @@ class DAGScheduler(
val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
clearCacheLocs()
logInfo("Got map stage job %s (%s) with %d output partitions".format(
jobId, callSite.shortForm, dependency.rdd.partitions.size))
jobId, callSite.shortForm, dependency.rdd.partitions.length))
logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))

val jobSubmissionTime = clock.getTimeMillis()
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.mapStageJobs = job :: finalStage.mapStageJobs
finalStage.addActiveJob(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
listenerBus.post(
Expand Down Expand Up @@ -950,12 +948,12 @@ class DAGScheduler(
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
outputCommitCoordinator.stageStart(stage.id)
val taskIdToLocations = try {
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
stage match {
case s: ShuffleMapStage =>
partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage =>
val job = s.resultOfJob.get
val job = s.activeJob.get
partitionsToCompute.map { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
Expand Down Expand Up @@ -1016,7 +1014,7 @@ class DAGScheduler(
}

case stage: ResultStage =>
val job = stage.resultOfJob.get
val job = stage.activeJob.get
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
val part = stage.rdd.partitions(p)
Expand Down Expand Up @@ -1132,7 +1130,7 @@ class DAGScheduler(
// Cast to ResultStage here because it's part of the ResultTask
// TODO Refactor this out to a function that accepts a ResultStage
val resultStage = stage.asInstanceOf[ResultStage]
resultStage.resultOfJob match {
resultStage.activeJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
updateAccumulators(event)
Expand Down Expand Up @@ -1187,7 +1185,7 @@ class DAGScheduler(
// we registered these map outputs.
mapOutputTracker.registerMapOutputs(
shuffleStage.shuffleDep.shuffleId,
shuffleStage.outputLocs.map(_.headOption.orNull),
shuffleStage.outputLocInMapOutputTrackerFormat(),
changeEpoch = true)

clearCacheLocs()
Expand All @@ -1197,8 +1195,7 @@ class DAGScheduler(
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
") because some of its tasks had failed: " +
shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty)
.map(_._2).mkString(", "))
shuffleStage.findMissingPartitions().mkString(", "))
submitStage(shuffleStage)
} else {
// Mark any map-stage jobs waiting on this stage as finished
Expand Down Expand Up @@ -1312,8 +1309,10 @@ class DAGScheduler(
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnExecutor(execId)
val locs = stage.outputLocs.map(_.headOption.orNull)
mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
mapOutputTracker.registerMapOutputs(
shuffleId,
stage.outputLocInMapOutputTrackerFormat(),
changeEpoch = true)
}
if (shuffleToMapStage.isEmpty) {
mapOutputTracker.incrementEpoch()
Expand Down
19 changes: 17 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,25 @@ private[spark] class ResultStage(
* The active job for this result stage. Will be empty if the job has already finished
* (e.g., because the job was cancelled).
*/
var resultOfJob: Option[ActiveJob] = None
private[this] var _activeJob: Option[ActiveJob] = None

def activeJob: Option[ActiveJob] = _activeJob

def setActiveJob(job: ActiveJob): Unit = {
_activeJob = Option(job)
}

def removeActiveJob(): Unit = {
_activeJob = None
}

/**
* Returns the sequence of partition ids that are missing (i.e. needs to be computed).
*
* This can only be called when there is an active job.
*/
override def findMissingPartitions(): Seq[Int] = {
val job = resultOfJob.get
val job = activeJob.get
(0 until job.numPartitions).filter(id => !job.finished(id))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,43 +43,61 @@ private[spark] class ShuffleMapStage(
val shuffleDep: ShuffleDependency[_, _, _])
extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {

private[this] var _mapStageJobs: List[ActiveJob] = Nil

private[this] var _numAvailableOutputs: Int = 0

/**
* List of [[MapStatus]] for each partition. The index of the array is the map partition id,
* and each value in the array is the list of possible [[MapStatus]] for a partition
* (a single task might run multiple times).
*/
private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)

override def toString: String = "ShuffleMapStage " + id

/** Running map-stage jobs that were submitted to execute this stage independently (if any) */
var mapStageJobs: List[ActiveJob] = Nil
/**
* Returns the list of active jobs,
* i.e. map-stage jobs that were submitted to execute this stage independently (if any).
*/
def mapStageJobs: Seq[ActiveJob] = _mapStageJobs

/** Adds the job to the active job list. */
def addActiveJob(job: ActiveJob): Unit = {
_mapStageJobs = job :: _mapStageJobs
}

/** Removes the job from the active job list. */
def removeActiveJob(job: ActiveJob): Unit = {
_mapStageJobs = _mapStageJobs.filter(_ != job)
}

/**
* Number of partitions that have shuffle outputs.
* When this reaches [[numPartitions]], this map stage is ready.
* This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
*/
var numAvailableOutputs: Int = 0
def numAvailableOutputs: Int = _numAvailableOutputs

/**
* Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
* This should be the same as `outputLocs.contains(Nil)`.
*/
def isAvailable: Boolean = numAvailableOutputs == numPartitions

/**
* List of [[MapStatus]] for each partition. The index of the array is the map partition id,
* and each value in the array is the list of possible [[MapStatus]] for a partition
* (a single task might run multiple times).
*/
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
def isAvailable: Boolean = _numAvailableOutputs == numPartitions

/** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
override def findMissingPartitions(): Seq[Int] = {
val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
assert(missing.size == numPartitions - numAvailableOutputs,
s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}")
assert(missing.size == numPartitions - _numAvailableOutputs,
s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
missing
}

def addOutputLoc(partition: Int, status: MapStatus): Unit = {
val prevList = outputLocs(partition)
outputLocs(partition) = status :: prevList
if (prevList == Nil) {
numAvailableOutputs += 1
_numAvailableOutputs += 1
}
}

Expand All @@ -88,10 +106,19 @@ private[spark] class ShuffleMapStage(
val newList = prevList.filterNot(_.location == bmAddress)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
_numAvailableOutputs -= 1
}
}

/**
* Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned
* value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition,
* that position is filled with null.
*/
def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
outputLocs.map(_.headOption.orNull)
}

/**
* Removes all shuffle outputs associated with this executor. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists), as they are still
Expand All @@ -105,12 +132,12 @@ private[spark] class ShuffleMapStage(
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
becameUnavailable = true
numAvailableOutputs -= 1
_numAvailableOutputs -= 1
}
}
if (becameUnavailable) {
logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
this, execId, numAvailableOutputs, numPartitions, isAvailable))
this, execId, _numAvailableOutputs, numPartitions, isAvailable))
}
}
}
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ private[scheduler] abstract class Stage(
/** The ID to use for the next new attempt for this stage. */
private var nextAttemptId: Int = 0

val name = callSite.shortForm
val details = callSite.longForm
val name: String = callSite.shortForm
val details: String = callSite.longForm

private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty

Expand Down Expand Up @@ -134,6 +134,7 @@ private[scheduler] abstract class Stage(
def latestInfo: StageInfo = _latestInfo

override final def hashCode(): Int = id

override final def equals(other: Any): Boolean = other match {
case stage: Stage => stage != null && stage.id == id
case _ => false
Expand Down

0 comments on commit 57446eb

Please sign in to comment.