Skip to content

Commit

Permalink
LIVY-213. Implemented interactive statement results recovery.
Browse files Browse the repository at this point in the history
- Removed statement results history cache in livy-server.
- Proxy statement results history requests to livy-repl.
- Modified RSC repl API:
 - ReplJobRequest to be synchronous and return the id of new statement.
 - GetReplJobResults to return all or some statements in sorted order. Also return current repl state.
 - Moved definition of Statement and StatementState to repl-rsc to avoid unnecessary parsing.
- livy-server polls only repl state but not statement results from repl.
- New StatementState "waiting" to show a statement that's submitted to repl but hasn't been executed.
- Send statement results to server in JSON instead of JValue to avoid Kryo issue.
  • Loading branch information
alex-the-man authored Nov 11, 2016
1 parent 805dcf2 commit b4642b1
Show file tree
Hide file tree
Showing 22 changed files with 382 additions and 404 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ class LivyRestClient(val httpClient: AsyncHttpClient, val livyEndpoint: String)

def run(code: String): Statement = { new Statement(code) }

def runFatalStatement(code: String): Unit = {
val requestBody = Map("code" -> code)
val r = httpClient.preparePost(s"$url/statements")
.setBody(mapper.writeValueAsString(requestBody))
.execute()

verifySessionState(SessionState.Dead())
}

def verifySessionIdle(): Unit = {
verifySessionState(SessionState.Idle())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ class InteractiveIT extends BaseIntegrationTestSuite {

test("application kills session") {
withNewSession(Spark()) { s =>
s.run("System.exit(0)")
s.verifySessionState(SessionState.Dead())
s.runFatalStatement("System.exit(0)")
}
}

Expand Down Expand Up @@ -134,7 +133,8 @@ class InteractiveIT extends BaseIntegrationTestSuite {

test("recover interactive session") {
withNewSession(Spark()) { s =>
s.run("1").verifyResult("res0: Int = 1")
val stmt1 = s.run("1")
stmt1.verifyResult("res0: Int = 1")

// Restart Livy.
cluster.stopLivy()
Expand All @@ -143,7 +143,8 @@ class InteractiveIT extends BaseIntegrationTestSuite {
// Verify session still exists.
s.verifySessionIdle()
s.run("2").verifyResult("res1: Int = 2")
// TODO, verify previous statement results still exist.
// Verify statement result is preserved.
stmt1.verifyResult("res0: Int = 1")

s.stop()

Expand Down
42 changes: 21 additions & 21 deletions repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,22 @@

package com.cloudera.livy.repl

import scala.collection.mutable
import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Await
import scala.concurrent.duration.Duration

import io.netty.channel.ChannelHandlerContext
import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods._

import com.cloudera.livy.Logging
import com.cloudera.livy.rsc.{BaseProtocol, RSCConf}
import com.cloudera.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf}
import com.cloudera.livy.rsc.driver._
import com.cloudera.livy.sessions._

class ReplDriver(conf: SparkConf, livyConf: RSCConf)
extends RSCDriver(conf, livyConf)
with Logging {

// Add here to make it compatible with json4s-jackson 3.2.11 JsonMethods#render API.
private implicit def formats = DefaultFormats

private val jobFutures = mutable.Map[String, JValue]()

private[repl] var session: Session = _

private val kind = Kind(livyConf.get(RSCConf.Entry.SESSION_KIND))
Expand Down Expand Up @@ -73,19 +63,29 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
}
}

def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.ReplJobRequest): Unit = {
Future {
jobFutures(msg.id) = session.execute(msg.code).result
}
def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.ReplJobRequest): Int = {
session.execute(msg.code)
}

def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResult): String = {
val result = jobFutures.getOrElse(msg.id, null)
Option(result).map { r => compact(render(r)) }.orNull
}
/**
* Return statement results. Results are sorted by statement id.
*/
def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResults): ReplJobResults =
session.synchronized {
val stmts = if (msg.allResults) {
session.statements.values.toArray
} else {
assert(msg.from != null)
assert(msg.size != null)
val until = msg.from + msg.size
session.statements.filterKeys(id => id >= msg.from && id < until).values.toArray
}
val state = session.state.toString
new ReplJobResults(stmts.sortBy(_.id), state)
}

def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplState): String = {
return session.state.toString
session.state.toString
}

