Skip to content

Commit

Permalink
Avoid Option while generating call site
Browse files Browse the repository at this point in the history
This is an update on apache#180, which changes the solution from blacklisting "Option.scala" to avoiding the Option code path while generating the call path.

Also includes a unit test to prevent this issue in the future, and some minor refactoring.

Thanks @witgo for reporting this issue and working on the initial solution!

Author: witgo <[email protected]>
Author: Aaron Davidson <[email protected]>

Closes apache#222 from aarondav/180 and squashes the following commits:

f74aad1 [Aaron Davidson] Avoid Option while generating call site & add unit tests
d2b4980 [witgo] Modify the position of the filter
1bc22d7 [witgo] Fix Stage.name return "apply at Option.scala:120"
  • Loading branch information
witgo authored and pwendell committed Mar 25, 2014
1 parent f8111ea commit 8237df8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,8 @@ class SparkContext(
* has overridden the call site, this will return the user's version.
*/
private[spark] def getCallSite(): String = {
Option(getLocalProperty("externalCallSite")).getOrElse(Utils.formatCallSiteInfo())
val defaultCallSite = Utils.getCallSiteInfo
Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString)
}

/**
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ abstract class RDD[T: ClassTag](

/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
@transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
private[spark] def getCreationSite = Utils.formatCallSiteInfo(creationSiteInfo)
private[spark] def getCreationSite: String = creationSiteInfo.toString

private[spark] def elementClassTag: ClassTag[T] = classTag[T]

Expand Down
18 changes: 9 additions & 9 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -679,16 +679,22 @@ private[spark] object Utils extends Logging {
private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r

private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
val firstUserLine: Int, val firstUserClass: String)
val firstUserLine: Int, val firstUserClass: String) {

/** Returns a printable version of the call site info suitable for logs. */
override def toString = {
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
}

/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
val trace = Thread.currentThread.getStackTrace()
.filterNot(_.getMethodName.contains("getStackTrace"))

// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
Expand Down Expand Up @@ -721,12 +727,6 @@ private[spark] object Utils extends Logging {
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}

/** Returns a printable version of the call site info suitable for logs. */
def formatCallSiteInfo(callSiteInfo: CallSiteInfo = Utils.getCallSiteInfo) = {
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
callSiteInfo.firstUserLine)
}

/** Return a string containing part of a file from byte 'start' to 'end'. */
def offsetBytes(path: String, start: Long, end: Long): String = {
val file = new File(path)
Expand Down
36 changes: 35 additions & 1 deletion core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark

import org.scalatest.FunSuite
import org.scalatest.{Assertions, FunSuite}

class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
test("getPersistentRDDs only returns RDDs that are marked as cached") {
Expand Down Expand Up @@ -56,4 +56,38 @@ class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
rdd.collect()
assert(sc.getRDDStorageInfo.size === 1)
}

test("call sites report correct locations") {
sc = new SparkContext("local", "test")
testPackage.runCallSiteTest(sc)
}
}

/** Call site must be outside of usual org.apache.spark packages (see Utils#SPARK_CLASS_REGEX). */
package object testPackage extends Assertions {
private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

def runCallSiteTest(sc: SparkContext) {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
val rddCreationSite = rdd.getCreationSite
val curCallSite = sc.getCallSite() // note: 2 lines after definition of "rdd"

val rddCreationLine = rddCreationSite match {
case CALL_SITE_REGEX(func, file, line) => {
assert(func === "makeRDD")
assert(file === "SparkContextInfoSuite.scala")
line.toInt
}
case _ => fail("Did not match expected call site format")
}

curCallSite match {
case CALL_SITE_REGEX(func, file, line) => {
assert(func === "getCallSite") // this is correct because we called it from outside of Spark
assert(file === "SparkContextInfoSuite.scala")
assert(line.toInt === rddCreationLine.toInt + 2)
}
case _ => fail("Did not match expected call site format")
}
}
}

0 comments on commit 8237df8

Please sign in to comment.