Skip to content

Commit

Permalink
LIVY-215. Support statement canceling in interactive sessions. (cloud…
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryshao authored and alex-the-man committed Dec 16, 2016
1 parent 2ae091d commit d3cc09b
Show file tree
Hide file tree
Showing 16 changed files with 236 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@ abstract class ProcessInterpreter(process: Process)
override def start(): SparkContext = {
waitUntilReady()

// At this point there should be an already active SparkContext that can be retrieved
// using SparkContext.getOrCreate. But we don't really support running "pre-compiled"
// jobs against pyspark or sparkr, so just return null here.
null
SparkContext.getOrCreate()
}

override def execute(code: String): Interpreter.ExecuteResponse = {
try {
sendExecuteRequest(code)
Expand Down
8 changes: 6 additions & 2 deletions repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaSparkContext

import com.cloudera.livy.Logging
import com.cloudera.livy.rsc.BaseProtocol.ReplState
import com.cloudera.livy.rsc.{BaseProtocol, RSCConf, ReplJobResults}
import com.cloudera.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf}
import com.cloudera.livy.rsc.driver._
import com.cloudera.livy.rsc.rpc.Rpc
import com.cloudera.livy.sessions._
Expand All @@ -49,7 +49,7 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
case Spark() => new SparkInterpreter(conf)
case SparkR() => SparkRInterpreter(conf)
}
session = new Session(interpreter, { s => broadcast(new ReplState(s.toString)) })
session = new Session(livyConf, interpreter, { s => broadcast(new ReplState(s.toString)) })

Option(Await.result(session.start(), Duration.Inf))
.map(new JavaSparkContext(_))
Expand All @@ -70,6 +70,10 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
session.execute(msg.code)
}

def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.CancelReplJobRequest): Unit = {
session.cancel(msg.id)
}

/**
* Return statement results. Results are sorted by statement id.
*/
Expand Down
105 changes: 91 additions & 14 deletions repl/src/main/scala/com/cloudera/livy/repl/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.collection.concurrent.TrieMap
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._

import org.apache.spark.SparkContext
import org.json4s.jackson.JsonMethods.{compact, render}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._

import com.cloudera.livy.Logging
import com.cloudera.livy.rsc.RSCConf
import com.cloudera.livy.rsc.driver.{Statement, StatementState}
import com.cloudera.livy.sessions._

Expand All @@ -44,15 +46,23 @@ object Session {
val TRACEBACK = "traceback"
}

