From b4642b193956b287bb1c1b32bafe4c5a15e71eac Mon Sep 17 00:00:00 2001 From: Alex Man Date: Thu, 10 Nov 2016 21:37:56 -0800 Subject: [PATCH] LIVY-213. Implemented interactive statement results recovery. - Removed statement results history cache in livy-server. - Proxy statement results history requests to livy-repl. - Modified RSC repl API: - ReplJobRequest to be synchronous and return the id of new statement. - GetReplJobResults to return all or some statements in sorted order. Also return current repl state. - Moved definition of Statement and StatementState to repl-rsc to avoid unnecessary parsing. - livy-server polls only repl state but not statement results from repl. - New StatementState "waiting" to show a statement that's submitted to repl but hasn't been executed. - Send statement results to server in JSON instead of JValue to avoid Kryo issue. --- .../livy/test/framework/LivyRestClient.scala | 9 ++ .../cloudera/livy/test/InteractiveIT.scala | 9 +- .../com/cloudera/livy/repl/ReplDriver.scala | 42 +++--- .../com/cloudera/livy/repl/Session.scala | 53 ++++--- .../cloudera/livy/repl/BaseSessionSpec.scala | 16 +- .../livy/repl/PythonSessionSpec.scala | 38 ++--- .../cloudera/livy/repl/ReplDriverSuite.scala | 9 +- .../livy/repl/SparkRSessionSpec.scala | 35 ++--- .../cloudera/livy/repl/SparkSessionSpec.scala | 42 +++--- .../com/cloudera/livy/rsc/BaseProtocol.java | 26 ++-- .../java/com/cloudera/livy/rsc/RSCClient.java | 14 +- .../com/cloudera/livy/rsc/ReplJobResults.java | 33 ++++ .../cloudera/livy/rsc/driver/RSCDriver.java | 2 +- .../cloudera/livy/rsc/driver/Statement.java | 37 +++++ .../livy/rsc/driver/StatementState.java | 38 +++++ .../interactive/InteractiveSession.scala | 141 ++++++++++-------- .../InteractiveSessionServlet.scala | 31 +--- .../livy/server/interactive/Statement.scala | 66 -------- .../server/interactive/StatementState.scala | 35 ----- .../InteractiveSessionServletSpec.scala | 8 +- .../interactive/InteractiveSessionSpec.scala | 25 ++-- .../server/interactive/StatementSpec.scala | 77 ---------- 22 files changed, 382 insertions(+), 404 deletions(-) create mode 100644 rsc/src/main/java/com/cloudera/livy/rsc/ReplJobResults.java create mode 100644 rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java create mode 100644 rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java delete mode 100644 server/src/main/scala/com/cloudera/livy/server/interactive/Statement.scala delete mode 100644 server/src/main/scala/com/cloudera/livy/server/interactive/StatementState.scala delete mode 100644 server/src/test/scala/com/cloudera/livy/server/interactive/StatementSpec.scala diff --git a/integration-test/src/main/scala/com/cloudera/livy/test/framework/LivyRestClient.scala b/integration-test/src/main/scala/com/cloudera/livy/test/framework/LivyRestClient.scala index 58037a016..8a1e4e66d 100644 --- a/integration-test/src/main/scala/com/cloudera/livy/test/framework/LivyRestClient.scala +++ b/integration-test/src/main/scala/com/cloudera/livy/test/framework/LivyRestClient.scala @@ -186,6 +186,15 @@ class LivyRestClient(val httpClient: AsyncHttpClient, val livyEndpoint: String) def run(code: String): Statement = { new Statement(code) } + def runFatalStatement(code: String): Unit = { + val requestBody = Map("code" -> code) + val r = httpClient.preparePost(s"$url/statements") + .setBody(mapper.writeValueAsString(requestBody)) + .execute() + + verifySessionState(SessionState.Dead()) + } + def verifySessionIdle(): Unit = { verifySessionState(SessionState.Idle()) } diff --git a/integration-test/src/test/scala/com/cloudera/livy/test/InteractiveIT.scala b/integration-test/src/test/scala/com/cloudera/livy/test/InteractiveIT.scala index 8ef63e808..c5a056917 100644 --- a/integration-test/src/test/scala/com/cloudera/livy/test/InteractiveIT.scala +++ b/integration-test/src/test/scala/com/cloudera/livy/test/InteractiveIT.scala @@ -97,8 +97,7 @@ class InteractiveIT extends BaseIntegrationTestSuite { test("application kills session") { withNewSession(Spark()) { s => - s.run("System.exit(0)") - s.verifySessionState(SessionState.Dead()) + s.runFatalStatement("System.exit(0)") } } @@ -134,7 +133,8 @@ class InteractiveIT extends BaseIntegrationTestSuite { test("recover interactive session") { withNewSession(Spark()) { s => - s.run("1").verifyResult("res0: Int = 1") + val stmt1 = s.run("1") + stmt1.verifyResult("res0: Int = 1") // Restart Livy. cluster.stopLivy() @@ -143,7 +143,8 @@ class InteractiveIT extends BaseIntegrationTestSuite { // Verify session still exists. s.verifySessionIdle() s.run("2").verifyResult("res1: Int = 2") - // TODO, verify previous statement results still exist. + // Verify statement result is preserved. + stmt1.verifyResult("res0: Int = 1") s.stop() diff --git a/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala b/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala index 45946625c..32e70901c 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala @@ -18,20 +18,15 @@ package com.cloudera.livy.repl -import scala.collection.mutable -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Await import scala.concurrent.duration.Duration import io.netty.channel.ChannelHandlerContext import org.apache.spark.SparkConf import org.apache.spark.api.java.JavaSparkContext -import org.json4s.DefaultFormats -import org.json4s.JsonAST.JValue -import org.json4s.jackson.JsonMethods._ import com.cloudera.livy.Logging -import com.cloudera.livy.rsc.{BaseProtocol, RSCConf} +import com.cloudera.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf} import com.cloudera.livy.rsc.driver._ import com.cloudera.livy.sessions._ @@ -39,11 +34,6 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) extends RSCDriver(conf, livyConf) with Logging { - // Add here to make it compatible with json4s-jackson 3.2.11 JsonMethods#render API. - private implicit def formats = DefaultFormats - - private val jobFutures = mutable.Map[String, JValue]() - private[repl] var session: Session = _ private val kind = Kind(livyConf.get(RSCConf.Entry.SESSION_KIND)) @@ -73,19 +63,29 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) } } - def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.ReplJobRequest): Unit = { - Future { - jobFutures(msg.id) = session.execute(msg.code).result - } + def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.ReplJobRequest): Int = { + session.execute(msg.code) } - def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResult): String = { - val result = jobFutures.getOrElse(msg.id, null) - Option(result).map { r => compact(render(r)) }.orNull - } + /** + * Return statement results. Results are sorted by statement id. + */ + def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResults): ReplJobResults = + session.synchronized { + val stmts = if (msg.allResults) { + session.statements.values.toArray + } else { + assert(msg.from != null) + assert(msg.size != null) + val until = msg.from + msg.size + session.statements.filterKeys(id => id >= msg.from && id < until).values.toArray + } + val state = session.state.toString + new ReplJobResults(stmts.sortBy(_.id), state) + } def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplState): String = { - return session.state.toString + session.state.toString } override protected def createWrapper(msg: BaseProtocol.BypassJobRequest): BypassJobWrapper = { diff --git a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala index 5c41259a5..04a7e85e0 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala @@ -19,15 +19,18 @@ package com.cloudera.livy.repl import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicInteger -import scala.concurrent.{ExecutionContext, Future, TimeoutException} -import scala.concurrent.duration.Duration +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.SparkContext -import org.json4s.{DefaultFormats, JValue} +import org.json4s.jackson.JsonMethods.{compact, render} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import com.cloudera.livy.{Logging, Utils} +import com.cloudera.livy.Logging +import com.cloudera.livy.rsc.driver.{Statement, StatementState} import com.cloudera.livy.sessions._ object Session { @@ -51,7 +54,9 @@ class Session(interpreter: Interpreter) private implicit val formats = DefaultFormats private var _state: SessionState = SessionState.NotStarted() - private var _history = IndexedSeq[Statement]() + private val _statements = mutable.Map[Int, Statement]() + + private val newStatementId = new AtomicInteger(0) def start(): Future[SparkContext] = { val future = Future { @@ -70,13 +75,26 @@ class Session(interpreter: Interpreter) def state: SessionState = _state - def history: IndexedSeq[Statement] = _history + def statements: mutable.Map[Int, Statement] = _statements + + def execute(code: String): Int = { + val statementId = newStatementId.getAndIncrement() + synchronized { + _statements(statementId) = new Statement(statementId, StatementState.Waiting, null) + } + Future { + synchronized { + _statements(statementId) = new Statement(statementId, StatementState.Running, null) + } + + val statement = + new Statement(statementId, StatementState.Available, executeCode(statementId, code)) - def execute(code: String): Statement = synchronized { - val executionCount = _history.length - val statement = Statement(executionCount, executeCode(executionCount, code)) - _history :+= statement - statement + synchronized { + _statements(statementId) = statement + } + } + statementId } def close(): Unit = { @@ -84,15 +102,14 @@ class Session(interpreter: Interpreter) interpreter.close() } - def clearHistory(): Unit = synchronized { - _history = IndexedSeq() + def clearStatements(): Unit = synchronized { + _statements.clear() } - private def executeCode(executionCount: Int, code: String) = { + private def executeCode(executionCount: Int, code: String): String = synchronized { _state = SessionState.Busy() - try { - + val resultInJson = try { interpreter.execute(code) match { case Interpreter.ExecuteSuccess(data) => _state = SessionState.Idle() @@ -140,7 +157,7 @@ class Session(interpreter: Interpreter) (EVALUE -> e.getMessage) ~ (TRACEBACK -> List()) } + + compact(render(resultInJson)) } } - -case class Statement(id: Int, result: JValue) diff --git a/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala index d5e9a9b14..a484a20ca 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala @@ -22,18 +22,28 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps -import org.json4s.DefaultFormats +import org.json4s._ import org.scalatest.{FlatSpec, Matchers} import org.scalatest.concurrent.Eventually._ import com.cloudera.livy.LivyBaseUnitTestSuite +import com.cloudera.livy.rsc.driver.{Statement, StatementState} import com.cloudera.livy.sessions.SessionState abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitTestSuite { implicit val formats = DefaultFormats - def withSession(testCode: Session => Any): Unit = { + protected def execute(session: Session)(code: String): Statement = { + val id = session.execute(code) + eventually(timeout(30 seconds), interval(100 millis)) { + val s = session.statements(id) + s.state shouldBe StatementState.Available + s + } + } + + protected def withSession(testCode: Session => Any): Unit = { val session = new Session(createInterpreter()) try { Await.ready(session.start(), 30 seconds) @@ -44,7 +54,7 @@ abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitT } } - def createInterpreter(): Interpreter + protected def createInterpreter(): Interpreter it should "start in the starting or idle state" in { val session = new Session(createInterpreter()) diff --git a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala index ec5e31ba5..4582acdf1 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala @@ -20,6 +20,7 @@ package com.cloudera.livy.repl import org.apache.spark.SparkConf import org.json4s.Extraction +import org.json4s.jackson.JsonMethods.parse import org.scalatest._ import com.cloudera.livy.sessions._ @@ -27,10 +28,10 @@ import com.cloudera.livy.sessions._ abstract class PythonSessionSpec extends BaseSessionSpec { it should "execute `1 + 2` == 3" in withSession { session => - val statement = session.execute("1 + 2") + val statement = execute(session)("1 + 2") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -43,10 +44,11 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } it should "execute `x = 1`, then `y = 2`, then `x + y`" in withSession { session => - var statement = session.execute("x = 1") + val executeWithSession = execute(session)(_) + var statement = executeWithSession("x = 1") statement.id should equal (0) - var result = statement.result + var result = parse(statement.output) var expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -57,10 +59,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec { result should equal (expectedResult) - statement = session.execute("y = 2") + statement = executeWithSession("y = 2") statement.id should equal (1) - result = statement.result + result = parse(statement.output) expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 1, @@ -71,10 +73,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec { result should equal (expectedResult) - statement = session.execute("x + y") + statement = executeWithSession("x + y") statement.id should equal (2) - result = statement.result + result = parse(statement.output) expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 2, @@ -87,10 +89,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } it should "do table magic" in withSession { session => - val statement = session.execute("x = [[1, 'a'], [3, 'b']]\n%table x") + val statement = execute(session)("x = [[1, 'a'], [3, 'b']]\n%table x") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -108,10 +110,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } it should "capture stdout" in withSession { session => - val statement = session.execute("""print('Hello World')""") + val statement = execute(session)("""print('Hello World')""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -124,10 +126,10 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } it should "report an error if accessing an unknown variable" in withSession { session => - val statement = session.execute("""x""") + val statement = execute(session)("""x""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "error", "execution_count" -> 0, @@ -143,7 +145,7 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } it should "report an error if exception is thrown" in withSession { session => - val statement = session.execute( + val statement = execute(session)( """def func1(): | raise Exception("message") |def func2(): @@ -152,7 +154,7 @@ abstract class PythonSessionSpec extends BaseSessionSpec { """.stripMargin) statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "error", "execution_count" -> 0, @@ -184,13 +186,13 @@ class Python3SessionSpec extends PythonSessionSpec { override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) it should "check python version is 3.x" in withSession { session => - val statement = session.execute( + val statement = execute(session)( """import sys |sys.version >= '3' """.stripMargin) statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, diff --git a/repl/src/test/scala/com/cloudera/livy/repl/ReplDriverSuite.scala b/repl/src/test/scala/com/cloudera/livy/repl/ReplDriverSuite.scala index 989af8f0c..79cfeb35d 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/ReplDriverSuite.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/ReplDriverSuite.scala @@ -56,11 +56,12 @@ class ReplDriverSuite extends FunSuite with LivyBaseUnitTestSuite { assert(client.getReplState().get(10, TimeUnit.SECONDS) === "idle") - val statementId = client.submitReplCode("1 + 1") + val statementId = client.submitReplCode("1 + 1").get eventually(timeout(30 seconds), interval(100 millis)) { - val rawResult = client.getReplJobResult(statementId).get(10, TimeUnit.SECONDS) - val result = parse(rawResult) - assert((result \ Session.STATUS).extract[String] === Session.OK) + val rawResult = + client.getReplJobResults(statementId, 1).get(10, TimeUnit.SECONDS).statements(0) + val result = rawResult.output + assert((parse(result) \ Session.STATUS).extract[String] === Session.OK) } } finally { client.stop(true) diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala index 7741aa149..b42b4f3ff 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala @@ -18,13 +18,9 @@ package com.cloudera.livy.repl -import scala.concurrent.Await -import scala.concurrent.duration.Duration - import org.apache.spark.SparkConf import org.json4s.Extraction -import org.json4s.JsonAST.JValue -import org.scalatest._ +import org.json4s.jackson.JsonMethods.parse class SparkRSessionSpec extends BaseSessionSpec { @@ -36,10 +32,10 @@ class SparkRSessionSpec extends BaseSessionSpec { override def createInterpreter(): Interpreter = SparkRInterpreter(new SparkConf()) it should "execute `1 + 2` == 3" in withSession { session => - val statement = session.execute("1 + 2") + val statement = execute(session)("1 + 2") statement.id should equal(0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -52,10 +48,11 @@ class SparkRSessionSpec extends BaseSessionSpec { } it should "execute `x = 1`, then `y = 2`, then `x + y`" in withSession { session => - var statement = session.execute("x = 1") + val executeWithSession = execute(session)(_) + var statement = executeWithSession("x = 1") statement.id should equal (0) - var result = statement.result + var result = parse(statement.output) var expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -66,10 +63,10 @@ class SparkRSessionSpec extends BaseSessionSpec { result should equal (expectedResult) - statement = session.execute("y = 2") + statement = executeWithSession("y = 2") statement.id should equal (1) - result = statement.result + result = parse(statement.output) expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 1, @@ -80,10 +77,10 @@ class SparkRSessionSpec extends BaseSessionSpec { result should equal (expectedResult) - statement = session.execute("x + y") + statement = executeWithSession("x + y") statement.id should equal (2) - result = statement.result + result = parse(statement.output) expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 2, @@ -96,10 +93,10 @@ class SparkRSessionSpec extends BaseSessionSpec { } it should "capture stdout from print" in withSession { session => - val statement = session.execute("""print('Hello World')""") + val statement = execute(session)("""print('Hello World')""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -112,10 +109,10 @@ class SparkRSessionSpec extends BaseSessionSpec { } it should "capture stdout from cat" in withSession { session => - val statement = session.execute("""cat(3)""") + val statement = execute(session)("""cat(3)""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -128,10 +125,10 @@ class SparkRSessionSpec extends BaseSessionSpec { } it should "report an error if accessing an unknown variable" in withSession { session => - val statement = session.execute("""x""") + val statement = execute(session)("""x""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala index ed48eb004..b36bb1a59 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala @@ -20,6 +20,7 @@ package com.cloudera.livy.repl import org.apache.spark.SparkConf import org.json4s.Extraction +import org.json4s.jackson.JsonMethods.parse import org.json4s.JsonAST.JValue class SparkSessionSpec extends BaseSessionSpec { @@ -27,10 +28,10 @@ class SparkSessionSpec extends BaseSessionSpec { override def createInterpreter(): Interpreter = new SparkInterpreter(new SparkConf()) it should "execute `1 + 2` == 3" in withSession { session => - val statement = session.execute("1 + 2") + val statement = execute(session)("1 + 2") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -43,10 +44,11 @@ class SparkSessionSpec extends BaseSessionSpec { } it should "execute `x = 1`, then `y = 2`, then `x + y`" in withSession { session => - var statement = session.execute("val x = 1") + val executeWithSession = execute(session)(_) + var statement = executeWithSession("val x = 1") statement.id should equal (0) - var result = statement.result + var result = parse(statement.output) var expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -57,10 +59,10 @@ class SparkSessionSpec extends BaseSessionSpec { result should equal (expectedResult) - statement = session.execute("val y = 2") + statement = executeWithSession("val y = 2") statement.id should equal (1) - result = statement.result + result = parse(statement.output) expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 1, @@ -71,10 +73,10 @@ class SparkSessionSpec extends BaseSessionSpec { result should equal (expectedResult) - statement = session.execute("x + y") + statement = executeWithSession("x + y") statement.id should equal (2) - result = statement.result + result = parse(statement.output) expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 2, @@ -87,10 +89,10 @@ class SparkSessionSpec extends BaseSessionSpec { } it should "capture stdout" in withSession { session => - val statement = session.execute("""println("Hello World")""") + val statement = execute(session)("""println("Hello World")""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -103,10 +105,10 @@ class SparkSessionSpec extends BaseSessionSpec { } it should "report an error if accessing an unknown variable" in withSession { session => - val statement = session.execute("""x""") + val statement = execute(session)("""x""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) def extract(key: String): String = (result \ key).extract[String] @@ -117,14 +119,14 @@ class SparkSessionSpec extends BaseSessionSpec { } it should "report an error if exception is thrown" in withSession { session => - val statement = session.execute( + val statement = execute(session)( """def func1() { |throw new Exception() |} |func1()""".stripMargin) statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val resultMap = result.extract[Map[String, JValue]] // Manually extract the values since the line numbers in the exception could change. @@ -138,10 +140,10 @@ class SparkSessionSpec extends BaseSessionSpec { } it should "access the spark context" in withSession { session => - val statement = session.execute("""sc""") + val statement = execute(session)("""sc""") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val resultMap = result.extract[Map[String, JValue]] // Manually extract the values since the line numbers in the exception could change. @@ -154,11 +156,11 @@ class SparkSessionSpec extends BaseSessionSpec { } it should "execute spark commands" in withSession { session => - val statement = session.execute( + val statement = execute(session)( """sc.parallelize(0 to 1).map{i => i+1}.collect""".stripMargin) statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", @@ -172,10 +174,10 @@ class SparkSessionSpec extends BaseSessionSpec { } it should "do table magic" in withSession { session => - val statement = session.execute("val x = List((1, \"a\"), (3, \"b\"))\n%table x") + val statement = execute(session)("val x = List((1, \"a\"), (3, \"b\"))\n%table x") statement.id should equal (0) - val result = statement.result + val result = parse(statement.output) val expectedResult = Extraction.decompose(Map( "status" -> "ok", diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/BaseProtocol.java b/rsc/src/main/java/com/cloudera/livy/rsc/BaseProtocol.java index f21266ac5..3139a9d61 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/BaseProtocol.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/BaseProtocol.java @@ -18,6 +18,7 @@ package com.cloudera.livy.rsc; import com.cloudera.livy.Job; +import com.cloudera.livy.rsc.driver.Statement; import com.cloudera.livy.rsc.rpc.RpcDispatcher; public abstract class BaseProtocol extends RpcDispatcher { @@ -171,30 +172,31 @@ public RemoteDriverAddress() { public static class ReplJobRequest { public final String code; - public final String id; - public ReplJobRequest(String code, String id) { + public ReplJobRequest(String code) { this.code = code; - this.id = id; } public ReplJobRequest() { - this(null, null); + this(null); } } - public static class GetReplJobResult { + public static class GetReplJobResults { + public boolean allResults; + public Integer from, size; - public final String id; - - public GetReplJobResult(String id) { - this.id = id; + public GetReplJobResults(Integer from, Integer size) { + this.allResults = false; + this.from = from; + this.size = size; } - public GetReplJobResult() { - this(null); + public GetReplJobResults() { + this.allResults = true; + from = null; + size = null; } - } public static class GetReplState { diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java b/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java index 525c9a6c1..b638266b3 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java @@ -274,14 +274,16 @@ ContextInfo getContextInfo() { return contextInfo; } - public String submitReplCode(String code) throws Exception { - String id = UUID.randomUUID().toString(); - deferredCall(new BaseProtocol.ReplJobRequest(code, id), Void.class); - return id; + public Future submitReplCode(String code) throws Exception { + return deferredCall(new BaseProtocol.ReplJobRequest(code), Integer.class); } - public Future getReplJobResult(String id) throws Exception { - return deferredCall(new BaseProtocol.GetReplJobResult(id), String.class); + public Future getReplJobResults(Integer from, Integer size) throws Exception { + return deferredCall(new BaseProtocol.GetReplJobResults(from, size), ReplJobResults.class); + } + + public Future getReplJobResults() throws Exception { + return deferredCall(new BaseProtocol.GetReplJobResults(), ReplJobResults.class); } public Future getReplState() { diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/ReplJobResults.java b/rsc/src/main/java/com/cloudera/livy/rsc/ReplJobResults.java new file mode 100644 index 000000000..c3c6df0d0 --- /dev/null +++ b/rsc/src/main/java/com/cloudera/livy/rsc/ReplJobResults.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.cloudera.livy.rsc; + +import com.cloudera.livy.rsc.driver.Statement; + +public class ReplJobResults { + public final Statement[] statements; + public final String replState; + + public ReplJobResults(Statement[] statements, String replState) { + this.statements = statements; + this.replState = replState; + } + + public ReplJobResults() { + this(null, null); + } +} diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/driver/RSCDriver.java b/rsc/src/main/java/com/cloudera/livy/rsc/driver/RSCDriver.java index ca1b8e2a3..bc7b77cbc 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/driver/RSCDriver.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/driver/RSCDriver.java @@ -80,7 +80,7 @@ public class RSCDriver extends BaseProtocol { // Used to queue up requests while the SparkContext is being created. private final List> jobQueue; // Keeps track of connected clients. - private final Collection clients; + protected final Collection clients; final Map> activeJobs; private final Collection bypassJobs; diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java b/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java new file mode 100644 index 000000000..6f5235e9e --- /dev/null +++ b/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.cloudera.livy.rsc.driver; + +import com.fasterxml.jackson.annotation.JsonRawValue; + +public class Statement { + public final Integer id; + public final StatementState state; + @JsonRawValue + public final String output; + + public Statement(Integer id, StatementState state, String output) { + this.id = id; + this.state = state; + this.output = output; + } + + public Statement() { + this(null, null, null); + } +} diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java b/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java new file mode 100644 index 000000000..590a55362 --- /dev/null +++ b/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.cloudera.livy.rsc.driver; + +import com.fasterxml.jackson.annotation.JsonValue; + +public enum StatementState { + Waiting("waiting"), + Running("running"), + Available("available"); + + private final String state; + + StatementState(final String text) { + this.state = text; + } + + @JsonValue + @Override + public String toString() { + return state; + } +} diff --git a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala index 94f66fcec..c9ee67926 100644 --- a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala +++ b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala @@ -24,21 +24,18 @@ import java.nio.ByteBuffer import java.nio.file.{Files, Paths} import java.util.concurrent.atomic.AtomicLong -import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.concurrent.{Future, _} -import scala.util.{Failure, Random, Success, Try} +import scala.concurrent.Future +import scala.util.Random import com.fasterxml.jackson.annotation.JsonIgnoreProperties import org.apache.spark.launcher.SparkLauncher -import org.json4s._ -import org.json4s.JsonAST.JString -import org.json4s.jackson.JsonMethods._ import com.cloudera.livy._ import com.cloudera.livy.client.common.HttpMessages._ import com.cloudera.livy.rsc.{PingJob, RSCClient, RSCConf} +import com.cloudera.livy.rsc.driver.Statement import com.cloudera.livy.server.recovery.SessionStore import com.cloudera.livy.sessions._ import com.cloudera.livy.sessions.Session._ @@ -351,6 +348,21 @@ class InteractiveSession( _appId = appIdHint sessionStore.save(RECOVERY_SESSION_TYPE, recoveryMetadata) + // TODO Replace this with a Rpc call from repl to server. + private val stateThread = new Thread(new Runnable { + override def run(): Unit = { + try { + while (_state.isActive) { + // State is also updated when we get statement results from repl, not just here. + setSessionStateFromReplState(client.map(_.getReplState.get())) + Thread.sleep(30000) + } + } catch { + case _: InterruptedException => + } + } + }) + private val app = mockApp.orElse { if (livyConf.isRunningOnYarn()) { // When Livy is running with YARN, SparkYarnApp can provide better YARN integration. @@ -385,6 +397,8 @@ class InteractiveSession( sessionStore.save(RECOVERY_SESSION_TYPE, recoveryMetadata) } transition(SessionState.Idle()) + stateThread.setDaemon(true) + stateThread.start() } private def errorOut(): Unit = { @@ -399,9 +413,6 @@ class InteractiveSession( }) } - private[this] var _executedStatements = 0 - private[this] var _statements = IndexedSeq[Statement]() - override def logLines(): IndexedSeq[String] = app.map(_.log()).getOrElse(sessionLog) override def recoveryMetadata: RecoveryMetadata = @@ -412,6 +423,11 @@ class InteractiveSession( override def stopSession(): Unit = { try { transition(SessionState.ShuttingDown()) + if (stateThread.isAlive) { + stateThread.interrupt() + stateThread.join() + } + sessionStore.remove(RECOVERY_SESSION_TYPE, id) client.foreach { _.stop(true) } } catch { case _: Exception => @@ -424,7 +440,25 @@ class InteractiveSession( } } - def statements: IndexedSeq[Statement] = _statements + def statements: IndexedSeq[Statement] = { + ensureActive() + val r = client.get.getReplJobResults().get() + + setSessionStateFromReplState(Option(r.replState)) + r.statements.toIndexedSeq + } + + def getStatement(stmtId: Int): Option[Statement] = { + ensureActive() + val r = client.get.getReplJobResults(stmtId, 1).get() + + setSessionStateFromReplState(Option(r.replState)) + if (r.statements.length < 1) { + None + } else { + Option(r.statements(0)) + } + } def interrupt(): Future[Unit] = { stop() @@ -432,20 +466,11 @@ class InteractiveSession( def executeStatement(content: ExecuteRequest): Statement = { ensureRunning() - _state = SessionState.Busy() + setSessionStateFromReplState(client.map(_.getReplState.get())) recordActivity() - val future = Future { - val id = client.get.submitReplCode(content.code) - waitForStatement(id) - } - - val statement = new Statement(_executedStatements, content, future) - - _executedStatements += 1 - _statements = _statements :+ statement - - statement + val id = client.get.submitReplCode(content.code).get + client.get.getReplJobResults(id, 1).get().statements(0) } def runJob(job: Array[Byte]): Long = { @@ -465,19 +490,19 @@ class InteractiveSession( } def addFile(uri: URI): Unit = { - ensureRunning() + ensureActive() recordActivity() client.get.addFile(resolveURI(uri, livyConf)).get() } def addJar(uri: URI): Unit = { - ensureRunning() + ensureActive() recordActivity() client.get.addJar(resolveURI(uri, livyConf)).get() } def jobStatus(id: Long): Any = { - ensureRunning() + ensureActive() val clientJobId = operations(id) recordActivity() // TODO: don't block indefinitely? @@ -486,54 +511,29 @@ class InteractiveSession( } def cancelJob(id: Long): Unit = { - ensureRunning() + ensureActive() recordActivity() operations.remove(id).foreach { client.get.cancel } } - @tailrec - private def waitForStatement(id: String): JValue = { - ensureRunning() - Try(client.get.getReplJobResult(id).get()) match { - case Success(null) => - Thread.sleep(1000) - waitForStatement(id) - - case Success(response) => - val result = parse(response) - // If the response errored out, it's possible it took down the interpreter. Check if - // it's still running. - result \ "status" match { - case JString("error") => - val state = client.get.getReplState().get() match { - case "error" => SessionState.Error() - case _ => SessionState.Idle() - } - transition(state) - case _ => transition(SessionState.Idle()) - } - result - - - case Failure(err) => - // If any other error occurs, it probably means the session died. Transition to - // the error state. - transition(SessionState.Error()) - throw err - } - } - - private def transition(state: SessionState) = synchronized { + private def transition(newState: SessionState) = synchronized { // When a statement returns an error, the session should transit to error state. // If the session crashed because of the error, the session should instead go to dead state. // Since these 2 transitions are triggered by different threads, there's a race condition. // Make sure we won't transit from dead to error state. - if (!_state.isInstanceOf[SessionState.Dead] || !state.isInstanceOf[SessionState.Error]) { - debug(s"$this session state change from ${_state} to $state") - _state = state + val areSameStates = _state.getClass() == newState.getClass() + val transitFromInactiveToActive = !_state.isActive && newState.isActive + if (!areSameStates && !transitFromInactiveToActive) { + debug(s"$this session state change from ${_state} to $newState") + _state = newState } } + private def ensureActive(): Unit = synchronized { + require(_state.isActive, "Session isn't active.") + require(client.isDefined, "Session is active but client hasn't been created.") + } + private def ensureRunning(): Unit = synchronized { _state match { case SessionState.Idle() | SessionState.Busy() => @@ -543,7 +543,7 @@ class InteractiveSession( } private def performOperation(job: Array[Byte], sync: Boolean): Long = { - ensureRunning() + ensureActive() recordActivity() val future = client.get.bypass(ByteBuffer.wrap(job), sync) val opId = operationCounter.incrementAndGet() @@ -551,6 +551,21 @@ class InteractiveSession( opId } + private def setSessionStateFromReplState(newStateStr: Option[String]): Unit = { + val newState = newStateStr match { + case Some("starting") => SessionState.Starting() + case Some("idle") => SessionState.Idle() + case Some("busy") => SessionState.Busy() + case Some("error") => SessionState.Error() + case Some(s) => // Should not happen. + warn(s"Unexpected repl state $s") + SessionState.Error() + case None => + SessionState.Dead() + } + transition(newState) + } + override def appIdKnown(appId: String): Unit = { _appId = Option(appId) sessionSaveLock.synchronized { diff --git a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala index 26f1dcca9..e6de19946 100644 --- a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala +++ b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala @@ -19,7 +19,6 @@ package com.cloudera.livy.server.interactive import java.net.URI -import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.JavaConverters._ @@ -33,6 +32,7 @@ import org.scalatra.servlet.FileUploadSupport import com.cloudera.livy.{ExecuteRequest, JobHandle, LivyConf, Logging} import com.cloudera.livy.client.common.HttpMessages import com.cloudera.livy.client.common.HttpMessages._ +import com.cloudera.livy.rsc.driver.Statement import com.cloudera.livy.server.SessionServlet import com.cloudera.livy.server.recovery.SessionStore import com.cloudera.livy.sessions._ @@ -84,18 +84,6 @@ class InteractiveSessionServlet( session.state.toString, session.kind.toString, session.appInfo.asJavaMap, logs.asJava) } - private def statementView(statement: Statement): Any = { - val output = try { - Await.result(statement.output(), Duration(100, TimeUnit.MILLISECONDS)) - } catch { - case _: TimeoutException => null - } - Map( - "id" -> statement.id, - "state" -> statement.state.toString, - "output" -> output) - } - post("/:id/stop") { withSession { session => Await.ready(session.stop(), Duration.Inf) @@ -112,12 +100,13 @@ class InteractiveSessionServlet( get("/:id/statements") { withSession { session => + val statements = session.statements val from = params.get("from").map(_.toInt).getOrElse(0) - val size = params.get("size").map(_.toInt).getOrElse(session.statements.length) + val size = params.get("size").map(_.toInt).getOrElse(statements.length) Map( - "total_statements" -> session.statements.length, - "statements" -> session.statements.view(from, from + size).map(statementView) + "total_statements" -> statements.length, + "statements" -> statements.view(from, from + size) ) } } @@ -125,14 +114,8 @@ class InteractiveSessionServlet( val getStatement = get("/:id/statements/:statementId") { withSession { session => val statementId = params("statementId").toInt - val from = params.get("from").map(_.toInt) - val size = params.get("size").map(_.toInt) - session.statements.lift(statementId) match { - case None => NotFound("Statement not found") - case Some(statement) => - statementView(statement) - } + session.getStatement(statementId).getOrElse(NotFound("Statement not found")) } } @@ -140,7 +123,7 @@ class InteractiveSessionServlet( withSession { session => val statement = session.executeStatement(req) - Created(statementView(statement), + Created(statement, headers = Map( "Location" -> url(getStatement, "id" -> session.id.toString, diff --git a/server/src/main/scala/com/cloudera/livy/server/interactive/Statement.scala b/server/src/main/scala/com/cloudera/livy/server/interactive/Statement.scala deleted file mode 100644 index c3e6fd36e..000000000 --- a/server/src/main/scala/com/cloudera/livy/server/interactive/Statement.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to Cloudera, Inc. under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. Cloudera, Inc. licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server.interactive - -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} -import scala.util.{Failure, Success} - -import org.json4s.JsonAST.{JArray, JField, JObject, JString} -import org.json4s.JValue - -import com.cloudera.livy.ExecuteRequest - -class Statement(val id: Int, val request: ExecuteRequest, _output: Future[JValue]) { - protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global - - private[this] var _state: StatementState = StatementState.Running() - - def state: StatementState = _state - - def output(from: Option[Int] = None, size: Option[Int] = None): Future[JValue] = { - _output.map { case output => - if (from.isEmpty && size.isEmpty) { - output - } else { - val from_ = from.getOrElse(0) - val size_ = size.getOrElse(100) - val until = from_ + size_ - - output \ "data" match { - case JObject(JField("text/plain", JString(text)) :: Nil) => - val lines = text.split('\n').slice(from_, until) - output.replace( - "data" :: "text/plain" :: Nil, - JString(lines.mkString("\n"))) - case JObject(JField("application/json", JArray(items)) :: Nil) => - output.replace( - "data" :: "application/json" :: Nil, - JArray(items.slice(from_, until))) - case _ => - output - } - } - } - } - - _output.onComplete { - case Success(_) => _state = StatementState.Available() - case Failure(_) => _state = StatementState.Error() - } -} diff --git a/server/src/main/scala/com/cloudera/livy/server/interactive/StatementState.scala b/server/src/main/scala/com/cloudera/livy/server/interactive/StatementState.scala deleted file mode 100644 index 3f7b2ec80..000000000 --- a/server/src/main/scala/com/cloudera/livy/server/interactive/StatementState.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to Cloudera, Inc. under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. Cloudera, Inc. licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server.interactive - -sealed trait StatementState - -object StatementState { - case class Running() extends StatementState { - override def toString: String = "running" - } - - case class Available() extends StatementState { - override def toString: String = "available" - } - - case class Error() extends StatementState { - override def toString: String = "error" - } -} diff --git a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala index 8c3175610..0a3fb99fc 100644 --- a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala @@ -35,6 +35,7 @@ import org.scalatest.mock.MockitoSugar.mock import com.cloudera.livy.{ExecuteRequest, LivyConf} import com.cloudera.livy.client.common.HttpMessages.SessionInfo +import com.cloudera.livy.rsc.driver.{Statement, StatementState} import com.cloudera.livy.server.recovery.SessionStore import com.cloudera.livy.sessions._ import com.cloudera.livy.utils.AppInfo @@ -69,11 +70,10 @@ class InteractiveSessionServletSpec extends BaseInteractiveServletSpec { new Answer[Statement]() { override def answer(args: InvocationOnMock): Statement = { val id = statementCounter.getAndIncrement - val executeRequest = args.getArguments()(0).asInstanceOf[ExecuteRequest] val statement = new Statement( id, - executeRequest, - Future.successful(JObject(JField("value", JInt(42))))) + StatementState.Available, + "1") statements :+= statement statement @@ -114,7 +114,7 @@ class InteractiveSessionServletSpec extends BaseInteractiveServletSpec { jpost[Map[String, Any]]("/0/statements", ExecuteRequest("foo")) { data => data("id") should be (0) - data("output") should be (Map("value" -> 42)) + data("output") shouldBe 1 } jget[Map[String, Any]]("/0/statements") { data => diff --git a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala index 077b33024..c5fd5925d 100644 --- a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala @@ -25,7 +25,8 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.launcher.SparkLauncher -import org.json4s.{DefaultFormats, Extraction} +import org.json4s.{DefaultFormats, Extraction, JValue} +import org.json4s.jackson.JsonMethods.parse import org.mockito.{Matchers => MockitoMatchers} import org.mockito.Matchers._ import org.mockito.Mockito.{atLeastOnce, verify, when} @@ -35,6 +36,7 @@ import org.scalatest.mock.MockitoSugar.mock import com.cloudera.livy.{ExecuteRequest, JobHandle, LivyBaseUnitTestSuite, LivyConf} import com.cloudera.livy.rsc.{PingJob, RSCClient, RSCConf} +import com.cloudera.livy.rsc.driver.StatementState import com.cloudera.livy.server.recovery.SessionStore import com.cloudera.livy.sessions.{PySpark, SessionState, Spark} import com.cloudera.livy.utils.{AppInfo, SparkApp} @@ -70,6 +72,15 @@ class InteractiveSessionSpec extends FunSpec InteractiveSession.create(0, null, None, livyConf, req, sessionStore, mockApp) } + private def executeStatement(code: String): JValue = { + val id = session.executeStatement(ExecuteRequest(code)).id + eventually(timeout(30 seconds), interval(100 millis)) { + val s = session.getStatement(id).get + s.state shouldBe StatementState.Available + parse(s.output) + } + } + override def afterAll(): Unit = { if (session != null) { Await.ready(session.stop(), 30 seconds) @@ -112,9 +123,7 @@ class InteractiveSessionSpec extends FunSpec } withSession("should execute `1 + 2` == 3") { session => - val stmt = session.executeStatement(ExecuteRequest("1 + 2")) - val result = Await.result(stmt.output(), 30 seconds) - + val result = executeStatement("1 + 2") val expectedResult = Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, @@ -127,8 +136,7 @@ class InteractiveSessionSpec extends FunSpec } withSession("should report an error if accessing an unknown variable") { session => - val stmt = session.executeStatement(ExecuteRequest("x")) - val result = Await.result(stmt.output(), 30 seconds) + val result = executeStatement("x") val expectedResult = Extraction.decompose(Map( "status" -> "error", "execution_count" -> 1, @@ -145,12 +153,11 @@ class InteractiveSessionSpec extends FunSpec } withSession("should error out the session if the interpreter dies") { session => - val stmt = session.executeStatement(ExecuteRequest("import os; os._exit(1)")) - val result = Await.result(stmt.output(), 30 seconds) + executeStatement("import os; os._exit(666)") (session.state match { case SessionState.Error(_) => true case _ => false - }) should equal (true) + }) should equal(true) } } diff --git a/server/src/test/scala/com/cloudera/livy/server/interactive/StatementSpec.scala b/server/src/test/scala/com/cloudera/livy/server/interactive/StatementSpec.scala deleted file mode 100644 index c98b2bab1..000000000 --- a/server/src/test/scala/com/cloudera/livy/server/interactive/StatementSpec.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to Cloudera, Inc. under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. Cloudera, Inc. licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server.interactive - -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.Duration - -import org.json4s.{DefaultFormats, Extraction} -import org.json4s.JsonAST.JString -import org.scalatest.{FunSpec, Matchers} - -import com.cloudera.livy.{ExecuteRequest, LivyBaseUnitTestSuite} - -class StatementSpec extends FunSpec with Matchers with LivyBaseUnitTestSuite { - - implicit val formats = DefaultFormats - - describe("A statement") { - it("should support paging through text/plain data") { - val lines = List("1", "2", "3", "4", "5") - val rep = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 0, - "data" -> Map( - "text/plain" -> lines.mkString("\n") - ) - )) - val stmt = new Statement(0, ExecuteRequest(""), Future.successful(rep)) - var output = Await.result(stmt.output(), Duration.Inf) - output \ "data" \ "text/plain" should equal (JString(lines.mkString("\n"))) - - output = Await.result(stmt.output(Some(2)), Duration.Inf) - output \ "data" \ "text/plain" should equal ( - JString(lines.slice(2, lines.length).mkString("\n"))) - - output = Await.result(stmt.output(Some(2), Some(1)), Duration.Inf) - output \ "data" \ "text/plain" should equal (JString(lines.slice(2, 3).mkString("\n"))) - } - - it("should support paging through application/json arrays") { - val lines = List("1", "2", "3", "4") - val rep = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 0, - "data" -> Map( - "application/json" -> List(1, 2, 3, 4) - ) - )) - val stmt = new Statement(0, ExecuteRequest(""), Future.successful(rep)) - var output = Await.result(stmt.output(), Duration.Inf) - (output \ "data" \ "application/json").extract[List[Int]] should equal (List(1, 2, 3, 4)) - - output = Await.result(stmt.output(Some(2)), Duration.Inf) - (output \ "data" \ "application/json").extract[List[Int]] should equal (List(3, 4)) - - output = Await.result(stmt.output(Some(2), Some(1)), Duration.Inf) - (output \ "data" \ "application/json").extract[List[Int]] should equal (List(3)) - } - } - -}