Skip to content

Commit

Permalink
Add NVTX ranges to identify Spark stages and tasks (#10826)
Browse files Browse the repository at this point in the history
* Add NVTX ranges to identify Spark stages and tasks

Signed-off-by: Jason Lowe <[email protected]>

* scalastyle

---------

Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored May 20, 2024
1 parent 3a807d6 commit 6921dac
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ import java.lang.reflect.InvocationTargetException
import java.net.URL
import java.time.ZoneId
import java.util.Properties
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.sys.process._
import scala.util.Try

import ai.rapids.cudf.{Cuda, CudaException, CudaFatalException, CudfException, MemoryCleaner}
import ai.rapids.cudf.{Cuda, CudaException, CudaFatalException, CudfException, MemoryCleaner, NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.RapidsConf.AllowMultipleJars
import com.nvidia.spark.rapids.RapidsPluginUtils.buildInfoEvent
import com.nvidia.spark.rapids.filecache.{FileCache, FileCacheLocalityManager, FileCacheLocalityMsg}
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import org.apache.commons.lang3.exception.ExceptionUtils

import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskFailedReason}
import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskContext, TaskFailedReason}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.SparkListenerEvent
Expand Down Expand Up @@ -494,6 +495,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
var rapidsShuffleHeartbeatEndpoint: RapidsShuffleHeartbeatEndpoint = null
private lazy val extraExecutorPlugins =
RapidsPluginUtils.extraPlugins.map(_.executorPlugin()).filterNot(_ == null)
private val activeTaskNvtx = new ConcurrentHashMap[Thread, NvtxRange]()

override def init(
pluginContext: PluginContext,
Expand Down Expand Up @@ -684,14 +686,32 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
logDebug(s"Executor onTaskFailed: ${other.toString}")
}
extraExecutorPlugins.foreach(_.onTaskFailed(failureReason))
endTaskNvtx()
}

override def onTaskStart(): Unit = {
startTaskNvtx(TaskContext.get)
extraExecutorPlugins.foreach(_.onTaskStart())
}

override def onTaskSucceeded(): Unit = {
extraExecutorPlugins.foreach(_.onTaskSucceeded())
endTaskNvtx()
}

private def startTaskNvtx(taskCtx: TaskContext): Unit = {
val stageId = taskCtx.stageId()
val taskAttemptId = taskCtx.taskAttemptId()
val attemptNumber = taskCtx.attemptNumber()
activeTaskNvtx.put(Thread.currentThread(),
new NvtxRange(s"Stage $stageId Task $taskAttemptId-$attemptNumber", NvtxColor.DARK_GREEN))
}

private def endTaskNvtx(): Unit = {
val nvtx = activeTaskNvtx.remove(Thread.currentThread())
if (nvtx != null) {
nvtx.close()
}
}
}

Expand Down

0 comments on commit 6921dac

Please sign in to comment.