class Session(interpreter: Interpreter, stateChangedCallback: SessionState => Unit = { _ => } )
extends Logging
{
class Session(
livyConf: RSCConf,
interpreter: Interpreter,
stateChangedCallback: SessionState => Unit = { _ => })
extends Logging {
import Session._

private implicit val executor = ExecutionContext.fromExecutorService(
Executors.newSingleThreadExecutor())

private val cancelExecutor = ExecutionContext.fromExecutorService(
Executors.newSingleThreadExecutor())

private implicit val formats = DefaultFormats

@volatile private[repl] var _sc: Option[SparkContext] = None

private var _state: SessionState = SessionState.NotStarted()
private val _statements = TrieMap[Int, Statement]()

Expand All @@ -64,12 +74,12 @@ class Session(interpreter: Interpreter, stateChangedCallback: SessionState => Un
val future = Future {
changeState(SessionState.Starting())
val sc = interpreter.start()
_sc = Option(sc)
changeState(SessionState.Idle())
sc
}
future.onFailure { case _ =>
changeState(SessionState.Error())
}

future.onFailure { case _ => changeState(SessionState.Error()) }
future
}

Expand All @@ -82,32 +92,78 @@ class Session(interpreter: Interpreter, stateChangedCallback: SessionState => Un
def execute(code: String): Int = {
val statementId = newStatementId.getAndIncrement()
_statements(statementId) = new Statement(statementId, StatementState.Waiting, null)

Future {
_statements(statementId) = new Statement(statementId, StatementState.Running, null)
setJobGroup(statementId)
_statements(statementId).state.compareAndSet(StatementState.Waiting, StatementState.Running)

val executeResult = if (_statements(statementId).state.get() == StatementState.Running) {
executeCode(statementId, code)
} else {
null
}

_statements(statementId) =
new Statement(statementId, StatementState.Available, executeCode(statementId, code))
_statements(statementId).output = executeResult
_statements(statementId).state.compareAndSet(StatementState.Running, StatementState.Available)
_statements(statementId).state.compareAndSet(
StatementState.Cancelling, StatementState.Cancelled)
}

statementId
}

def cancel(statementId: Int): Unit = {
if (!_statements.contains(statementId)) {
return
}

if (_statements(statementId).state.get() == StatementState.Available ||
_statements(statementId).state.get() == StatementState.Cancelled ||
_statements(statementId).state.get() == StatementState.Cancelling) {
return
} else {
// statement 1 is running and statement 2 is waiting. User cancels
// statement 2 then cancels statement 1. The 2nd cancel call will loop and block the 1st
// cancel call since cancelExecutor is single threaded. To avoid this, set the statement
// state to cancelled when cancelling a waiting statement.
_statements(statementId).state.compareAndSet(StatementState.Waiting, StatementState.Cancelled)
_statements(statementId).state.compareAndSet(
StatementState.Running, StatementState.Cancelling)
}

info(s"Cancelling statement $statementId...")

Future {
val deadline = livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TIMEOUT).millis.fromNow
while (_statements(statementId).state.get() == StatementState.Cancelling) {
if (deadline.isOverdue()) {
info(s"Failed to cancel statement $statementId.")
_statements(statementId).state.compareAndSet(
StatementState.Cancelling, StatementState.Cancelled)
} else {
_sc.foreach(_.cancelJobGroup(statementId.toString))
}
Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL))
}
if (_statements(statementId).state.get() == StatementState.Cancelled) {
info(s"Statement $statementId cancelled.")
}
}(cancelExecutor)
}

def close(): Unit = {
executor.shutdown()
interpreter.close()
}

def clearStatements(): Unit = synchronized {
_statements.clear()
}

private def changeState(newState: SessionState): Unit = {
synchronized {
_state = newState
}
stateChangedCallback(newState)
}

