Skip to content

Commit

Permalink
[SPARK-6394][Core] cleanup BlockManager companion object and improve …
Browse files Browse the repository at this point in the history
…the getCacheLocs method in DAGScheduler

The current implementation include searching a HashMap many times, we can avoid this.
Actually if you look into `BlockManager.blockIdsToBlockManagers`, the core function call is [this](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/BlockManager.scala#L1258), so we can call `blockManagerMaster.getLocations` directly and avoid building a HashMap.

Author: Wenchen Fan <[email protected]>

Closes apache#5043 from cloud-fan/small and squashes the following commits:

e959d12 [Wenchen Fan] fix style
203c493 [Wenchen Fan] some cleanup in BlockManager companion object
d409099 [Wenchen Fan] address rxin's comment
faec999 [Wenchen Fan] add regression test
2fb57aa [Wenchen Fan] imporve the getCacheLocs method
  • Loading branch information
cloud-fan authored and JoshRosen committed Mar 19, 2015
1 parent 3db1387 commit 540b2a4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class DAGScheduler(
*
* All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
*/
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
private val cacheLocs = new HashMap[Int, Seq[Seq[TaskLocation]]]

// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
// every task. When we detect a node failing, we note the current epoch number and failed
Expand Down Expand Up @@ -188,14 +188,15 @@ class DAGScheduler(
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}

private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized {
private[scheduler]
def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized {
// Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms =>
bms.map(bm => TaskLocation(bm.host, bm.executorId))
}
cacheLocs(rdd.id) = locs
}
cacheLocs(rdd.id)
}
Expand Down
22 changes: 4 additions & 18 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1245,10 +1245,10 @@ private[spark] object BlockManager extends Logging {
}
}

def blockIdsToBlockManagers(
def blockIdsToHosts(
blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[BlockManagerId]] = {
blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {

// blockManagerMaster != null is used in tests
assert(env != null || blockManagerMaster != null)
Expand All @@ -1258,24 +1258,10 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster.getLocations(blockIds)
}

val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]]
val blockManagers = new HashMap[BlockId, Seq[String]]
for (i <- 0 until blockIds.length) {
blockManagers(blockIds(i)) = blockLocations(i)
blockManagers(blockIds(i)) = blockLocations(i).map(_.host)
}
blockManagers.toMap
}

def blockIdsToExecutorIds(
blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
}

def blockIdsToHosts(
blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assertDataStructuresEmpty
}

test("regression test for getCacheLocs") {
val rdd = new MyRDD(sc, 3, Nil)
cacheLocations(rdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
cacheLocations(rdd.id -> 1) =
Seq(makeBlockManagerId("hostB"), makeBlockManagerId("hostC"))
cacheLocations(rdd.id -> 2) =
Seq(makeBlockManagerId("hostC"), makeBlockManagerId("hostD"))
val locs = scheduler.getCacheLocs(rdd).map(_.map(_.host))
assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD")))
}

test("avoid exponential blowup when getting preferred locs list") {
// Build up a complex dependency graph with repeated zip operations, without preferred locations.
var rdd: RDD[_] = new MyRDD(sc, 1, Nil)
Expand Down

0 comments on commit 540b2a4

Please sign in to comment.