Skip to content

Commit

Permalink
LIVY-355. Refactor statement progress tracker to fix binary compatibl…
Browse files Browse the repository at this point in the history
…e issue (apache#323)

* Refactor statement progress tracker to fix binary compatible issue

Change-Id: Ie91fd77472aeebe138bd6711a0baa82269a6b247

* refactor again to simplify the code

Change-Id: I9380bcb8dd2b594250783633a3c68e290ac7ea28

* isolate statementId to job group logic

Change-Id: If554aee2c0b3d96b54804f94cbb8df9af7843ab4
  • Loading branch information
jerryshao authored and zjffdu committed May 9, 2017
1 parent f5ef489 commit 0ddcaf6
Show file tree
Hide file tree
Showing 18 changed files with 93 additions and 449 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ import org.apache.spark.repl.SparkIMain
/**
* This represents a Spark interpreter. It is not thread safe.
*/
class SparkInterpreter(conf: SparkConf,
override val statementProgressListener: StatementProgressListener)
class SparkInterpreter(conf: SparkConf)
extends AbstractSparkInterpreter with SparkContextInitializer {

private var sparkIMain: SparkIMain = _
Expand Down Expand Up @@ -108,7 +107,6 @@ class SparkInterpreter(conf: SparkConf,
createSparkContext(conf)
}

sparkContext.addSparkListener(statementProgressListener)
sparkContext
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.cloudera.livy.LivyBaseUnitTestSuite

class SparkInterpreterSpec extends FunSpec with Matchers with LivyBaseUnitTestSuite {
describe("SparkInterpreter") {
val interpreter = new SparkInterpreter(null, null)
val interpreter = new SparkInterpreter(null)

it("should parse Scala compile error.") {
// Regression test for LIVY-260.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ import org.apache.spark.repl.SparkILoop
/**
* Scala 2.11 version of SparkInterpreter
*/
class SparkInterpreter(conf: SparkConf,
override val statementProgressListener: StatementProgressListener)
class SparkInterpreter(conf: SparkConf)
extends AbstractSparkInterpreter with SparkContextInitializer {

protected var sparkContext: SparkContext = _
Expand Down Expand Up @@ -94,7 +93,6 @@ class SparkInterpreter(conf: SparkConf,
createSparkContext(conf)
}

sparkContext.addSparkListener(statementProgressListener)
sparkContext
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.cloudera.livy.LivyBaseUnitTestSuite

class SparkInterpreterSpec extends FunSpec with Matchers with LivyBaseUnitTestSuite {
describe("SparkInterpreter") {
val interpreter = new SparkInterpreter(null, null)
val interpreter = new SparkInterpreter(null)

it("should parse Scala compile error.") {
// Regression test for LIVY-.
Expand Down
10 changes: 0 additions & 10 deletions repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,13 @@ trait Interpreter {

def kind: String

def statementProgressListener: StatementProgressListener

/**
* Start the Interpreter.
*
* @return A SparkContext
*/
def start(): SparkContext

/**
* Execute the code and return the result.
*/
def execute(statementId: Int, code: String): ExecuteResponse = {
statementProgressListener.setCurrentStatementId(statementId)
execute(code)
}

/**
* Execute the code and return the result, it may
* take some time to execute.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ private case class ShutdownRequest(promise: Promise[Unit]) extends Request
*
* @param process
*/
abstract class ProcessInterpreter(process: Process,
override val statementProgressListener: StatementProgressListener)
abstract class ProcessInterpreter(process: Process)
extends Interpreter with Logging {
protected[this] val stdin = new PrintWriter(process.getOutputStream)
protected[this] val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1)
Expand All @@ -53,9 +52,7 @@ abstract class ProcessInterpreter(process: Process,
if (ClientConf.TEST_MODE) {
null.asInstanceOf[SparkContext]
} else {
val sc = SparkContext.getOrCreate()
sc.addSparkListener(statementProgressListener)
sc
SparkContext.getOrCreate()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import com.cloudera.livy.sessions._
// scalastyle:off println
object PythonInterpreter extends Logging {

def apply(conf: SparkConf, kind: Kind, listener: StatementProgressListener): Interpreter = {
def apply(conf: SparkConf, kind: Kind): Interpreter = {
val pythonExec = kind match {
case PySpark() => sys.env.getOrElse("PYSPARK_PYTHON", "python")
case PySpark3() => sys.env.getOrElse("PYSPARK3_PYTHON", "python3")
Expand All @@ -72,7 +72,7 @@ object PythonInterpreter extends Logging {
env.put("LIVY_SPARK_MAJOR_VERSION", conf.get("spark.livy.spark_major_version", "1"))
builder.redirectError(Redirect.PIPE)
val process = builder.start()
new PythonInterpreter(process, gatewayServer, kind.toString, listener)
new PythonInterpreter(process, gatewayServer, kind.toString)
}

private def findPySparkArchives(): Seq[String] = {
Expand Down Expand Up @@ -190,9 +190,8 @@ object PythonInterpreter extends Logging {
private class PythonInterpreter(
process: Process,
gatewayServer: GatewayServer,
pyKind: String,
listener: StatementProgressListener)
extends ProcessInterpreter(process, listener)
pyKind: String)
extends ProcessInterpreter(process)
with Logging
{
implicit val formats = DefaultFormats
Expand Down
10 changes: 5 additions & 5 deletions repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)

override protected def initializeContext(): JavaSparkContext = {
interpreter = kind match {
case PySpark() => PythonInterpreter(conf, PySpark(), new StatementProgressListener(livyConf))
case PySpark() => PythonInterpreter(conf, PySpark())
case PySpark3() =>
PythonInterpreter(conf, PySpark3(), new StatementProgressListener(livyConf))
case Spark() => new SparkInterpreter(conf, new StatementProgressListener(livyConf))
case SparkR() => SparkRInterpreter(conf, new StatementProgressListener(livyConf))
PythonInterpreter(conf, PySpark3())
case Spark() => new SparkInterpreter(conf)
case SparkR() => SparkRInterpreter(conf)
}
session = new Session(livyConf, interpreter, { s => broadcast(new ReplState(s.toString)) })

Expand Down Expand Up @@ -94,7 +94,7 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)

// Update progress of statements when queried
statements.foreach { s =>
s.updateProgress(interpreter.statementProgressListener.progressOfStatement(s.id))
s.updateProgress(session.progressOfStatement(s.id))
}

new ReplJobResults(statements.sortBy(_.id))
Expand Down
40 changes: 34 additions & 6 deletions repl/src/main/scala/com/cloudera/livy/repl/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,29 @@ class Session(
interpreter.close()
}

/**
* Get the current progress of given statement id.
*/
def progressOfStatement(stmtId: Int): Double = {
val jobGroup = statementIdToJobGroup(stmtId)

_sc.map { sc =>
val jobIds = sc.statusTracker.getJobIdsForGroup(jobGroup)
val jobs = jobIds.flatMap { id => sc.statusTracker.getJobInfo(id) }
val stages = jobs.flatMap { job =>
job.stageIds().flatMap(sc.statusTracker.getStageInfo)
}

val taskCount = stages.map(_.numTasks).sum
val completedTaskCount = stages.map(_.numCompletedTasks).sum
if (taskCount == 0) {
0.0
} else {
completedTaskCount.toDouble / taskCount
}
}.getOrElse(0.0)
}

private def changeState(newState: SessionState): Unit = {
synchronized {
_state = newState
Expand All @@ -188,7 +211,7 @@ class Session(
}

val resultInJson = try {
interpreter.execute(executionCount, code) match {
interpreter.execute(code) match {
case Interpreter.ExecuteSuccess(data) =>
transitToIdle()

Expand Down Expand Up @@ -240,23 +263,28 @@ class Session(
}

private def setJobGroup(statementId: Int): String = {
val jobGroup = statementIdToJobGroup(statementId)
val cmd = Kind(interpreter.kind) match {
case Spark() =>
// A dummy value to avoid automatic value binding in scala REPL.
s"""val _livyJobGroup$statementId = sc.setJobGroup("$statementId",""" +
s""""Job group for statement $statementId")"""
s"""val _livyJobGroup$jobGroup = sc.setJobGroup("$jobGroup",""" +
s""""Job group for statement $jobGroup")"""
case PySpark() | PySpark3() =>
s"""sc.setJobGroup("$statementId", "Job group for statement $statementId")"""
s"""sc.setJobGroup("$jobGroup", "Job group for statement $jobGroup")"""
case SparkR() =>
interpreter.asInstanceOf[SparkRInterpreter].sparkMajorVersion match {
case "1" =>
s"""setJobGroup(sc, "$statementId", "Job group for statement $statementId", """ +
s"""setJobGroup(sc, "$jobGroup", "Job group for statement $jobGroup", """ +
"FALSE)"
case "2" =>
s"""setJobGroup("$statementId", "Job group for statement $statementId", FALSE)"""
s"""setJobGroup("$jobGroup", "Job group for statement $jobGroup", FALSE)"""
}
}
// Set the job group
executeCode(statementId, cmd)
}

private def statementIdToJobGroup(statementId: Int): String = {
statementId.toString
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object SparkRInterpreter {
")"
).r.unanchored

def apply(conf: SparkConf, listener: StatementProgressListener): SparkRInterpreter = {
def apply(conf: SparkConf): SparkRInterpreter = {
val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt
val mirror = universe.runtimeMirror(getClass.getClassLoader)
val sparkRBackendClass = mirror.classLoader.loadClass("org.apache.spark.api.r.RBackend")
Expand Down Expand Up @@ -121,8 +121,7 @@ object SparkRInterpreter {
val process = builder.start()
new SparkRInterpreter(process, backendInstance, backendThread,
conf.get("spark.livy.spark_major_version", "1"),
conf.getBoolean("spark.repl.enableHiveContext", false),
listener)
conf.getBoolean("spark.repl.enableHiveContext", false))
} catch {
case e: Exception =>
if (backendThread != null) {
Expand All @@ -137,9 +136,8 @@ class SparkRInterpreter(process: Process,
backendInstance: Any,
backendThread: Thread,
val sparkMajorVersion: String,
hiveEnabled: Boolean,
statementProgressListener: StatementProgressListener)
extends ProcessInterpreter(process, statementProgressListener) {
hiveEnabled: Boolean)
extends ProcessInterpreter(process) {
import SparkRInterpreter._

implicit val formats = DefaultFormats
Expand Down
Loading

0 comments on commit 0ddcaf6

Please sign in to comment.