Skip to content

Commit

Permalink
[CELEBORN-1167] Avoid calling parmap when destroy slots
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
As title

### Why are the changes needed?
One user reported that LifecycleManager's parmap can create huge number of threads and causes OOM.

![image](https://github.com/apache/incubator-celeborn/assets/948245/1e9a0b83-32fe-40d5-8739-2b370e030fc8)

There are four places where parmap is called:

1. When LifecycleManager commits files
2. When LifecycleManager reserves slots
3. When LifecycleManager setup connection to workers
4. When LifecycleManager call destroy slots

This PR fixes the fourth one. To be more detail, this PR eliminates `parmap` when destroying slots, and also replaces `askSync` with `ask`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Manual test and GA.

Closes apache#2156 from waitinfuture/1167.

Lead-authored-by: zky.zhoukeyong <[email protected]>
Co-authored-by: cxzl25 <[email protected]>
Co-authored-by: Keyong Zhou <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
  • Loading branch information
3 people committed Dec 15, 2023
1 parent 41df4eb commit 01feb93
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
private val rpcCacheSize = conf.clientRpcCacheSize
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
private val rpcMaxRetires = conf.clientRpcMaxRetries

private val excludedWorkersFilter = conf.registerShuffleFilterExcludedWorkerEnabled

Expand All @@ -106,6 +107,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]

private val mockDestroyFailure = conf.testMockDestroySlotsFailure

@VisibleForTesting
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, ShufflePartitionLocationInfo] =
shuffleAllocatedWorkers.get(shuffleId)
Expand Down Expand Up @@ -1293,6 +1296,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
slots
}

case class DestroyFutureWithStatus(
var future: Future[DestroyWorkerSlotsResponse],
message: DestroyWorkerSlots,
endpoint: RpcEndpointRef,
var retryTimes: Int,
var startTime: Long)

