From c1aafeb6cb87f2bd7f4cb7cf538822b59fb34a9c Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 31 Aug 2017 21:07:42 +0800 Subject: [PATCH] [LIVY-194][REPL] Add shared language support for Livy interactive session This is a ongoing work. Putting it here to leverage travis to test. In the current Livy when we create a new session we need to specify session kind, which represents the interpreter kind bound with this session. User could choose either `spark` or `pyspark` or `sparkr`, Livy internally will create a related interpreter with current session. Also in Livy each session represents a complete Spark application / SparkContext. This brings a limitation that user can only choose one language along with its own Spark application. Furthermore, data generated by `pyspark` interpreter cannot be shared with `spark` interpreter. So to improve the usability of Livy we propose to add multiple interpreters support in one session. Author: jerryshao Closes #28 from jerryshao/LIVY-194. Change-Id: I743871c80ccb5c16101236e052d5f31662382667 --- .../livy/client/common/HttpMessages.java | 6 +- .../livy/client/http/JobHandleImpl.java | 2 +- .../livy/client/http/HttpClientSpec.scala | 2 +- .../src/main/scala/org/apache/livy/msgs.scala | 2 +- .../scala/org/apache/livy/sessions/Kind.scala | 10 +- integration-test/pom.xml | 2 +- python-api/src/main/python/livy/client.py | 4 +- .../apache/livy/repl/SparkInterpreter.scala | 28 ++-- .../apache/livy/repl/SparkInterpreter.scala | 26 ++- repl/src/main/resources/fake_shell.py | 34 +++- .../livy/repl/AbstractSparkInterpreter.scala | 52 ++++++ .../apache/livy/repl/BypassPySparkJob.scala | 7 +- .../org/apache/livy/repl/Interpreter.scala | 5 +- .../apache/livy/repl/ProcessInterpreter.scala | 10 +- .../apache/livy/repl/PythonInterpreter.scala | 25 ++- .../org/apache/livy/repl/ReplDriver.scala | 48 +++--- .../scala/org/apache/livy/repl/Session.scala | 120 +++++++++----- .../livy/repl/SparkContextInitializer.scala | 126 --------------- .../apache/livy/repl/SparkRInterpreter.scala | 66 +++++--- .../apache/livy/repl/BaseSessionSpec.scala | 20 +-- .../livy/repl/PythonInterpreterSpec.scala | 24 ++- .../apache/livy/repl/PythonSessionSpec.scala | 21 ++- .../apache/livy/repl/ReplDriverSuite.scala | 2 +- .../org/apache/livy/repl/SessionSpec.scala | 61 +++---- .../apache/livy/repl/SharedSessionSpec.scala | 126 +++++++++++++++ .../livy/repl/SparkRInterpreterSpec.scala | 10 +- .../apache/livy/repl/SparkRSessionSpec.scala | 7 +- .../apache/livy/repl/SparkSessionSpec.scala | 7 +- .../org/apache/livy/rsc/BaseProtocol.java | 12 +- .../org/apache/livy/rsc/ContextLauncher.java | 8 +- .../java/org/apache/livy/rsc/RSCClient.java | 13 +- .../livy/rsc/driver/JobContextImpl.java | 65 ++------ .../org/apache/livy/rsc/driver/RSCDriver.java | 17 +- .../apache/livy/rsc/driver/SparkEntries.java | 149 ++++++++++++++++++ .../org/apache/livy/rsc/TestSparkClient.java | 15 +- .../CreateInteractiveRequest.scala | 7 +- .../interactive/InteractiveSession.scala | 37 ++--- .../InteractiveSessionServlet.scala | 9 +- server/src/test/resources/log4j.properties | 2 +- .../InteractiveSessionServletSpec.scala | 2 +- .../interactive/InteractiveSessionSpec.scala | 43 ++--- .../livy/server/interactive/JobApiSpec.scala | 6 +- 42 files changed, 738 insertions(+), 500 deletions(-) delete mode 100644 repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala create mode 100644 repl/src/test/scala/org/apache/livy/repl/SharedSessionSpec.scala create mode 100644 rsc/src/main/java/org/apache/livy/rsc/driver/SparkEntries.java diff --git a/client-common/src/main/java/org/apache/livy/client/common/HttpMessages.java b/client-common/src/main/java/org/apache/livy/client/common/HttpMessages.java index 99ce900a1..b1e253fb0 100644 --- a/client-common/src/main/java/org/apache/livy/client/common/HttpMessages.java +++ b/client-common/src/main/java/org/apache/livy/client/common/HttpMessages.java @@ -82,13 +82,15 @@ private SessionInfo() { public static class SerializedJob implements ClientMessage { public final byte[] job; + public final String jobType; - public SerializedJob(byte[] job) { + public SerializedJob(byte[] job, String jobType) { this.job = job; + this.jobType = jobType; } private SerializedJob() { - this(null); + this(null, null); } } diff --git a/client-http/src/main/java/org/apache/livy/client/http/JobHandleImpl.java b/client-http/src/main/java/org/apache/livy/client/http/JobHandleImpl.java index f0ffb59bd..d39dfe994 100644 --- a/client-http/src/main/java/org/apache/livy/client/http/JobHandleImpl.java +++ b/client-http/src/main/java/org/apache/livy/client/http/JobHandleImpl.java @@ -138,7 +138,7 @@ void start(final String command, final ByteBuffer serializedJob) { @Override public void run() { try { - ClientMessage msg = new SerializedJob(BufferUtils.toByteArray(serializedJob)); + ClientMessage msg = new SerializedJob(BufferUtils.toByteArray(serializedJob), "spark"); JobStatus status = conn.post(msg, JobStatus.class, "/%d/%s", sessionId, command); if (isCancelPending) { diff --git a/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala b/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala index 4984eaec2..5ea6f8d40 100644 --- a/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala +++ b/client-http/src/test/scala/org/apache/livy/client/http/HttpClientSpec.scala @@ -217,7 +217,7 @@ class HttpClientSpec extends FunSpecLike with BeforeAndAfterAll with LivyBaseUni private def runJob(sync: Boolean, genStatusFn: Long => Seq[JobStatus]): (Long, JFuture[Int]) = { val jobId = java.lang.Long.valueOf(ID_GENERATOR.incrementAndGet()) - when(session.submitJob(any(classOf[Array[Byte]]))).thenReturn(jobId) + when(session.submitJob(any(classOf[Array[Byte]]), anyString())).thenReturn(jobId) val statuses = genStatusFn(jobId) val first = statuses.head diff --git a/core/src/main/scala/org/apache/livy/msgs.scala b/core/src/main/scala/org/apache/livy/msgs.scala index 0dd0a2609..048aa7f40 100644 --- a/core/src/main/scala/org/apache/livy/msgs.scala +++ b/core/src/main/scala/org/apache/livy/msgs.scala @@ -28,7 +28,7 @@ case class Msg[T <: Content](msg_type: MsgType, content: T) sealed trait Content -case class ExecuteRequest(code: String) extends Content { +case class ExecuteRequest(code: String, kind: Option[String]) extends Content { val msg_type = MsgType.execute_request } diff --git a/core/src/main/scala/org/apache/livy/sessions/Kind.scala b/core/src/main/scala/org/apache/livy/sessions/Kind.scala index bfb166ff4..5e7fa00b8 100644 --- a/core/src/main/scala/org/apache/livy/sessions/Kind.scala +++ b/core/src/main/scala/org/apache/livy/sessions/Kind.scala @@ -30,21 +30,21 @@ case class PySpark() extends Kind { override def toString: String = "pyspark" } -case class PySpark3() extends Kind { - override def toString: String = "pyspark3" -} - case class SparkR() extends Kind { override def toString: String = "sparkr" } +case class Shared() extends Kind { + override def toString: String = "shared" +} + object Kind { def apply(kind: String): Kind = kind match { case "spark" | "scala" => Spark() case "pyspark" | "python" => PySpark() - case "pyspark3" | "python3" => PySpark3() case "sparkr" | "r" => SparkR() + case "shared" => Shared() case other => throw new IllegalArgumentException(s"Invalid kind: $other") } diff --git a/integration-test/pom.xml b/integration-test/pom.xml index e15c47099..efb8b386b 100644 --- a/integration-test/pom.xml +++ b/integration-test/pom.xml @@ -286,7 +286,7 @@ ${execution.root} - false + true true diff --git a/python-api/src/main/python/livy/client.py b/python-api/src/main/python/livy/client.py index 4a5a98578..99f88f89f 100644 --- a/python-api/src/main/python/livy/client.py +++ b/python-api/src/main/python/livy/client.py @@ -70,6 +70,7 @@ def __init__(self, url, load_defaults=True, conf_dict=None): uri = urlparse(url) self._config = ConfigParser() self._load_config(load_defaults, conf_dict) + self._job_type = 'pyspark' match = re.match(r'(.*)/sessions/([0-9]+)', uri.path) if match: base = ParseResult(scheme=uri.scheme, netloc=uri.netloc, @@ -395,7 +396,8 @@ def _reconnect_to_existing_session(self): def _send_job(self, command, job): pickled_job = cloudpickle.dumps(job) base64_pickled_job = base64.b64encode(pickled_job).decode('utf-8') - base64_pickled_job_data = {'job': base64_pickled_job} + base64_pickled_job_data = \ + {'job': base64_pickled_job, 'jobType': self._job_type} handle = JobHandle(self._conn, self._session_id, self._executor) handle._start(command, base64_pickled_job_data) diff --git a/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala b/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala index f5b5b32f8..39009b2df 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala @@ -26,20 +26,20 @@ import scala.tools.nsc.interpreter.JPrintWriter import scala.tools.nsc.interpreter.Results.Result import scala.util.{Failure, Success, Try} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.repl.SparkIMain +import org.apache.livy.rsc.driver.SparkEntries + /** * This represents a Spark interpreter. It is not thread safe. */ -class SparkInterpreter(conf: SparkConf) - extends AbstractSparkInterpreter with SparkContextInitializer { +class SparkInterpreter(protected override val conf: SparkConf) extends AbstractSparkInterpreter { private var sparkIMain: SparkIMain = _ - protected var sparkContext: SparkContext = _ - override def start(): SparkContext = { - require(sparkIMain == null && sparkContext == null) + override def start(): Unit = { + require(sparkIMain == null) val settings = new Settings() settings.embeddedDefaults(Thread.currentThread().getContextClassLoader()) @@ -103,23 +103,21 @@ class SparkInterpreter(conf: SparkConf) } } - createSparkContext(conf) + postStart() } - - sparkContext } - protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit = { + override protected def bind(name: String, + tpe: String, + value: Object, + modifier: List[String]): Unit = { sparkIMain.beQuietDuring { sparkIMain.bind(name, tpe, value, modifier) } } override def close(): Unit = synchronized { - if (sparkContext != null) { - sparkContext.stop() - sparkContext = null - } + super.close() if (sparkIMain != null) { sparkIMain.close() @@ -128,7 +126,7 @@ class SparkInterpreter(conf: SparkConf) } override protected def isStarted(): Boolean = { - sparkContext != null && sparkIMain != null + sparkIMain != null } override protected def interpret(code: String): Result = { diff --git a/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala index 94cd241f8..9d19ef30f 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/livy/repl/SparkInterpreter.scala @@ -26,20 +26,20 @@ import scala.tools.nsc.interpreter.JPrintWriter import scala.tools.nsc.interpreter.Results.Result import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.repl.SparkILoop +import org.apache.livy.rsc.driver.SparkEntries + /** * Scala 2.11 version of SparkInterpreter */ -class SparkInterpreter(conf: SparkConf) - extends AbstractSparkInterpreter with SparkContextInitializer { +class SparkInterpreter(protected override val conf: SparkConf) extends AbstractSparkInterpreter { - protected var sparkContext: SparkContext = _ private var sparkILoop: SparkILoop = _ private var sparkHttpServer: Object = _ - override def start(): SparkContext = { + override def start(): Unit = { require(sparkILoop == null) val rootDir = conf.get("spark.repl.classdir", System.getProperty("java.io.tmpdir")) @@ -89,17 +89,12 @@ class SparkInterpreter(conf: SparkConf) } } - createSparkContext(conf) + postStart() } - - sparkContext } override def close(): Unit = synchronized { - if (sparkContext != null) { - sparkContext.stop() - sparkContext = null - } + super.close() if (sparkILoop != null) { sparkILoop.closeInterpreter() @@ -115,7 +110,7 @@ class SparkInterpreter(conf: SparkConf) } override protected def isStarted(): Boolean = { - sparkContext != null && sparkILoop != null + sparkILoop != null } override protected def interpret(code: String): Result = { @@ -127,7 +122,10 @@ class SparkInterpreter(conf: SparkConf) Option(sparkILoop.lastRequest.lineRep.call("$result")) } - protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit = { + override protected def bind(name: String, + tpe: String, + value: Object, + modifier: List[String]): Unit = { sparkILoop.beQuietDuring { sparkILoop.bind(name, tpe, value, modifier) } diff --git a/repl/src/main/resources/fake_shell.py b/repl/src/main/resources/fake_shell.py index 534e0df44..523f7a059 100644 --- a/repl/src/main/resources/fake_shell.py +++ b/repl/src/main/resources/fake_shell.py @@ -531,14 +531,42 @@ def main(): listening_port = 0 if os.environ.get("LIVY_TEST") != "true": #Load spark into the context - exec('from pyspark.shell import sc', global_dict) - exec('from pyspark.shell import sqlContext', global_dict) exec('from pyspark.sql import HiveContext', global_dict) exec('from pyspark.streaming import StreamingContext', global_dict) exec('import pyspark.cloudpickle as cloudpickle', global_dict) + from py4j.java_gateway import java_import, JavaGateway, GatewayClient + from pyspark.conf import SparkConf + from pyspark.context import SparkContext + from pyspark.sql import SQLContext, HiveContext, Row + # Connect to the gateway + gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) + + # Import the classes used by PySpark + java_import(gateway.jvm, "org.apache.spark.SparkConf") + java_import(gateway.jvm, "org.apache.spark.api.java.*") + java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") + java_import(gateway.jvm, "scala.Tuple2") + + jsc = gateway.entry_point.sc() + jconf = gateway.entry_point.sc().getConf() + jsqlc = gateway.entry_point.hivectx() if gateway.entry_point.hivectx() is not None \ + else gateway.entry_point.sqlctx() + + conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf) + sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf) + global_dict['sc'] = sc + sqlc = SQLContext(sc, jsqlc) + global_dict['sqlContext'] = sqlc + if spark_major_version >= "2": - exec('from pyspark.shell import spark', global_dict) + from pyspark.sql import SparkSession + spark_session = SparkSession(sc, gateway.entry_point.sparkSession()) + global_dict['spark'] = spark_session else: # LIVY-294, need to check whether HiveContext can work properly, # fallback to SQLContext if HiveContext can not be initialized successfully. diff --git a/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala index cf2cf3560..b058d0015 100644 --- a/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala +++ b/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala @@ -21,6 +21,7 @@ import java.io.ByteArrayOutputStream import scala.tools.nsc.interpreter.Results +import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.json4s.DefaultFormats import org.json4s.Extraction @@ -28,6 +29,7 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.apache.livy.Logging +import org.apache.livy.rsc.driver.SparkEntries object AbstractSparkInterpreter { private[repl] val KEEP_NEWLINE_REGEX = """(?<=\n)""".r @@ -41,6 +43,10 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging { protected val outputStream = new ByteArrayOutputStream() + protected var entries: SparkEntries = _ + + def sparkEntries(): SparkEntries = entries + final def kind: String = "spark" protected def isStarted(): Boolean @@ -49,6 +55,52 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging { protected def valueOfTerm(name: String): Option[Any] + protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit + + protected def conf: SparkConf + + protected def postStart(): Unit = { + entries = new SparkEntries(conf) + + if (isSparkSessionPresent()) { + bind("spark", + sparkEntries.sparkSession().getClass.getCanonicalName, + sparkEntries.sparkSession(), + List("""@transient""")) + bind("sc", "org.apache.spark.SparkContext", sparkEntries.sc().sc, List("""@transient""")) + + execute("import org.apache.spark.SparkContext._") + execute("import spark.implicits._") + execute("import spark.sql") + execute("import org.apache.spark.sql.functions._") + } else { + bind("sc", "org.apache.spark.SparkContext", sparkEntries.sc().sc, List("""@transient""")) + val sqlContext = Option(sparkEntries.hivectx()).getOrElse(sparkEntries.sqlctx()) + bind("sqlContext", sqlContext.getClass.getCanonicalName, sqlContext, List("""@transient""")) + + execute("import org.apache.spark.SparkContext._") + execute("import sqlContext.implicits._") + execute("import sqlContext.sql") + execute("import org.apache.spark.sql.functions._") + } + } + + override def close(): Unit = { + if (entries != null) { + entries.stop() + entries = null + } + } + + private def isSparkSessionPresent(): Boolean = { + try { + Class.forName("org.apache.spark.sql.SparkSession") + true + } catch { + case _: ClassNotFoundException | _: NoClassDefFoundError => false + } + } + override protected[repl] def execute(code: String): Interpreter.ExecuteResponse = restoreContextClassLoader { require(isStarted()) diff --git a/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala b/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala index 6a9bdd192..7e79c412a 100644 --- a/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala +++ b/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala @@ -19,16 +19,13 @@ package org.apache.livy.repl import java.nio.charset.StandardCharsets import org.apache.livy.{Job, JobContext} +import org.apache.livy.sessions._ class BypassPySparkJob( serializedJob: Array[Byte], - replDriver: ReplDriver) extends Job[Array[Byte]] { + pi: PythonInterpreter) extends Job[Array[Byte]] { override def call(jc: JobContext): Array[Byte] = { - val interpreter = replDriver.interpreter - require(interpreter != null && interpreter.isInstanceOf[PythonInterpreter]) - val pi = interpreter.asInstanceOf[PythonInterpreter] - val resultByteArray = pi.pysparkJobProcessor.processBypassJob(serializedJob) val resultString = new String(resultByteArray, StandardCharsets.UTF_8) if (resultString.startsWith("Client job error:")) { diff --git a/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala b/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala index 058b20bd3..860513ede 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala @@ -17,7 +17,6 @@ package org.apache.livy.repl -import org.apache.spark.SparkContext import org.json4s.JObject object Interpreter { @@ -38,10 +37,8 @@ trait Interpreter { /** * Start the Interpreter. - * - * @return A SparkContext */ - def start(): SparkContext + def start(): Unit /** * Execute the code and return the result, it may diff --git a/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala index 7995ba07b..cb3e0ea68 100644 --- a/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala +++ b/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala @@ -23,11 +23,9 @@ import java.util.concurrent.locks.ReentrantLock import scala.concurrent.Promise import scala.io.Source -import org.apache.spark.SparkContext import org.json4s.JValue import org.apache.livy.{Logging, Utils} -import org.apache.livy.client.common.ClientConf private sealed trait Request private case class ExecuteRequest(code: String, promise: Promise[JValue]) extends Request @@ -45,14 +43,8 @@ abstract class ProcessInterpreter(process: Process) protected[this] val stdin = new PrintWriter(process.getOutputStream) protected[this] val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1) - override def start(): SparkContext = { + override def start(): Unit = { waitUntilReady() - - if (ClientConf.TEST_MODE) { - null.asInstanceOf[SparkContext] - } else { - SparkContext.getOrCreate() - } } override protected[repl] def execute(code: String): Interpreter.ExecuteResponse = { diff --git a/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala index bfd5d76ff..f4e16f85f 100644 --- a/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala +++ b/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala @@ -37,21 +37,18 @@ import py4j.reflection.PythonProxyHandler import org.apache.livy.Logging import org.apache.livy.client.common.ClientConf -import org.apache.livy.rsc.BaseProtocol -import org.apache.livy.rsc.driver.BypassJobWrapper +import org.apache.livy.rsc.driver.SparkEntries import org.apache.livy.sessions._ // scalastyle:off println object PythonInterpreter extends Logging { - def apply(conf: SparkConf, kind: Kind): Interpreter = { - val pythonExec = kind match { - case PySpark() => sys.env.getOrElse("PYSPARK_PYTHON", "python") - case PySpark3() => sys.env.getOrElse("PYSPARK3_PYTHON", "python3") - case _ => throw new IllegalArgumentException(s"Unknown kind: $kind") - } + def apply(conf: SparkConf, sparkEntries: SparkEntries): Interpreter = { + val pythonExec = sys.env.get("PYSPARK_PYTHON") + .orElse(sys.props.get("pyspark.python")) // This java property is only used for internal UT. + .getOrElse("python") - val gatewayServer = new GatewayServer(null, 0) + val gatewayServer = new GatewayServer(sparkEntries, 0) gatewayServer.start() val builder = new ProcessBuilder(Seq(pythonExec, createFakeShell().toString).asJava) @@ -71,7 +68,7 @@ object PythonInterpreter extends Logging { env.put("LIVY_SPARK_MAJOR_VERSION", conf.get("spark.livy.spark_major_version", "1")) builder.redirectError(Redirect.PIPE) val process = builder.start() - new PythonInterpreter(process, gatewayServer, kind.toString) + new PythonInterpreter(process, gatewayServer) } private def findPySparkArchives(): Seq[String] = { @@ -188,14 +185,12 @@ object PythonInterpreter extends Logging { private class PythonInterpreter( process: Process, - gatewayServer: GatewayServer, - pyKind: String) + gatewayServer: GatewayServer) extends ProcessInterpreter(process) - with Logging -{ + with Logging { implicit val formats = DefaultFormats - override def kind: String = pyKind + override def kind: String = "pyspark" private[repl] val pysparkJobProcessor = PythonInterpreter.initiatePy4jCallbackGateway(gatewayServer) diff --git a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala index 75966be1c..7f359829d 100644 --- a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala +++ b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala @@ -22,9 +22,9 @@ import scala.concurrent.duration.Duration import io.netty.channel.ChannelHandlerContext import org.apache.spark.SparkConf -import org.apache.spark.api.java.JavaSparkContext import org.apache.livy.Logging +import org.apache.livy.client.common.ClientConf import org.apache.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf} import org.apache.livy.rsc.BaseProtocol.ReplState import org.apache.livy.rsc.driver._ @@ -37,23 +37,11 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) private[repl] var session: Session = _ - private val kind = Kind(livyConf.get(RSCConf.Entry.SESSION_KIND)) - - private[repl] var interpreter: Interpreter = _ - - override protected def initializeContext(): JavaSparkContext = { - interpreter = kind match { - case PySpark() => PythonInterpreter(conf, PySpark()) - case PySpark3() => - PythonInterpreter(conf, PySpark3()) - case Spark() => new SparkInterpreter(conf) - case SparkR() => SparkRInterpreter(conf) - } - session = new Session(livyConf, interpreter, { s => broadcast(new ReplState(s.toString)) }) - - Option(Await.result(session.start(), Duration.Inf)) - .map(new JavaSparkContext(_)) - .orNull + override protected def initializeSparkEntries(): SparkEntries = { + session = new Session(livyConf = livyConf, + sparkConf = conf, + stateChangedCallback = { s => broadcast(new ReplState(s.toString)) }) + Await.result(session.start(), Duration.Inf) } override protected def shutdownContext(): Unit = { @@ -67,7 +55,7 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) } def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.ReplJobRequest): Int = { - session.execute(msg.code) + session.execute(msg.code, msg.codeType) } def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.CancelReplJobRequest): Unit = { @@ -100,27 +88,27 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) } override protected def createWrapper(msg: BaseProtocol.BypassJobRequest): BypassJobWrapper = { - kind match { - case PySpark() | PySpark3() => new BypassJobWrapper(this, msg.id, - new BypassPySparkJob(msg.serializedJob, this)) + Kind(msg.jobType) match { + case PySpark() => + new BypassJobWrapper(this, msg.id, + new BypassPySparkJob(msg.serializedJob, + session.interpreter(PySpark()).asInstanceOf[PythonInterpreter])) case _ => super.createWrapper(msg) } } override protected def addFile(path: String): Unit = { - require(interpreter != null) - interpreter match { - case pi: PythonInterpreter => pi.addFile(path) - case _ => super.addFile(path) + if (!ClientConf.TEST_MODE) { + session.interpreter(PySpark()).asInstanceOf[PythonInterpreter].addFile(path) } + super.addFile(path) } override protected def addJarOrPyFile(path: String): Unit = { - require(interpreter != null) - interpreter match { - case pi: PythonInterpreter => pi.addPyFile(this, conf, path) - case _ => super.addJarOrPyFile(path) + if (!ClientConf.TEST_MODE) { + session.interpreter(PySpark()).asInstanceOf[PythonInterpreter].addPyFile(this, conf, path) } + super.addJarOrPyFile(path) } override protected def onClientAuthenticated(client: Rpc): Unit = { diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index 40176ea0c..f2451932d 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -23,17 +23,18 @@ import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.json4s.jackson.JsonMethods.{compact, render} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.apache.livy.Logging import org.apache.livy.rsc.RSCConf -import org.apache.livy.rsc.driver.{Statement, StatementState} +import org.apache.livy.rsc.driver.{SparkEntries, Statement, StatementState} import org.apache.livy.sessions._ object Session { @@ -49,7 +50,8 @@ object Session { class Session( livyConf: RSCConf, - interpreter: Interpreter, + sparkConf: SparkConf, + mockSparkInterpreter: Option[SparkInterpreter] = None, stateChangedCallback: SessionState => Unit = { _ => }) extends Logging { import Session._ @@ -62,8 +64,6 @@ class Session( private implicit val formats = DefaultFormats - @volatile private[repl] var _sc: Option[SparkContext] = None - private var _state: SessionState = SessionState.NotStarted() // Number of statements kept in driver's memory @@ -77,40 +77,88 @@ class Session( private val newStatementId = new AtomicInteger(0) + private val defaultInterpKind = Kind(livyConf.get(RSCConf.Entry.SESSION_KIND)) + + private val interpGroup = new mutable.HashMap[Kind, Interpreter]() + + @volatile private var entries: SparkEntries = _ + stateChangedCallback(_state) - def start(): Future[SparkContext] = { + private def sc: SparkContext = { + require(entries != null) + entries.sc().sc + } + + private[repl] def interpreter(kind: Kind): Interpreter = interpGroup.synchronized { + if (interpGroup.contains(kind)) { + interpGroup(kind) + } else { + require(entries != null, + "SparkEntries should not be null when lazily initialize other interpreters.") + + val interp = kind match { + case Spark() => + // This should never be touched here. + throw new IllegalStateException("SparkInterpreter should not be lazily created.") + case PySpark() => PythonInterpreter(sparkConf, entries) + case SparkR() => SparkRInterpreter(sparkConf, entries) + } + interp.start() + interpGroup(kind) = interp + + interp + } + } + + def start(): Future[SparkEntries] = { val future = Future { changeState(SessionState.Starting()) - val sc = interpreter.start() - _sc = Option(sc) + + // Always start SparkInterpreter after beginning, because we rely on SparkInterpreter to + // initialize SparkContext and create SparkEntries. + val sparkInterp = mockSparkInterpreter.getOrElse(new SparkInterpreter(sparkConf)) + sparkInterp.start() + + entries = sparkInterp.sparkEntries() + require(entries != null, "SparkEntries object should not be null in Spark Interpreter.") + interpGroup.synchronized { + interpGroup.put(Spark(), sparkInterp) + } + changeState(SessionState.Idle()) - sc + entries }(interpreterExecutor) future.onFailure { case _ => changeState(SessionState.Error()) }(interpreterExecutor) future } - def kind: String = interpreter.kind - def state: SessionState = _state def statements: collection.Map[Int, Statement] = _statements.synchronized { _statements.toMap } - def execute(code: String): Int = { + def execute(code: String, codeType: String = null): Int = { + val tpe = if (codeType != null) { + Kind(codeType) + } else if (defaultInterpKind != Shared()) { + defaultInterpKind + } else { + throw new IllegalArgumentException(s"Code type should be specified if session kind is shared") + } + val statementId = newStatementId.getAndIncrement() val statement = new Statement(statementId, code, StatementState.Waiting, null) _statements.synchronized { _statements(statementId) = statement } Future { - setJobGroup(statementId) + setJobGroup(tpe, statementId) statement.compareAndTransit(StatementState.Waiting, StatementState.Running) if (statement.state.get() == StatementState.Running) { - statement.output = executeCode(statementId, code) + statement.output = executeCode(interpreter(tpe), statementId, code) } statement.compareAndTransit(StatementState.Running, StatementState.Available) @@ -150,7 +198,7 @@ class Session( info(s"Failed to cancel statement $statementId.") statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) } else { - _sc.foreach(_.cancelJobGroup(statementId.toString)) + sc.cancelJobGroup(statementId.toString) if (statement.state.get() == StatementState.Cancelling) { Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL)) } @@ -166,7 +214,7 @@ class Session( def close(): Unit = { interpreterExecutor.shutdown() cancelExecutor.shutdown() - interpreter.close() + interpGroup.values.foreach(_.close()) } /** @@ -175,21 +223,19 @@ class Session( def progressOfStatement(stmtId: Int): Double = { val jobGroup = statementIdToJobGroup(stmtId) - _sc.map { sc => - val jobIds = sc.statusTracker.getJobIdsForGroup(jobGroup) - val jobs = jobIds.flatMap { id => sc.statusTracker.getJobInfo(id) } - val stages = jobs.flatMap { job => - job.stageIds().flatMap(sc.statusTracker.getStageInfo) - } + val jobIds = sc.statusTracker.getJobIdsForGroup(jobGroup) + val jobs = jobIds.flatMap { id => sc.statusTracker.getJobInfo(id) } + val stages = jobs.flatMap { job => + job.stageIds().flatMap(sc.statusTracker.getStageInfo) + } - val taskCount = stages.map(_.numTasks).sum - val completedTaskCount = stages.map(_.numCompletedTasks).sum - if (taskCount == 0) { - 0.0 - } else { - completedTaskCount.toDouble / taskCount - } - }.getOrElse(0.0) + val taskCount = stages.map(_.numTasks).sum + val completedTaskCount = stages.map(_.numCompletedTasks).sum + if (taskCount == 0) { + 0.0 + } else { + completedTaskCount.toDouble / taskCount + } } private def changeState(newState: SessionState): Unit = { @@ -199,7 +245,7 @@ class Session( stateChangedCallback(newState) } - private def executeCode(executionCount: Int, code: String): String = { + private def executeCode(interp: Interpreter, executionCount: Int, code: String): String = { changeState(SessionState.Busy()) def transitToIdle() = { @@ -210,7 +256,7 @@ class Session( } val resultInJson = try { - interpreter.execute(code) match { + interp.execute(code) match { case Interpreter.ExecuteSuccess(data) => transitToIdle() @@ -261,17 +307,17 @@ class Session( compact(render(resultInJson)) } - private def setJobGroup(statementId: Int): String = { + private def setJobGroup(codeType: Kind, statementId: Int): String = { val jobGroup = statementIdToJobGroup(statementId) - val cmd = Kind(interpreter.kind) match { + val cmd = codeType match { case Spark() => // A dummy value to avoid automatic value binding in scala REPL. s"""val _livyJobGroup$jobGroup = sc.setJobGroup("$jobGroup",""" + s""""Job group for statement $jobGroup")""" - case PySpark() | PySpark3() => + case PySpark() => s"""sc.setJobGroup("$jobGroup", "Job group for statement $jobGroup")""" case SparkR() => - interpreter.asInstanceOf[SparkRInterpreter].sparkMajorVersion match { + sc.getConf.get("spark.livy.spark_major_version", "1") match { case "1" => s"""setJobGroup(sc, "$jobGroup", "Job group for statement $jobGroup", """ + "FALSE)" @@ -280,7 +326,7 @@ class Session( } } // Set the job group - executeCode(statementId, cmd) + executeCode(interpreter(codeType), statementId, cmd) } private def statementIdToJobGroup(statementId: Int): String = { diff --git a/repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala b/repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala deleted file mode 100644 index 4d478244e..000000000 --- a/repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.livy.repl - -import org.apache.spark.{SparkConf, SparkContext} - -import org.apache.livy.Logging - -/** - * A mixin trait for Spark entry point creation. This trait exists two different code path - * separately for Spark1 and Spark2, depends on whether SparkSession exists or not. - */ -trait SparkContextInitializer extends Logging { - self: SparkInterpreter => - - def createSparkContext(conf: SparkConf): Unit = { - if (isSparkSessionPresent()) { - spark2CreateContext(conf) - } else { - spark1CreateContext(conf) - } - } - - private def spark1CreateContext(conf: SparkConf): Unit = { - sparkContext = SparkContext.getOrCreate(conf) - var sqlContext: Object = null - - if (conf.getBoolean("spark.repl.enableHiveContext", false)) { - try { - val loader = Option(Thread.currentThread().getContextClassLoader) - .getOrElse(getClass.getClassLoader) - if (loader.getResource("hive-site.xml") == null) { - warn("livy.repl.enable-hive-context is true but no hive-site.xml found on classpath.") - } - - sqlContext = Class.forName("org.apache.spark.sql.hive.HiveContext") - .getConstructor(classOf[SparkContext]).newInstance(sparkContext).asInstanceOf[Object] - info("Created sql context (with Hive support).") - } catch { - case _: NoClassDefFoundError => - sqlContext = Class.forName("org.apache.spark.sql.SQLContext") - .getConstructor(classOf[SparkContext]).newInstance(sparkContext).asInstanceOf[Object] - info("Created sql context.") - } - } else { - sqlContext = Class.forName("org.apache.spark.sql.SQLContext") - .getConstructor(classOf[SparkContext]).newInstance(sparkContext).asInstanceOf[Object] - info("Created sql context.") - } - - bind("sc", "org.apache.spark.SparkContext", sparkContext, List("""@transient""")) - bind("sqlContext", sqlContext.getClass.getCanonicalName, sqlContext, List("""@transient""")) - - execute("import org.apache.spark.SparkContext._") - execute("import sqlContext.implicits._") - execute("import sqlContext.sql") - execute("import org.apache.spark.sql.functions._") - } - - private def spark2CreateContext(conf: SparkConf): Unit = { - val sparkClz = Class.forName("org.apache.spark.sql.SparkSession$") - val sparkObj = sparkClz.getField("MODULE$").get(null) - - val builderMethod = sparkClz.getMethod("builder") - val builder = builderMethod.invoke(sparkObj) - builder.getClass.getMethod("config", classOf[SparkConf]).invoke(builder, conf) - - var spark: Object = null - if (conf.get("spark.sql.catalogImplementation", "in-memory").toLowerCase == "hive") { - if (sparkClz.getMethod("hiveClassesArePresent").invoke(sparkObj).asInstanceOf[Boolean]) { - val loader = Option(Thread.currentThread().getContextClassLoader) - .getOrElse(getClass.getClassLoader) - if (loader.getResource("hive-site.xml") == null) { - warn("livy.repl.enable-hive-context is true but no hive-site.xml found on classpath.") - } - - builder.getClass.getMethod("enableHiveSupport").invoke(builder) - spark = builder.getClass.getMethod("getOrCreate").invoke(builder) - info("Created Spark session (with Hive support).") - } else { - builder.getClass.getMethod("config", classOf[String], classOf[String]) - .invoke(builder, "spark.sql.catalogImplementation", "in-memory") - spark = builder.getClass.getMethod("getOrCreate").invoke(builder) - info("Created Spark session.") - } - } else { - spark = builder.getClass.getMethod("getOrCreate").invoke(builder) - info("Created Spark session.") - } - - sparkContext = spark.getClass.getMethod("sparkContext").invoke(spark) - .asInstanceOf[SparkContext] - - bind("spark", spark.getClass.getCanonicalName, spark, List("""@transient""")) - bind("sc", "org.apache.spark.SparkContext", sparkContext, List("""@transient""")) - - execute("import org.apache.spark.SparkContext._") - execute("import spark.implicits._") - execute("import spark.sql") - execute("import org.apache.spark.sql.functions._") - } - - private def isSparkSessionPresent(): Boolean = { - try { - Class.forName("org.apache.spark.sql.SparkSession") - true - } catch { - case _: ClassNotFoundException | _: NoClassDefFoundError => false - } - } -} diff --git a/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala index b74586163..9330248d9 100644 --- a/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala +++ b/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala @@ -17,8 +17,7 @@ package org.apache.livy.repl -import java.io.{File, FileOutputStream} -import java.lang.ProcessBuilder.Redirect +import java.io.File import java.nio.file.Files import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit} @@ -28,13 +27,14 @@ import scala.reflect.runtime.universe import org.apache.commons.codec.binary.Base64 import org.apache.commons.lang.StringEscapeUtils -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.SparkConf +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.SQLContext import org.json4s._ import org.json4s.JsonDSL._ import org.apache.livy.client.common.ClientConf -import org.apache.livy.rsc.RSCConf +import org.apache.livy.rsc.driver.SparkEntries private case class RequestResponse(content: String, error: Boolean) @@ -44,6 +44,7 @@ object SparkRInterpreter { private val LIVY_ERROR_MARKER = "----LIVY_END_OF_ERROR----" private val PRINT_MARKER = f"""print("$LIVY_END_MARKER")""" private val EXPECTED_OUTPUT = f"""[1] "$LIVY_END_MARKER"""" + private var sparkEntries: SparkEntries = null private val PLOT_REGEX = ( "(" + @@ -67,7 +68,8 @@ object SparkRInterpreter { ")" ).r.unanchored - def apply(conf: SparkConf): SparkRInterpreter = { + def apply(conf: SparkConf, entries: SparkEntries): SparkRInterpreter = { + sparkEntries = entries val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt val mirror = universe.runtimeMirror(getClass.getClassLoader) val sparkRBackendClass = mirror.classLoader.loadClass("org.apache.spark.api.r.RBackend") @@ -119,8 +121,7 @@ object SparkRInterpreter { builder.redirectErrorStream(true) val process = builder.start() new SparkRInterpreter(process, backendInstance, backendThread, - conf.get("spark.livy.spark_major_version", "1"), - conf.getBoolean("spark.repl.enableHiveContext", false)) + conf.getInt("spark.livy.spark_major_version", 1)) } catch { case e: Exception => if (backendThread != null) { @@ -129,13 +130,27 @@ object SparkRInterpreter { throw e } } + + def getSparkContext(): JavaSparkContext = { + require(sparkEntries != null) + sparkEntries.sc() + } + + def getSparkSession(): Object = { + require(sparkEntries != null) + sparkEntries.sparkSession() + } + + def getSQLContext(): SQLContext = { + require(sparkEntries != null) + if (sparkEntries.hivectx() != null) sparkEntries.hivectx() else sparkEntries.sqlctx() + } } class SparkRInterpreter(process: Process, backendInstance: Any, backendThread: Thread, - val sparkMajorVersion: String, - hiveEnabled: Boolean) + val sparkMajorVersion: Int) extends ProcessInterpreter(process) { import SparkRInterpreter._ @@ -149,24 +164,23 @@ class SparkRInterpreter(process: Process, // Set the option to catch and ignore errors instead of halting. sendRequest("options(error = dump.frames)") if (!ClientConf.TEST_MODE) { + // scalastyle:off line.size.limit sendRequest("library(SparkR)") - if (sparkMajorVersion >= "2") { - if (hiveEnabled) { - sendRequest("spark <- SparkR::sparkR.session()") - } else { - sendRequest("spark <- SparkR::sparkR.session(enableHiveSupport=FALSE)") - } - sendRequest( - """sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "getJavaSparkContext", spark)""") - } else { - sendRequest("sc <- sparkR.init()") - if (hiveEnabled) { - sendRequest("sqlContext <- sparkRHive.init(sc)") - } else { - sendRequest("sqlContext <- sparkRSQL.init(sc)") - } + sendRequest("""port <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")""") + sendRequest("""SparkR:::connectBackend("localhost", port, 6000)""") + sendRequest("""assign(".scStartTime", as.integer(Sys.time()), envir = SparkR:::.sparkREnv)""") + + sendRequest("""assign(".sc", SparkR:::callJStatic("org.apache.livy.repl.SparkRInterpreter", "getSparkContext"), envir = SparkR:::.sparkREnv)""") + sendRequest("""assign("sc", get(".sc", envir = SparkR:::.sparkREnv), envir=.GlobalEnv)""") + + if (sparkMajorVersion >= 2) { + sendRequest("""assign(".sparkRsession", SparkR:::callJStatic("org.apache.livy.repl.SparkRInterpreter", "getSparkSession"), envir = SparkR:::.sparkREnv)""") + sendRequest("""assign("spark", get(".sparkRsession", envir = SparkR:::.sparkREnv), envir=.GlobalEnv)""") } + + sendRequest("""assign(".sqlc", SparkR:::callJStatic("org.apache.livy.repl.SparkRInterpreter", "getSQLContext"), envir = SparkR:::.sparkREnv)""") + sendRequest("""assign("sqlContext", get(".sqlc", envir = SparkR:::.sparkREnv), envir = .GlobalEnv)""") + // scalastyle:on line.size.limit } isStarted.countDown() diff --git a/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala index a13f924e1..bcea0df29 100644 --- a/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/BaseSessionSpec.scala @@ -24,6 +24,7 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import org.apache.spark.SparkConf import org.json4s._ import org.scalatest.{FlatSpec, Matchers} import org.scalatest.concurrent.Eventually._ @@ -31,13 +32,16 @@ import org.scalatest.concurrent.Eventually._ import org.apache.livy.LivyBaseUnitTestSuite import org.apache.livy.rsc.RSCConf import org.apache.livy.rsc.driver.{Statement, StatementState} -import org.apache.livy.sessions.SessionState +import org.apache.livy.sessions._ -abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitTestSuite { +abstract class BaseSessionSpec(kind: Kind) + extends FlatSpec with Matchers with LivyBaseUnitTestSuite { implicit val formats = DefaultFormats - private val rscConf = new RSCConf(new Properties()) + private val rscConf = new RSCConf(new Properties()).set(RSCConf.Entry.SESSION_KIND, kind.toString) + + private val sparkConf = new SparkConf() protected def execute(session: Session)(code: String): Statement = { val id = session.execute(code) @@ -51,7 +55,7 @@ abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitT protected def withSession(testCode: Session => Any): Unit = { val stateChangedCalled = new AtomicInteger() val session = - new Session(rscConf, createInterpreter(), { _ => stateChangedCalled.incrementAndGet() }) + new Session(rscConf, sparkConf, None, { _ => stateChangedCalled.incrementAndGet() }) try { // Session's constructor should fire an initial state change event. stateChangedCalled.intValue() shouldBe 1 @@ -65,16 +69,12 @@ abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitT } } - protected def createInterpreter(): Interpreter - it should "start in the starting or idle state" in { - val session = new Session(rscConf, createInterpreter()) + val session = new Session(rscConf, sparkConf) val future = session.start() try { - eventually(timeout(30 seconds), interval(100 millis)) { - session.state should (equal (SessionState.Starting()) or equal (SessionState.Idle())) - } Await.ready(future, 60 seconds) + session.state should (equal (SessionState.Starting()) or equal (SessionState.Idle())) } finally { session.close() } diff --git a/repl/src/test/scala/org/apache/livy/repl/PythonInterpreterSpec.scala b/repl/src/test/scala/org/apache/livy/repl/PythonInterpreterSpec.scala index 140476576..fb0704e4e 100644 --- a/repl/src/test/scala/org/apache/livy/repl/PythonInterpreterSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/PythonInterpreterSpec.scala @@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JNull, JValue} import org.json4s.JsonDSL._ import org.scalatest._ -import org.apache.livy.rsc.RSCConf +import org.apache.livy.rsc.driver.SparkEntries import org.apache.livy.sessions._ abstract class PythonBaseInterpreterSpec extends BaseInterpreterSpec { @@ -244,7 +244,10 @@ class Python2InterpreterSpec extends PythonBaseInterpreterSpec { implicit val formats = DefaultFormats - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) + override def createInterpreter(): Interpreter = { + val sparkConf = new SparkConf() + PythonInterpreter(sparkConf, new SparkEntries(sparkConf)) + } // Scalastyle is treating unicode escape as non ascii characters. Turn off the check. // scalastyle:off non.ascii.character.disallowed @@ -262,7 +265,7 @@ class Python2InterpreterSpec extends PythonBaseInterpreterSpec { // scalastyle:on non.ascii.character.disallowed } -class Python3InterpreterSpec extends PythonBaseInterpreterSpec { +class Python3InterpreterSpec extends PythonBaseInterpreterSpec with BeforeAndAfterAll { implicit val formats = DefaultFormats @@ -271,7 +274,20 @@ class Python3InterpreterSpec extends PythonBaseInterpreterSpec { test() } - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) + override def beforeAll(): Unit = { + super.beforeAll() + sys.props.put("pyspark.python", "python3") + } + + override def afterAll(): Unit = { + sys.props.remove("pyspark.python") + super.afterAll() + } + + override def createInterpreter(): Interpreter = { + val sparkConf = new SparkConf() + PythonInterpreter(sparkConf, new SparkEntries(sparkConf)) + } it should "check python version is 3.x" in withInterpreter { interpreter => val response = interpreter.execute("""import sys diff --git a/repl/src/test/scala/org/apache/livy/repl/PythonSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/PythonSessionSpec.scala index 0883f278d..69d41312c 100644 --- a/repl/src/test/scala/org/apache/livy/repl/PythonSessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/PythonSessionSpec.scala @@ -17,15 +17,13 @@ package org.apache.livy.repl -import org.apache.spark.SparkConf import org.json4s.Extraction import org.json4s.jackson.JsonMethods.parse import org.scalatest._ -import org.apache.livy.rsc.RSCConf import org.apache.livy.sessions._ -abstract class PythonSessionSpec extends BaseSessionSpec { +abstract class PythonSessionSpec extends BaseSessionSpec(PySpark()) { it should "execute `1 + 2` == 3" in withSession { session => val statement = execute(session)("1 + 2") @@ -172,18 +170,25 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } } -class Python2SessionSpec extends PythonSessionSpec { - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) -} +class Python2SessionSpec extends PythonSessionSpec -class Python3SessionSpec extends PythonSessionSpec { +class Python3SessionSpec extends PythonSessionSpec with BeforeAndAfterAll { override protected def withFixture(test: NoArgTest): Outcome = { assume(!sys.props.getOrElse("skipPySpark3Tests", "false").toBoolean, "Skipping PySpark3 tests.") test() } - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) + override def beforeAll(): Unit = { + super.beforeAll() + sys.props.put("pyspark.python", "python3") + } + + override def afterAll(): Unit = { + sys.props.remove("pyspark.python") + super.afterAll() + } + it should "check python version is 3.x" in withSession { session => val statement = execute(session)( diff --git a/repl/src/test/scala/org/apache/livy/repl/ReplDriverSuite.scala b/repl/src/test/scala/org/apache/livy/repl/ReplDriverSuite.scala index 6537f0cba..6d7094d14 100644 --- a/repl/src/test/scala/org/apache/livy/repl/ReplDriverSuite.scala +++ b/repl/src/test/scala/org/apache/livy/repl/ReplDriverSuite.scala @@ -53,7 +53,7 @@ class ReplDriverSuite extends FunSuite with LivyBaseUnitTestSuite { // This is sort of what InteractiveSession.scala does to detect an idle session. client.submit(new PingJob()).get(60, TimeUnit.SECONDS) - val statementId = client.submitReplCode("1 + 1").get + val statementId = client.submitReplCode("1 + 1", "spark").get eventually(timeout(30 seconds), interval(100 millis)) { val rawResult = client.getReplJobResults(statementId, 1).get(10, TimeUnit.SECONDS).statements(0) diff --git a/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala index 090b7cb87..554d59ecb 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SessionSpec.scala @@ -20,41 +20,41 @@ package org.apache.livy.repl import java.util.Properties import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} -import org.mockito.Mockito.when -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.FunSpec +import org.apache.spark.SparkConf +import org.scalatest.{BeforeAndAfter, FunSpec} import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually -import org.scalatest.mock.MockitoSugar.mock import org.scalatest.time._ import org.apache.livy.LivyBaseUnitTestSuite import org.apache.livy.repl.Interpreter.ExecuteResponse import org.apache.livy.rsc.RSCConf +import org.apache.livy.sessions._ -class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite { +class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite with BeforeAndAfter { override implicit val patienceConfig = - PatienceConfig(timeout = scaled(Span(10, Seconds)), interval = scaled(Span(100, Millis))) + PatienceConfig(timeout = scaled(Span(30, Seconds)), interval = scaled(Span(100, Millis))) - private val rscConf = new RSCConf(new Properties()) + private val rscConf = new RSCConf(new Properties()).set(RSCConf.Entry.SESSION_KIND, "spark") describe("Session") { + var session: Session = null + + after { + if (session != null) { + session.close() + session = null + } + } + it("should call state changed callbacks in happy path") { val expectedStateTransitions = Array("not_started", "starting", "idle", "busy", "idle", "busy", "idle") val actualStateTransitions = new ConcurrentLinkedQueue[String]() - val interpreter = mock[Interpreter] - when(interpreter.kind).thenAnswer(new Answer[String] { - override def answer(invocationOnMock: InvocationOnMock): String = "spark" - }) - - val session = - new Session(rscConf, interpreter, { s => actualStateTransitions.add(s.toString) }) - + session = new Session(rscConf, new SparkConf(), None, + { s => actualStateTransitions.add(s.toString) }) session.start() - session.execute("") eventually { @@ -64,23 +64,19 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite { it("should not transit to idle if there're any pending statements.") { val expectedStateTransitions = - Array("not_started", "busy", "busy", "busy", "idle", "busy", "idle") + Array("not_started", "starting", "idle", "busy", "busy", "busy", "idle", "busy", "idle") val actualStateTransitions = new ConcurrentLinkedQueue[String]() - val interpreter = mock[Interpreter] - when(interpreter.kind).thenAnswer(new Answer[String] { - override def answer(invocationOnMock: InvocationOnMock): String = "spark" - }) - val blockFirstExecuteCall = new CountDownLatch(1) - when(interpreter.execute("")).thenAnswer(new Answer[Interpreter.ExecuteResponse] { - override def answer(invocation: InvocationOnMock): ExecuteResponse = { + val interpreter = new SparkInterpreter(new SparkConf()) { + override def execute(code: String): ExecuteResponse = { blockFirstExecuteCall.await(10, TimeUnit.SECONDS) - null + super.execute(code) } - }) - val session = - new Session(rscConf, interpreter, { s => actualStateTransitions.add(s.toString) }) + } + session = new Session(rscConf, new SparkConf(), Some(interpreter), + { s => actualStateTransitions.add(s.toString) }) + session.start() for (_ <- 1 to 2) { session.execute("") @@ -93,13 +89,8 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite { } it("should remove old statements when reaching threshold") { - val interpreter = mock[Interpreter] - when(interpreter.kind).thenAnswer(new Answer[String] { - override def answer(invocationOnMock: InvocationOnMock): String = "spark" - }) - rscConf.set(RSCConf.Entry.RETAINED_STATEMENT_NUMBER, 2) - val session = new Session(rscConf, interpreter) + session = new Session(rscConf, new SparkConf()) session.start() session.statements.size should be (0) diff --git a/repl/src/test/scala/org/apache/livy/repl/SharedSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SharedSessionSpec.scala new file mode 100644 index 000000000..d63e22fe0 --- /dev/null +++ b/repl/src/test/scala/org/apache/livy/repl/SharedSessionSpec.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.json4s.Extraction +import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods.parse +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} + +import org.apache.livy.rsc.driver.{Statement, StatementState} +import org.apache.livy.sessions._ + +class SharedSessionSpec extends BaseSessionSpec(Shared()) { + + private def execute(session: Session, code: String, codeType: String): Statement = { + val id = session.execute(code, codeType) + eventually(timeout(30 seconds), interval(100 millis)) { + val s = session.statements(id) + s.state.get() shouldBe StatementState.Available + s + } + } + + it should "execute `1 + 2` == 3" in withSession { session => + val statement = execute(session, "1 + 2", "spark") + statement.id should equal (0) + + val result = parse(statement.output) + val expectedResult = Extraction.decompose(Map( + "status" -> "ok", + "execution_count" -> 0, + "data" -> Map( + "text/plain" -> "res0: Int = 3" + ) + )) + + result should equal (expectedResult) + } + + it should "access the spark context" in withSession { session => + val statement = execute(session, """sc""", "spark") + statement.id should equal (0) + + val result = parse(statement.output) + val resultMap = result.extract[Map[String, JValue]] + + // Manually extract the values since the line numbers in the exception could change. + resultMap("status").extract[String] should equal ("ok") + resultMap("execution_count").extract[Int] should equal (0) + + val data = resultMap("data").extract[Map[String, JValue]] + data("text/plain").extract[String] should include ( + "res0: org.apache.spark.SparkContext = org.apache.spark.SparkContext") + } + + it should "execute spark commands" in withSession { session => + val statement = execute(session, + """sc.parallelize(0 to 1).map{i => i+1}.collect""".stripMargin, "spark") + statement.id should equal (0) + + val result = parse(statement.output) + + val expectedResult = Extraction.decompose(Map( + "status" -> "ok", + "execution_count" -> 0, + "data" -> Map( + "text/plain" -> "res0: Array[Int] = Array(1, 2)" + ) + )) + + result should equal (expectedResult) + } + + it should "throw exception if code type is not specified in shared session" in withSession { + session => + intercept[IllegalArgumentException](session.execute("1 + 2")) + } + + it should "execute `1 + 2 = 3` in Python" in withSession { session => + val statement = execute(session, "1 + 2", "pyspark") + statement.id should equal (0) + + val result = parse(statement.output) + + val expectedResult = Extraction.decompose(Map( + "status" -> "ok", + "execution_count" -> 0, + "data" -> Map("text/plain" -> "3") + )) + + result should equal (expectedResult) + } + + it should "execute `1 + 2 = 3` in R" in withSession { session => + val statement = execute(session, "1 + 2", "sparkr") + statement.id should be (0) + + val result = parse(statement.output) + + val expectedResult = Extraction.decompose(Map( + "status" -> "ok", + "execution_count" -> 0, + "data" -> Map("text/plain" -> "[1] 3") + )) + + result should be (expectedResult) + } +} diff --git a/repl/src/test/scala/org/apache/livy/repl/SparkRInterpreterSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SparkRInterpreterSpec.scala index 374afbc41..e032e15d6 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SparkRInterpreterSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SparkRInterpreterSpec.scala @@ -18,11 +18,11 @@ package org.apache.livy.repl import org.apache.spark.SparkConf -import org.json4s.{DefaultFormats, JValue} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.scalatest._ -import org.apache.livy.rsc.RSCConf +import org.apache.livy.rsc.driver.SparkEntries class SparkRInterpreterSpec extends BaseInterpreterSpec { @@ -33,7 +33,11 @@ class SparkRInterpreterSpec extends BaseInterpreterSpec { super.withFixture(test) } - override def createInterpreter(): Interpreter = SparkRInterpreter(new SparkConf()) + + override def createInterpreter(): Interpreter = { + val sparkConf = new SparkConf() + SparkRInterpreter(sparkConf, new SparkEntries(sparkConf)) + } it should "execute `1 + 2` == 3" in withInterpreter { interpreter => val response = interpreter.execute("1 + 2") diff --git a/repl/src/test/scala/org/apache/livy/repl/SparkRSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SparkRSessionSpec.scala index 42ed60a31..dc1887dd4 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SparkRSessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SparkRSessionSpec.scala @@ -17,21 +17,18 @@ package org.apache.livy.repl -import org.apache.spark.SparkConf import org.json4s.Extraction import org.json4s.jackson.JsonMethods.parse -import org.apache.livy.rsc.RSCConf +import org.apache.livy.sessions._ -class SparkRSessionSpec extends BaseSessionSpec { +class SparkRSessionSpec extends BaseSessionSpec(SparkR()) { override protected def withFixture(test: NoArgTest) = { assume(!sys.props.getOrElse("skipRTests", "false").toBoolean, "Skipping R tests.") super.withFixture(test) } - override def createInterpreter(): Interpreter = SparkRInterpreter(new SparkConf()) - it should "execute `1 + 2` == 3" in withSession { session => val statement = execute(session)("1 + 2") statement.id should equal(0) diff --git a/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala index 93695199f..7e5108025 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala @@ -20,18 +20,15 @@ package org.apache.livy.repl import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.spark.SparkConf import org.json4s.Extraction import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods.parse import org.scalatest.concurrent.Eventually._ -import org.apache.livy.rsc.RSCConf import org.apache.livy.rsc.driver.StatementState +import org.apache.livy.sessions._ -class SparkSessionSpec extends BaseSessionSpec { - - override def createInterpreter(): Interpreter = new SparkInterpreter(new SparkConf()) +class SparkSessionSpec extends BaseSessionSpec(Spark()) { it should "execute `1 + 2` == 3" in withSession { session => val statement = execute(session)("1 + 2") diff --git a/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java b/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java index c25e98f4f..823e71ba5 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java +++ b/rsc/src/main/java/org/apache/livy/rsc/BaseProtocol.java @@ -61,17 +61,19 @@ public Error() { public static class BypassJobRequest { public final String id; + public final String jobType; public final byte[] serializedJob; public final boolean synchronous; - public BypassJobRequest(String id, byte[] serializedJob, boolean synchronous) { + public BypassJobRequest(String id, String jobType, byte[] serializedJob, boolean synchronous) { this.id = id; + this.jobType = jobType; this.serializedJob = serializedJob; this.synchronous = synchronous; } public BypassJobRequest() { - this(null, null, false); + this(null, null, null, false); } } @@ -171,13 +173,15 @@ public RemoteDriverAddress() { public static class ReplJobRequest { public final String code; + public final String codeType; - public ReplJobRequest(String code) { + public ReplJobRequest(String code, String codeType) { this.code = code; + this.codeType = codeType; } public ReplJobRequest() { - this(null); + this(null, null); } } diff --git a/rsc/src/main/java/org/apache/livy/rsc/ContextLauncher.java b/rsc/src/main/java/org/apache/livy/rsc/ContextLauncher.java index 8f46c1e0f..ed42e48dd 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/ContextLauncher.java +++ b/rsc/src/main/java/org/apache/livy/rsc/ContextLauncher.java @@ -173,12 +173,8 @@ private static ChildProcess startDriver(final RSCConf conf, Promise promise) } merge(conf, SPARK_JARS_KEY, livyJars, ","); - String kind = conf.get(SESSION_KIND); - if ("sparkr".equals(kind)) { - merge(conf, SPARK_ARCHIVES_KEY, conf.get(RSCConf.Entry.SPARKR_PACKAGE), ","); - } else if ("pyspark".equals(kind)) { - merge(conf, "spark.submit.pyFiles", conf.get(RSCConf.Entry.PYSPARK_ARCHIVES), ","); - } + merge(conf, SPARK_ARCHIVES_KEY, conf.get(RSCConf.Entry.SPARKR_PACKAGE), ","); + merge(conf, "spark.submit.pyFiles", conf.get(RSCConf.Entry.PYSPARK_ARCHIVES), ","); // Disable multiple attempts since the RPC server doesn't yet support multiple // connections for the same registered app. diff --git a/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java b/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java index 3161187d9..3fc334874 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java +++ b/rsc/src/main/java/org/apache/livy/rsc/RSCClient.java @@ -270,8 +270,8 @@ public Future addFile(URI uri) { return submit(new AddFileJob(uri.toString())); } - public String bypass(ByteBuffer serializedJob, boolean sync) { - return protocol.bypass(serializedJob, sync); + public String bypass(ByteBuffer serializedJob, String jobType, boolean sync) { + return protocol.bypass(serializedJob, jobType, sync); } public Future getBypassJobStatus(String id) { @@ -286,8 +286,8 @@ ContextInfo getContextInfo() { return contextInfo; } - public Future submitReplCode(String code) throws Exception { - return deferredCall(new BaseProtocol.ReplJobRequest(code), Integer.class); + public Future submitReplCode(String code, String codeType) throws Exception { + return deferredCall(new BaseProtocol.ReplJobRequest(code, codeType), Integer.class); } public void cancelReplCode(int statementId) throws Exception { @@ -354,9 +354,10 @@ Future run(Job job) { return (Future) deferredCall(new SyncJobRequest(job), Object.class); } - String bypass(ByteBuffer serializedJob, boolean sync) { + String bypass(ByteBuffer serializedJob, String jobType, boolean sync) { String jobId = UUID.randomUUID().toString(); - Object msg = new BypassJobRequest(jobId, BufferUtils.toByteArray(serializedJob), sync); + Object msg = + new BypassJobRequest(jobId, jobType, BufferUtils.toByteArray(serializedJob), sync); deferredCall(msg, Void.class); return jobId; } diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/JobContextImpl.java b/rsc/src/main/java/org/apache/livy/rsc/driver/JobContextImpl.java index ddb571353..9faae0e12 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/driver/JobContextImpl.java +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/JobContextImpl.java @@ -18,90 +18,47 @@ package org.apache.livy.rsc.driver; import java.io.File; -import java.lang.reflect.Method; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaFutureAction; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.apache.livy.JobContext; import org.apache.livy.rsc.Utils; class JobContextImpl implements JobContext { - private static final Logger LOG = LoggerFactory.getLogger(JobContextImpl.class); - - private final JavaSparkContext sc; private final File localTmpDir; - private volatile SQLContext sqlctx; - private volatile HiveContext hivectx; private volatile JavaStreamingContext streamingctx; private final RSCDriver driver; - private volatile Object sparksession; + private final SparkEntries sparkEntries; - public JobContextImpl(JavaSparkContext sc, File localTmpDir, RSCDriver driver) { - this.sc = sc; + public JobContextImpl(SparkEntries sparkEntries, File localTmpDir, RSCDriver driver) { + this.sparkEntries = sparkEntries; this.localTmpDir = localTmpDir; this.driver = driver; } @Override public JavaSparkContext sc() { - return sc; + return sparkEntries.sc(); } @Override public Object sparkSession() throws Exception { - if (sparksession == null) { - synchronized (this) { - if (sparksession == null) { - try { - Class clz = Class.forName("org.apache.spark.sql.SparkSession$"); - Object spark = clz.getField("MODULE$").get(null); - Method m = clz.getMethod("builder"); - Object builder = m.invoke(spark); - builder.getClass().getMethod("sparkContext", SparkContext.class) - .invoke(builder, sc.sc()); - sparksession = builder.getClass().getMethod("getOrCreate").invoke(builder); - } catch (Exception e) { - LOG.warn("SparkSession is not supported", e); - throw e; - } - } - } - } - - return sparksession; + return sparkEntries.sparkSession(); } @Override public SQLContext sqlctx() { - if (sqlctx == null) { - synchronized (this) { - if (sqlctx == null) { - sqlctx = new SQLContext(sc); - } - } - } - return sqlctx; + return sparkEntries.sqlctx(); } @Override public HiveContext hivectx() { - if (hivectx == null) { - synchronized (this) { - if (hivectx == null) { - hivectx = new HiveContext(sc.sc()); - } - } - } - return hivectx; + return sparkEntries.hivectx(); } @Override @@ -113,13 +70,13 @@ public synchronized JavaStreamingContext streamingctx(){ @Override public synchronized void createStreamingContext(long batchDuration) { Utils.checkState(streamingctx == null, "Streaming context is not null."); - streamingctx = new JavaStreamingContext(sc, new Duration(batchDuration)); + streamingctx = new JavaStreamingContext(sparkEntries.sc(), new Duration(batchDuration)); } @Override public synchronized void stopStreamingCtx() { Utils.checkState(streamingctx != null, "Streaming Context is null"); - streamingctx.stop(); + streamingctx.stop(false); streamingctx = null; } @@ -132,9 +89,7 @@ public synchronized void stop() { if (streamingctx != null) { stopStreamingCtx(); } - if (sc != null) { - sc.stop(); - } + sparkEntries.stop(); } public void addFile(String path) { diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java index 20be563d6..0c95c317b 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java @@ -48,7 +48,6 @@ import org.apache.hadoop.fs.Path; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -275,13 +274,11 @@ protected void broadcast(Object msg) { * and returning a null context is allowed. In that case, the context exposed by * JobContext will be null. */ - protected JavaSparkContext initializeContext() throws Exception { - long t1 = System.nanoTime(); - LOG.info("Starting Spark context..."); - JavaSparkContext sc = new JavaSparkContext(conf); - LOG.info("Spark context finished initialization in {}ms", - TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1)); - return sc; + protected SparkEntries initializeSparkEntries() throws Exception { + SparkEntries entries = new SparkEntries(conf); + // Explicitly call sc() to initialize SparkContext. + entries.sc(); + return entries; } protected void onClientAuthenticated(final Rpc client) { @@ -325,9 +322,9 @@ void run() throws Exception { try { initializeServer(); - JavaSparkContext sc = initializeContext(); + SparkEntries entries = initializeSparkEntries(); synchronized (jcLock) { - jc = new JobContextImpl(sc, localTmpDir, this); + jc = new JobContextImpl(entries, localTmpDir, this); jcLock.notifyAll(); } diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/SparkEntries.java b/rsc/src/main/java/org/apache/livy/rsc/driver/SparkEntries.java new file mode 100644 index 000000000..c28bac912 --- /dev/null +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/SparkEntries.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.rsc.driver; + +import java.lang.reflect.Method; +import java.util.concurrent.TimeUnit; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.hive.HiveContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkEntries { + + private static final Logger LOG = LoggerFactory.getLogger(SparkEntries.class); + + private volatile JavaSparkContext sc; + private final SparkConf conf; + private volatile SQLContext sqlctx; + private volatile HiveContext hivectx; + private volatile Object sparksession; + + public SparkEntries(SparkConf conf) { + this.conf = conf; + } + + public JavaSparkContext sc() { + if (sc == null) { + synchronized (this) { + if (sc == null) { + long t1 = System.nanoTime(); + LOG.info("Starting Spark context..."); + SparkContext scalaSc = SparkContext.getOrCreate(conf); + sc = new JavaSparkContext(scalaSc); + LOG.info("Spark context finished initialization in {}ms", + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1)); + } + } + } + return sc; + } + + @SuppressWarnings("unchecked") + public Object sparkSession() throws Exception { + if (sparksession == null) { + synchronized (this) { + if (sparksession == null) { + try { + Class clz = Class.forName("org.apache.spark.sql.SparkSession$"); + Object spark = clz.getField("MODULE$").get(null); + Method m = clz.getMethod("builder"); + Object builder = m.invoke(spark); + builder.getClass().getMethod("sparkContext", SparkContext.class) + .invoke(builder, sc().sc()); + + SparkConf conf = sc().getConf(); + if (conf.get("spark.sql.catalogImplementation", "in-memory").toLowerCase() + .equals("hive")) { + if ((boolean) clz.getMethod("hiveClassesArePresent").invoke(spark)) { + ClassLoader loader = Thread.currentThread().getContextClassLoader() != null ? + Thread.currentThread().getContextClassLoader() : getClass().getClassLoader(); + if (loader.getResource("hive-site.xml") == null) { + LOG.warn("livy.repl.enable-hive-context is true but no hive-site.xml found on " + + "classpath"); + } + + builder.getClass().getMethod("enableHiveSupport").invoke(builder); + sparksession = builder.getClass().getMethod("getOrCreate").invoke(builder); + LOG.info("Created Spark session (with Hive support)."); + } else { + builder.getClass().getMethod("config", String.class, String.class) + .invoke(builder, "spark.sql.catalogImplementation", "in-memory"); + sparksession = builder.getClass().getMethod("getOrCreate").invoke(builder); + LOG.info("Created Spark session."); + } + } else { + sparksession = builder.getClass().getMethod("getOrCreate").invoke(builder); + LOG.info("Created Spark session."); + } + } catch (Exception e) { + LOG.warn("SparkSession is not supported", e); + throw e; + } + } + } + } + + return sparksession; + } + + public SQLContext sqlctx() { + if (sqlctx == null) { + synchronized (this) { + if (sqlctx == null) { + sqlctx = new SQLContext(sc()); + LOG.info("Created SQLContext."); + } + } + } + return sqlctx; + } + + public HiveContext hivectx() { + if (hivectx == null) { + synchronized (this) { + if (hivectx == null) { + SparkConf conf = sc.getConf(); + if (conf.getBoolean("spark.repl.enableHiveContext", false) || + conf.get("spark.sql.catalogImplementation", "in-memory").toLowerCase() + .equals("hive")) { + ClassLoader loader = Thread.currentThread().getContextClassLoader() != null ? + Thread.currentThread().getContextClassLoader() : getClass().getClassLoader(); + if (loader.getResource("hive-site.xml") == null) { + LOG.warn("livy.repl.enable-hive-context is true but no hive-site.xml found on " + + "classpath."); + } + hivectx = new HiveContext(sc().sc()); + LOG.info("Created HiveContext."); + } + } + } + } + return hivectx; + } + + public synchronized void stop() { + if (sc != null) { + sc.stop(); + } + } +} diff --git a/rsc/src/test/java/org/apache/livy/rsc/TestSparkClient.java b/rsc/src/test/java/org/apache/livy/rsc/TestSparkClient.java index 06638227f..cacf94311 100644 --- a/rsc/src/test/java/org/apache/livy/rsc/TestSparkClient.java +++ b/rsc/src/test/java/org/apache/livy/rsc/TestSparkClient.java @@ -28,6 +28,7 @@ import java.util.jar.JarOutputStream; import java.util.zip.ZipEntry; +import org.apache.commons.io.FileUtils; import org.apache.spark.launcher.SparkLauncher; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.Test; @@ -76,6 +77,8 @@ private Properties createConf(boolean local) { } conf.put(LIVY_JARS.key(), ""); + conf.put("spark.repl.enableHiveContext", "true"); + conf.put("spark.sql.catalogImplementation", "hive"); return conf; } @@ -406,7 +409,7 @@ public void call(LivyClient client) throws Exception { Serializer s = new Serializer(); RSCClient lclient = (RSCClient) client; ByteBuffer job = s.serialize(new Echo<>("hello")); - String jobId = lclient.bypass(job, sync); + String jobId = lclient.bypass(job, "spark", sync); // Try to fetch the result, trying several times until the timeout runs out, and // backing off as attempts fail. @@ -485,6 +488,16 @@ private void runTest(boolean local, TestFunction test) throws Exception { if (client != null) { client.stop(true); } + + File derbyLog = new File("derby.log"); + if (derbyLog.exists()) { + derbyLog.delete(); + } + + File metaStore = new File("metastore_db"); + if (metaStore.exists()) { + FileUtils.deleteDirectory(metaStore); + } } } diff --git a/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala b/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala index bbb7abdd4..792c59f7a 100644 --- a/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala +++ b/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala @@ -17,10 +17,10 @@ package org.apache.livy.server.interactive -import org.apache.livy.sessions.{Kind, Spark} +import org.apache.livy.sessions.{Kind, Shared} class CreateInteractiveRequest { - var kind: Kind = Spark() + var kind: Kind = Shared() var proxyUser: Option[String] = None var jars: List[String] = List() var pyFiles: List[String] = List() @@ -37,8 +37,7 @@ class CreateInteractiveRequest { var heartbeatTimeoutInSecond: Int = 0 override def toString: String = { - s"[kind: $kind, " + - s"proxyUser: $proxyUser, " + + s"[kind: $kind, proxyUser: $proxyUser, " + (if (jars.nonEmpty) s"jars: ${jars.mkString(",")}, " else "") + (if (pyFiles.nonEmpty) s"pyFiles: ${pyFiles.mkString(",")}, " else "") + (if (files.nonEmpty) s"files: ${files.mkString(",")}, " else "") + diff --git a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala index 207cc0464..0462e80fc 100644 --- a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala +++ b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSession.scala @@ -296,18 +296,19 @@ object InteractiveSession extends Logging { } } - kind match { - case PySpark() | PySpark3() => - val pySparkFiles = if (!LivyConf.TEST_MODE) findPySparkArchives() else Nil - mergeConfList(pySparkFiles, LivyConf.SPARK_PY_FILES) - builderProperties.put(SPARK_YARN_IS_PYTHON, "true") - case SparkR() => - val sparkRArchive = if (!LivyConf.TEST_MODE) findSparkRArchive() else None - sparkRArchive.foreach { archive => - builderProperties.put(RSCConf.Entry.SPARKR_PACKAGE.key(), archive + "#sparkr") - } - case _ => + val pySparkFiles = if (!LivyConf.TEST_MODE) { + builderProperties.put(SPARK_YARN_IS_PYTHON, "true") + findPySparkArchives() + } else { + Nil } + mergeConfList(pySparkFiles, LivyConf.SPARK_PY_FILES) + + val sparkRArchive = if (!LivyConf.TEST_MODE) findSparkRArchive() else None + sparkRArchive.foreach { archive => + builderProperties.put(RSCConf.Entry.SPARKR_PACKAGE.key(), archive + "#sparkr") + } + builderProperties.put(RSCConf.Entry.SESSION_KIND.key, kind.toString) // Set Livy.rsc.jars from livy conf to rsc conf, RSC conf will take precedence if both are set. @@ -490,7 +491,7 @@ class InteractiveSession( ensureRunning() recordActivity() - val id = client.get.submitReplCode(content.code).get + val id = client.get.submitReplCode(content.code, content.kind.orNull).get client.get.getReplJobResults(id, 1).get().statements(0) } @@ -500,12 +501,12 @@ class InteractiveSession( client.get.cancelReplCode(statementId) } - def runJob(job: Array[Byte]): Long = { - performOperation(job, true) + def runJob(job: Array[Byte], jobType: String): Long = { + performOperation(job, jobType, true) } - def submitJob(job: Array[Byte]): Long = { - performOperation(job, false) + def submitJob(job: Array[Byte], jobType: String): Long = { + performOperation(job, jobType, false) } def addFile(fileStream: InputStream, fileName: String): Unit = { @@ -569,10 +570,10 @@ class InteractiveSession( } } - private def performOperation(job: Array[Byte], sync: Boolean): Long = { + private def performOperation(job: Array[Byte], jobType: String, sync: Boolean): Long = { ensureActive() recordActivity() - val future = client.get.bypass(ByteBuffer.wrap(job), sync) + val future = client.get.bypass(ByteBuffer.wrap(job), jobType, sync) val opId = operationCounter.incrementAndGet() operations(opId) = future opId diff --git a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala index 900c826fb..38008568c 100644 --- a/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala +++ b/server/src/main/scala/org/apache/livy/server/interactive/InteractiveSessionServlet.scala @@ -153,7 +153,7 @@ class InteractiveSessionServlet( withModifyAccessSession { session => try { require(req.job != null && req.job.length > 0, "no job provided.") - val jobId = session.submitJob(req.job) + val jobId = session.submitJob(req.job, req.jobType) Created(new JobStatus(jobId, JobHandle.State.SENT, null, null)) } catch { case e: Throwable => @@ -165,7 +165,7 @@ class InteractiveSessionServlet( jpost[SerializedJob]("/:id/run-job") { req => withModifyAccessSession { session => require(req.job != null && req.job.length > 0, "no job provided.") - val jobId = session.runJob(req.job) + val jobId = session.runJob(req.job, req.jobType) Created(new JobStatus(jobId, JobHandle.State.SENT, null, null)) } } @@ -211,10 +211,7 @@ class InteractiveSessionServlet( jpost[AddResource]("/:id/add-pyfile") { req => withModifyAccessSession { lsession => - lsession.kind match { - case PySpark() | PySpark3() => addJarOrPyFile(req, lsession) - case _ => BadRequest("Only supported for pyspark sessions.") - } + addJarOrPyFile(req, lsession) } } diff --git a/server/src/test/resources/log4j.properties b/server/src/test/resources/log4j.properties index 9c8586e33..9195bd934 100644 --- a/server/src/test/resources/log4j.properties +++ b/server/src/test/resources/log4j.properties @@ -17,7 +17,7 @@ # Set everything to be logged to the file target/unit-tests.log test.appender=file -log4j.rootCategory=DEBUG, ${test.appender} +log4j.rootCategory=INFO, ${test.appender} log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log diff --git a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala index 1d7206281..e1de22e3b 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala @@ -118,7 +118,7 @@ class InteractiveSessionServletSpec extends BaseInteractiveServletSpec { batch should be (defined) } - jpost[Map[String, Any]]("/0/statements", ExecuteRequest("foo")) { data => + jpost[Map[String, Any]]("/0/statements", ExecuteRequest("foo", Some("spark"))) { data => data("id") should be (0) data("code") shouldBe "1+1" data("progress") should be (0.0) diff --git a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala index d2ae9ae18..9943c00cd 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala @@ -71,8 +71,8 @@ class InteractiveSessionSpec extends FunSpec InteractiveSession.create(0, null, None, livyConf, req, sessionStore, mockApp) } - private def executeStatement(code: String): JValue = { - val id = session.executeStatement(ExecuteRequest(code)).id + private def executeStatement(code: String, codeType: Option[String] = None): JValue = { + val id = session.executeStatement(ExecuteRequest(code, codeType)).id eventually(timeout(30 seconds), interval(100 millis)) { val s = session.getStatement(id).get s.state.get() shouldBe StatementState.Available @@ -154,15 +154,10 @@ class InteractiveSessionSpec extends FunSpec properties1(RSCConf.Entry.LIVY_JARS.key()).split(",").toSet === rscJars1 } - it("should start in the idle state") { - session = createSession() - session.state should (be(a[SessionState.Starting]) or be(a[SessionState.Idle])) - } - it("should update appId and appInfo and session store") { val mockApp = mock[SparkApp] val sessionStore = mock[SessionStore] - val session = createSession(sessionStore, Some(mockApp)) + session = createSession(sessionStore, Some(mockApp)) val expectedAppId = "APPID" session.appIdKnown(expectedAppId) @@ -174,26 +169,38 @@ class InteractiveSessionSpec extends FunSpec verify(sessionStore, atLeastOnce()).save( MockitoMatchers.eq(InteractiveSession.RECOVERY_SESSION_TYPE), anyObject()) + + session.state should (be(a[SessionState.Starting]) or be(a[SessionState.Idle])) } withSession("should execute `1 + 2` == 3") { session => - val result = executeStatement("1 + 2") - val expectedResult = Extraction.decompose(Map( + val pyResult = executeStatement("1 + 2", Some("pyspark")) + pyResult should equal (Extraction.decompose(Map( "status" -> "ok", "execution_count" -> 0, - "data" -> Map( - "text/plain" -> "3" - ) - )) + "data" -> Map("text/plain" -> "3"))) + ) - result should equal (expectedResult) + val scalaResult = executeStatement("1 + 2", Some("spark")) + scalaResult should equal (Extraction.decompose(Map( + "status" -> "ok", + "execution_count" -> 1, + "data" -> Map("text/plain" -> "res0: Int = 3"))) + ) + + val rResult = executeStatement("1 + 2", Some("sparkr")) + rResult should equal (Extraction.decompose(Map( + "status" -> "ok", + "execution_count" -> 2, + "data" -> Map("text/plain" -> "[1] 3"))) + ) } withSession("should report an error if accessing an unknown variable") { session => val result = executeStatement("x") val expectedResult = Extraction.decompose(Map( "status" -> "error", - "execution_count" -> 1, + "execution_count" -> 3, "ename" -> "NameError", "evalue" -> "name 'x' is not defined", "traceback" -> List( @@ -214,7 +221,7 @@ class InteractiveSessionSpec extends FunSpec |from time import sleep |sleep(3) """.stripMargin - val statement = session.executeStatement(ExecuteRequest(code)) + val statement = session.executeStatement(ExecuteRequest(code, None)) statement.progress should be (0.0) eventually(timeout(10 seconds), interval(100 millis)) { @@ -225,7 +232,7 @@ class InteractiveSessionSpec extends FunSpec } withSession("should error out the session if the interpreter dies") { session => - session.executeStatement(ExecuteRequest("import os; os._exit(666)")) + session.executeStatement(ExecuteRequest("import os; os._exit(666)", None)) eventually(timeout(30 seconds), interval(100 millis)) { session.state shouldBe a[SessionState.Error] } diff --git a/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala index 6c0ff403d..d575b5459 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala @@ -75,7 +75,7 @@ class JobApiSpec extends BaseInteractiveServletSpec { withSessionId("should handle synchronous jobs") { testJobSubmission(_, true) } - // Test that the file does get copied over to the live home dir on HDFS - does not test end + // Test that the file does get copied over to the livy home dir on HDFS - does not test end // to end that the RSCClient class copies it over to the app. withSessionId("should support file uploads") { id => testResourceUpload("file", id) @@ -89,7 +89,7 @@ class JobApiSpec extends BaseInteractiveServletSpec { val ser = new Serializer() val job = BufferUtils.toByteArray(ser.serialize(new Echo("hello"))) var jobId: Long = -1L - jpost[JobStatus](s"/$sid/submit-job", new SerializedJob(job)) { status => + jpost[JobStatus](s"/$sid/submit-job", new SerializedJob(job, "spark")) { status => jobId = status.id } @@ -211,7 +211,7 @@ class JobApiSpec extends BaseInteractiveServletSpec { val jobData = BufferUtils.toByteArray(ser.serialize(job)) val route = if (sync) s"/$sid/submit-job" else s"/$sid/run-job" var jobId: Long = -1L - jpost[JobStatus](route, new SerializedJob(jobData), headers = headers) { data => + jpost[JobStatus](route, new SerializedJob(jobData, "spark"), headers = headers) { data => jobId = data.id }