override protected def createWrapper(msg: BaseProtocol.BypassJobRequest): BypassJobWrapper = {
Expand Down
53 changes: 35 additions & 18 deletions repl/src/main/scala/com/cloudera/livy/repl/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
package com.cloudera.livy.repl

import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{ExecutionContext, Future, TimeoutException}
import scala.concurrent.duration.Duration
import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkContext
import org.json4s.{DefaultFormats, JValue}
import org.json4s.jackson.JsonMethods.{compact, render}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._

import com.cloudera.livy.{Logging, Utils}
import com.cloudera.livy.Logging
import com.cloudera.livy.rsc.driver.{Statement, StatementState}
import com.cloudera.livy.sessions._

object Session {
Expand All @@ -51,7 +54,9 @@ class Session(interpreter: Interpreter)
private implicit val formats = DefaultFormats

private var _state: SessionState = SessionState.NotStarted()
private var _history = IndexedSeq[Statement]()
private val _statements = mutable.Map[Int, Statement]()

private val newStatementId = new AtomicInteger(0)

def start(): Future[SparkContext] = {
val future = Future {
Expand All @@ -70,29 +75,41 @@ class Session(interpreter: Interpreter)

def state: SessionState = _state

def history: IndexedSeq[Statement] = _history
def statements: mutable.Map[Int, Statement] = _statements

def execute(code: String): Int = {
val statementId = newStatementId.getAndIncrement()
synchronized {
_statements(statementId) = new Statement(statementId, StatementState.Waiting, null)
}
Future {
synchronized {
_statements(statementId) = new Statement(statementId, StatementState.Running, null)
}

val statement =
new Statement(statementId, StatementState.Available, executeCode(statementId, code))

def execute(code: String): Statement = synchronized {
val executionCount = _history.length
val statement = Statement(executionCount, executeCode(executionCount, code))
_history :+= statement
statement
synchronized {
_statements(statementId) = statement
}
}
statementId
}

def close(): Unit = {
executor.shutdown()
interpreter.close()
}

def clearHistory(): Unit = synchronized {
_history = IndexedSeq()
def clearStatements(): Unit = synchronized {
_statements.clear()
}

private def executeCode(executionCount: Int, code: String) = {
private def executeCode(executionCount: Int, code: String): String = synchronized {
_state = SessionState.Busy()

try {

val resultInJson = try {
interpreter.execute(code) match {
case Interpreter.ExecuteSuccess(data) =>
_state = SessionState.Idle()
Expand Down Expand Up @@ -140,7 +157,7 @@ class Session(interpreter: Interpreter)
(EVALUE -> e.getMessage) ~
(TRACEBACK -> List())
}

compact(render(resultInJson))
}
}

case class Statement(id: Int, result: JValue)
16 changes: 13 additions & 3 deletions repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,28 @@ import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps

import org.json4s.DefaultFormats
import org.json4s._
import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.concurrent.Eventually._

import com.cloudera.livy.LivyBaseUnitTestSuite
import com.cloudera.livy.rsc.driver.{Statement, StatementState}
import com.cloudera.livy.sessions.SessionState

abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitTestSuite {

implicit val formats = DefaultFormats

def withSession(testCode: Session => Any): Unit = {
protected def execute(session: Session)(code: String): Statement = {
val id = session.execute(code)
eventually(timeout(30 seconds), interval(100 millis)) {
val s = session.statements(id)
s.state shouldBe StatementState.Available
s
}
}

protected def withSession(testCode: Session => Any): Unit = {
val session = new Session(createInterpreter())
try {
Await.ready(session.start(), 30 seconds)
Expand All @@ -44,7 +54,7 @@ abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitT
}
}

def createInterpreter(): Interpreter
protected def createInterpreter(): Interpreter

it should "start in the starting or idle state" in {
val session = new Session(createInterpreter())
Expand Down
38 changes: 20 additions & 18 deletions repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ package com.cloudera.livy.repl

import org.apache.spark.SparkConf
import org.json4s.Extraction
import org.json4s.jackson.JsonMethods.parse
import org.scalatest._

import com.cloudera.livy.sessions._

abstract class PythonSessionSpec extends BaseSessionSpec {

it should "execute `1 + 2` == 3" in withSession { session =>
val statement = session.execute("1 + 2")
val statement = execute(session)("1 + 2")
statement.id should equal (0)

val result = statement.result
val result = parse(statement.output)
val expectedResult = Extraction.decompose(Map(
"status" -> "ok",
"execution_count" -> 0,
Expand All @@ -43,10 +44,11 @@ abstract class PythonSessionSpec extends BaseSessionSpec {
}

it should "execute `x = 1`, then `y = 2`, then `x + y`" in withSession { session =>
var statement = session.execute("x = 1")
val executeWithSession = execute(session)(_)
var statement = executeWithSession("x = 1")
statement.id should equal (0)

var result = statement.result
var result = parse(statement.output)
var expectedResult = Extraction.decompose(Map(
"status" -> "ok",
"execution_count" -> 0,
Expand All @@ -57,10 +59,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec {

result should equal (expectedResult)

statement = session.execute("y = 2")
statement = executeWithSession("y = 2")
statement.id should equal (1)

result = statement.result
result = parse(statement.output)
expectedResult = Extraction.decompose(Map(
"status" -> "ok",
"execution_count" -> 1,
Expand All @@ -71,10 +73,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec {

result should equal (expectedResult)

statement = session.execute("x + y")
statement = executeWithSession("x + y")
statement.id should equal (2)

result = statement.result
result = parse(statement.output)
expectedResult = Extraction.decompose(Map(
"status" -> "ok",
"execution_count" -> 2,
Expand All @@ -87,10 +89,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec {
}

it should "do table magic" in withSession { session =>
val statement = session.execute("x = [[1, 'a'], [3, 'b']]\n%table x")
val statement = execute(session)("x = [[1, 'a'], [3, 'b']]\n%table x")
statement.id should equal (0)

val result = statement.result
val result = parse(statement.output)
val expectedResult = Extraction.decompose(Map(
"status" -> "ok",
"execution_count" -> 0,
Expand All @@ -108,10 +110,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec {
}

it should "capture stdout" in withSession { session =>
val statement = session.execute("""print('Hello World')""")
val statement = execute(session)("""print('Hello World')""")
statement.id should equal (0)

val result = statement.result
val result = parse(statement.output)
val expectedResult = Extraction.decompose(Map(
"status" -> "ok",
"execution_count" -> 0,
Expand All @@ -124,10 +126,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec {
}

it should "report an error if accessing an unknown variable" in withSession { session =>
val statement = session.execute("""x""")
val statement = execute(session)("""x""")
statement.id should equal (0)

val result = statement.result
val result = parse(statement.output)
val expectedResult = Extraction.decompose(Map(
"status" -> "error",
"execution_count" -> 0,
Expand All @@ -143,7 +145,7 @@ abstract class PythonSessionSpec extends BaseSessionSpec {
}

it should "report an error if exception is thrown" in withSession { session =>
val statement = session.execute(
val statement = execute(session)(
"""def func1():
| raise Exception("message")
|def func2():
Expand All @@ -152,7 +154,7 @@ abstract class PythonSessionSpec extends BaseSessionSpec {
""".stripMargin)
statement.id should equal (0)

val result = statement.result
val result = parse(statement.output)
val expectedResult = Extraction.decompose(Map(
"status" -> "error",
"execution_count" -> 0,
Expand Down Expand Up @@ -184,13 +186,13 @@ class Python3SessionSpec extends PythonSessionSpec {
override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3())

it should "check python version is 3.x" in withSession { session =>
val statement = session.execute(
val statement = execute(session)(
"""import sys
|sys.version >= '3'
""".stripMargin)
statement.id should equal (0)

val result = statement.result
val result = parse(statement.output)
val expectedResult = Extraction.decompose(Map(
"status" -> "ok",
"execution_count" -> 0,
Expand Down
Loading

0 comments on commit b4642b1

Please sign in to comment.