/**
* For the slots that need to be destroyed, LifecycleManager will ask the corresponding worker
* to destroy related FileWriter.
Expand All @@ -1305,24 +1315,91 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleId: Int,
slotsToDestroy: WorkerResource): Unit = {
val shuffleKey = Utils.makeShuffleKey(appUniqueId, shuffleId)
val parallelism = Math.min(Math.max(1, slotsToDestroy.size()), conf.clientRpcMaxParallelism)
ThreadUtils.parmap(
slotsToDestroy.asScala,
"DestroySlot",
parallelism) { case (workerInfo, (primaryLocations, replicaLocations)) =>
val destroy = DestroyWorkerSlots(
shuffleKey,
primaryLocations.asScala.map(_.getUniqueId).asJava,
replicaLocations.asScala.map(_.getUniqueId).asJava)
var res = requestWorkerDestroySlots(workerInfo.endpoint, destroy)
if (res.status != StatusCode.SUCCESS) {
logDebug(s"Request $destroy return ${res.status} for $shuffleKey, " +
s"will retry request destroy.")
res = requestWorkerDestroySlots(
workerInfo.endpoint,
DestroyWorkerSlots(shuffleKey, res.failedPrimarys, res.failedReplicas))

def retryDestroy(status: DestroyFutureWithStatus, currentTime: Long): Unit = {
status.retryTimes += 1
status.startTime = currentTime
// mock failure if mockDestroyFailure is true and this is not the last retry
status.message.mockFailure =
status.message.mockFailure && (status.retryTimes != rpcMaxRetires)
status.future =
status.endpoint.ask[DestroyWorkerSlotsResponse](status.message)
}

val startTime = System.currentTimeMillis()
val futures = new util.LinkedList[DestroyFutureWithStatus]()
slotsToDestroy.asScala.foreach { case (workerInfo, (primaryLocations, replicaLocations)) =>
val primaryIds = primaryLocations.asScala.map(_.getUniqueId).asJava
val replicaIds = replicaLocations.asScala.map(_.getUniqueId).asJava
val destroy = DestroyWorkerSlots(shuffleKey, primaryIds, replicaIds, mockDestroyFailure)
val future = workerInfo.endpoint.ask[DestroyWorkerSlotsResponse](destroy)
futures.add(DestroyFutureWithStatus(future, destroy, workerInfo.endpoint, 1, startTime))
}

val timeout = conf.rpcAskTimeout.duration.toMillis
var remainingTime = timeout * rpcMaxRetires
val delta = 50
while (remainingTime > 0 && !futures.isEmpty) {
val currentTime = System.currentTimeMillis()
val iter = futures.iterator()
while (iter.hasNext) {
val futureWithStatus = iter.next()
val message = futureWithStatus.message
val retryTimes = futureWithStatus.retryTimes
if (futureWithStatus.future.isCompleted) {
futureWithStatus.future.value.get match {
case scala.util.Success(res) =>
if (res.status != StatusCode.SUCCESS && retryTimes < rpcMaxRetires) {
logError(
s"Request $message to ${futureWithStatus.endpoint} return ${res.status} for $shuffleKey $retryTimes/$rpcMaxRetires, " +
"will retry.")
retryDestroy(futureWithStatus, currentTime)
} else {
if (res.status != StatusCode.SUCCESS && retryTimes == rpcMaxRetires) {
logError(
s"Request $message to ${futureWithStatus.endpoint} return ${res.status} for $shuffleKey $retryTimes/$rpcMaxRetires, " +
"will not retry.")
}
iter.remove()
}
case scala.util.Failure(e) =>
if (retryTimes < rpcMaxRetires) {
logError(
s"Request $message to ${futureWithStatus.endpoint} failed $retryTimes/$rpcMaxRetires for $shuffleKey, reason: $e, " +
"will retry.")
retryDestroy(futureWithStatus, currentTime)
} else {
if (retryTimes == rpcMaxRetires) {
logError(
s"Request $message to ${futureWithStatus.endpoint} failed $retryTimes/$rpcMaxRetires for $shuffleKey, reason: $e, " +
"will not retry.")
}
iter.remove()
}
}
} else if (currentTime - futureWithStatus.startTime > timeout) {
if (retryTimes < rpcMaxRetires) {
logError(
s"Request $message to ${futureWithStatus.endpoint} failed $retryTimes/$rpcMaxRetires for $shuffleKey, reason: Timeout, " +
"will retry.")
retryDestroy(futureWithStatus, currentTime)
} else {
if (retryTimes == rpcMaxRetires) {
logError(
s"Request $message to ${futureWithStatus.endpoint} failed $retryTimes/$rpcMaxRetires for $shuffleKey, reason: Timeout, " +
"will retry.")
}
iter.remove()
}
}
}

if (!futures.isEmpty) {
Thread.sleep(delta)
remainingTime -= delta
}
}
futures.clear()
}

private def removeExpiredShuffle(): Unit = {
Expand Down Expand Up @@ -1403,23 +1480,6 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}
}

private def requestWorkerDestroySlots(
endpoint: RpcEndpointRef,
message: DestroyWorkerSlots): DestroyWorkerSlotsResponse = {
try {
endpoint.askSync[DestroyWorkerSlotsResponse](message)
} catch {
case e: Exception =>
logError(
s"AskSync worker(${endpoint.address}) Destroy for ${message.shuffleKey} failed.",
e)
DestroyWorkerSlotsResponse(
StatusCode.REQUEST_FAILED,
message.primaryLocations,
message.replicaLocations)
}
}

private def requestMasterUnregisterShuffle(message: PbUnregisterShuffle)
: PbUnregisterShuffleResponse = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ case class CommitFilesParam(
primaryIds: util.List[String],
replicaIds: util.List[String])

case class FutureWithStatus(
case class CommitFutureWithStatus(
var future: Future[CommitFilesResponse],
commitFilesParam: CommitFilesParam,
var retriedTimes: Int)
var retriedTimes: Int,
var startTime: Long)

case class CommitResult(
primaryPartitionLocationMap: ConcurrentHashMap[String, PartitionLocation],
Expand Down Expand Up @@ -204,6 +205,26 @@ abstract class CommitHandler(
params: ArrayBuffer[CommitFilesParam],
commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {

def retryCommitFiles(status: CommitFutureWithStatus, currentTime: Long): Unit = {
status.retriedTimes = status.retriedTimes + 1
status.startTime = currentTime
status.future = commitFiles(
appUniqueId,
shuffleId,
status.commitFilesParam.worker,
status.commitFilesParam.primaryIds,
status.commitFilesParam.replicaIds)
}

def createFailResponse(status: CommitFutureWithStatus): CommitFilesResponse = {
CommitFilesResponse(
StatusCode.REQUEST_FAILED,
List.empty.asJava,
List.empty.asJava,
status.commitFilesParam.primaryIds,
status.commitFilesParam.replicaIds)
}

def processResponse(res: CommitFilesResponse, worker: WorkerInfo): Unit = {
shuffleCommittedInfo.synchronized {
// record committed partitionIds
Expand Down Expand Up @@ -241,8 +262,9 @@ abstract class CommitHandler(
}
}

val futures = new LinkedBlockingQueue[FutureWithStatus]()
val futures = new LinkedBlockingQueue[CommitFutureWithStatus]()

val startTime = System.currentTimeMillis()
val outFutures = params.filter(param =>
!CollectionUtils.isEmpty(param.primaryIds) ||
!CollectionUtils.isEmpty(param.replicaIds)) map { param =>
Expand All @@ -254,7 +276,7 @@ abstract class CommitHandler(
param.primaryIds,
param.replicaIds)

futures.add(FutureWithStatus(future, param, 1))
futures.add(CommitFutureWithStatus(future, param, 1, startTime))
}(ec)
}
val cbf =
Expand All @@ -264,9 +286,11 @@ abstract class CommitHandler(
awaitResult(futureSeq, Duration.Inf)

val maxRetries = conf.clientRequestCommitFilesMaxRetries
var timeout = conf.rpcAskTimeout.duration.toMillis * maxRetries
val timeout = conf.rpcAskTimeout.duration.toMillis
var remainingTime = timeout * maxRetries
val delta = 50
while (timeout >= 0 && !futures.isEmpty) {
while (remainingTime >= 0 && !futures.isEmpty) {
val currentTime = System.currentTimeMillis()
val iter = futures.iterator()
while (iter.hasNext) {
val status = iter.next()
Expand Down Expand Up @@ -295,44 +319,34 @@ abstract class CommitHandler(
s" (attempt ${status.retriedTimes}/$maxRetries).",
e)
if (status.retriedTimes < maxRetries) {
status.retriedTimes = status.retriedTimes + 1
status.future = commitFiles(
appUniqueId,
shuffleId,
status.commitFilesParam.worker,
status.commitFilesParam.primaryIds,
status.commitFilesParam.replicaIds)
retryCommitFiles(status, currentTime)
} else {
val res = CommitFilesResponse(
StatusCode.REQUEST_FAILED,
List.empty.asJava,
List.empty.asJava,
status.commitFilesParam.primaryIds,
status.commitFilesParam.replicaIds)
val res = createFailResponse(status)
processResponse(res, status.commitFilesParam.worker)
iter.remove()
}
}
} else if (currentTime - status.startTime > timeout) {
if (status.retriedTimes < maxRetries) {
retryCommitFiles(status, currentTime)
} else {
iter.remove()
}
}
}

if (!futures.isEmpty) {
Thread.sleep(delta)
}
timeout = timeout - delta
remainingTime -= delta
}

val iter = futures.iterator()
while (iter.hasNext) {
val status = iter.next()
logError(
s"Ask worker(${status.commitFilesParam.worker}) CommitFiles for $shuffleId timed out")
val res = CommitFilesResponse(
StatusCode.REQUEST_FAILED,
List.empty.asJava,
List.empty.asJava,
status.commitFilesParam.primaryIds,
status.commitFilesParam.replicaIds)
val res = createFailResponse(status)
processResponse(res, status.commitFilesParam.worker)
iter.remove()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ trait WithShuffleClientSuite extends CelebornFunSuite {

private var lifecycleManager: LifecycleManager = _
private var shuffleClient: ShuffleClientImpl = _
private var shuffleId = 0

var _shuffleId = 0
def nextShuffleId: Int = {
_shuffleId += 1
_shuffleId
}

override protected def afterEach() {
if (lifecycleManager != null) {
Expand All @@ -56,7 +61,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {

test("test register map partition task") {
prepareService()
shuffleId = 1
val shuffleId = nextShuffleId
var location =
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId, 1)
Assert.assertEquals(location.getId, 1)
Expand Down Expand Up @@ -95,10 +100,10 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
}

test("test batch release partition") {
shuffleId = 2
val shuffleId = nextShuffleId
celebornConf.set(CelebornConf.CLIENT_BATCH_HANDLE_RELEASE_PARTITION_ENABLED.key, "true")
prepareService()
registerAndFinishPartition()
registerAndFinishPartition(shuffleId)

val partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala

Expand All @@ -115,11 +120,11 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
}

test("test release single partition") {
shuffleId = 3
val shuffleId = nextShuffleId
celebornConf.set(CelebornConf.CLIENT_BATCH_HANDLE_RELEASE_PARTITION_ENABLED.key, "false")
celebornConf.set(CelebornConf.CLIENT_BATCH_HANDLED_RELEASE_PARTITION_INTERVAL.key, "1s")
prepareService()
registerAndFinishPartition()
registerAndFinishPartition(shuffleId)

val partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala

Expand All @@ -134,9 +139,9 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
}

test("test map end & get reducer file group") {
shuffleId = 4
val shuffleId = nextShuffleId
prepareService()
registerAndFinishPartition()
registerAndFinishPartition(shuffleId)

// reduce file group size (for empty partitions)
Assert.assertEquals(shuffleClient.getReduceFileGroupsMap.size(), 0)
Expand All @@ -161,7 +166,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
}

private def registerAndFinishPartition(): Unit = {
private def registerAndFinishPartition(shuffleId: Int): Unit = {
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId, 1)
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 1, attemptId, 2)
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 2, attemptId, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ public enum StatusCode {
PUSH_DATA_REPLICA_WORKER_EXCLUDED(45),

FETCH_DATA_TIMEOUT(46),
REVIVE_INITIALIZED(47);
REVIVE_INITIALIZED(47),
DESTROY_SLOTS_MOCK_FAILED(48);

private final byte value;

Expand Down
1 change: 1 addition & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ message PbDestroyWorkerSlots {
string shuffleKey = 1;
repeated string primaryLocations = 2;
repeated string replicaLocation = 3;
bool mockFailure = 4;
}

message PbDestroyWorkerSlotsResponse {
Expand Down
Loading

0 comments on commit 01feb93

Please sign in to comment.