private def executeCode(executionCount: Int, code: String): String = synchronized {
private def executeCode(executionCount: Int, code: String): String = {
changeState(SessionState.Busy())

def transitToIdle() = {
Expand Down Expand Up @@ -168,4 +224,25 @@ class Session(interpreter: Interpreter, stateChangedCallback: SessionState => Un

compact(render(resultInJson))
}

private def setJobGroup(statementId: Int): String = {
val cmd = Kind(interpreter.kind) match {
case Spark() =>
// A dummy value to avoid automatic value binding in scala REPL.
s"""val _livyJobGroup$statementId = sc.setJobGroup("$statementId",""" +
s""""Job group for statement $statementId")"""
case PySpark() | PySpark3() =>
s"""sc.setJobGroup("$statementId", "Job group for statement $statementId")"""
case SparkR() =>
interpreter.asInstanceOf[SparkRInterpreter].sparkMajorVersion match {
case "1" =>
s"""setJobGroup(sc, "$statementId", "Job group for statement $statementId", """ +
"FALSE)"
case "2" =>
s"""setJobGroup("$statementId", "Job group for statement $statementId", FALSE)"""
}
}
// Set the job group
executeCode(statementId, cmd)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,17 @@ object SparkRInterpreter {
}
}

class SparkRInterpreter(process: Process, backendInstance: Any, backendThread: Thread,
sparkMajorVersion: String) extends ProcessInterpreter(process) {
class SparkRInterpreter(process: Process,
backendInstance: Any,
backendThread: Thread,
val sparkMajorVersion: String)
extends ProcessInterpreter(process) {
import SparkRInterpreter._

implicit val formats = DefaultFormats

private[this] var executionCount = 0
override def kind: String = "sparkR"
override def kind: String = "sparkr"
private[this] val isStarted = new CountDownLatch(1);

final override protected def waitUntilReady(): Unit = {
Expand Down
11 changes: 8 additions & 3 deletions repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.cloudera.livy.repl

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.Await
Expand All @@ -29,25 +30,29 @@ import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.concurrent.Eventually._

import com.cloudera.livy.LivyBaseUnitTestSuite
import com.cloudera.livy.rsc.RSCConf
import com.cloudera.livy.rsc.driver.{Statement, StatementState}
import com.cloudera.livy.sessions.SessionState

abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitTestSuite {

implicit val formats = DefaultFormats

private val rscConf = new RSCConf(new Properties())

protected def execute(session: Session)(code: String): Statement = {
val id = session.execute(code)
eventually(timeout(30 seconds), interval(100 millis)) {
val s = session.statements(id)
s.state shouldBe StatementState.Available
s.state.get() shouldBe StatementState.Available
s
}
}

protected def withSession(testCode: Session => Any): Unit = {
val stateChangedCalled = new AtomicInteger()
val session = new Session(createInterpreter(), { _ => stateChangedCalled.incrementAndGet() })
val session =
new Session(rscConf, createInterpreter(), { _ => stateChangedCalled.incrementAndGet() })
try {
// Session's constructor should fire an initial state change event.
stateChangedCalled.intValue() shouldBe 1
Expand All @@ -64,7 +69,7 @@ abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitT
protected def createInterpreter(): Interpreter

it should "start in the starting or idle state" in {
val session = new Session(createInterpreter())
val session = new Session(rscConf, createInterpreter())
val future = session.start()
try {
eventually(timeout(30 seconds), interval(100 millis)) {
Expand Down
25 changes: 20 additions & 5 deletions repl/src/test/scala/com/cloudera/livy/repl/SessionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

package com.cloudera.livy.repl

import java.util.Properties
import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}


import org.mockito.Mockito.when
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
Expand All @@ -32,18 +32,27 @@ import org.scalatest.time._

import com.cloudera.livy.LivyBaseUnitTestSuite
import com.cloudera.livy.repl.Interpreter.ExecuteResponse
import com.cloudera.livy.rsc.RSCConf

class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite {
override implicit val patienceConfig =
PatienceConfig(timeout = scaled(Span(10, Seconds)), interval = scaled(Span(100, Millis)))

private val rscConf = new RSCConf(new Properties())

describe("Session") {
it("should call state changed callbacks in happy path") {
val expectedStateTransitions = Array("not_started", "starting", "idle", "busy", "idle")
val expectedStateTransitions =
Array("not_started", "starting", "idle", "busy", "idle", "busy", "idle")
val actualStateTransitions = new ConcurrentLinkedQueue[String]()

val interpreter = mock[Interpreter]
val session = new Session(interpreter, { s => actualStateTransitions.add(s.toString) })
when(interpreter.kind).thenAnswer(new Answer[String] {
override def answer(invocationOnMock: InvocationOnMock): String = "spark"
})

val session =
new Session(rscConf, interpreter, { s => actualStateTransitions.add(s.toString) })

session.start()

Expand All @@ -55,18 +64,24 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite {
}

it("should not transit to idle if there're any pending statements.") {
val expectedStateTransitions = Array("not_started", "busy", "busy", "idle")
val expectedStateTransitions =
Array("not_started", "busy", "busy", "busy", "idle", "busy", "idle")
val actualStateTransitions = new ConcurrentLinkedQueue[String]()

val interpreter = mock[Interpreter]
when(interpreter.kind).thenAnswer(new Answer[String] {
override def answer(invocationOnMock: InvocationOnMock): String = "spark"
})

val blockFirstExecuteCall = new CountDownLatch(1)
when(interpreter.execute("")).thenAnswer(new Answer[Interpreter.ExecuteResponse] {
override def answer(invocation: InvocationOnMock): ExecuteResponse = {
blockFirstExecuteCall.await(10, TimeUnit.SECONDS)
null
}
})
val session = new Session(interpreter, { s => actualStateTransitions.add(s.toString) })
val session =
new Session(rscConf, interpreter, { s => actualStateTransitions.add(s.toString) })

for (_ <- 1 to 2) {
session.execute("")
Expand Down
Loading

0 comments on commit d3cc09b

Please sign in to comment.