Skip to content

Commit

Permalink
[LIVY-194][REPL] Add shared language support for Livy interactive ses…
Browse files Browse the repository at this point in the history
…sion

This is a ongoing work. Putting it here to leverage travis to test.

In the current Livy when we create a new session we need to specify session kind, which represents the interpreter kind bound with this session. User could choose either `spark` or `pyspark` or `sparkr`, Livy internally will create a related interpreter with current session. Also in Livy each session represents a complete Spark application / SparkContext.

This brings a limitation that user can only choose one language along with its own Spark application. Furthermore, data generated by `pyspark` interpreter cannot be shared with `spark` interpreter. So to improve the usability of Livy we propose to add multiple interpreters support in one session.

Author: jerryshao <[email protected]>

Closes apache#28 from jerryshao/LIVY-194.

Change-Id: I743871c80ccb5c16101236e052d5f31662382667
  • Loading branch information
jerryshao committed Aug 31, 2017
1 parent 317290d commit c1aafeb
Show file tree
Hide file tree
Showing 42 changed files with 738 additions and 500 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ private SessionInfo() {
public static class SerializedJob implements ClientMessage {

public final byte[] job;
public final String jobType;

public SerializedJob(byte[] job) {
public SerializedJob(byte[] job, String jobType) {
this.job = job;
this.jobType = jobType;
}

private SerializedJob() {
this(null);
this(null, null);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void start(final String command, final ByteBuffer serializedJob) {
@Override
public void run() {
try {
ClientMessage msg = new SerializedJob(BufferUtils.toByteArray(serializedJob));
ClientMessage msg = new SerializedJob(BufferUtils.toByteArray(serializedJob), "spark");
JobStatus status = conn.post(msg, JobStatus.class, "/%d/%s", sessionId, command);

if (isCancelPending) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class HttpClientSpec extends FunSpecLike with BeforeAndAfterAll with LivyBaseUni

private def runJob(sync: Boolean, genStatusFn: Long => Seq[JobStatus]): (Long, JFuture[Int]) = {
val jobId = java.lang.Long.valueOf(ID_GENERATOR.incrementAndGet())
when(session.submitJob(any(classOf[Array[Byte]]))).thenReturn(jobId)
when(session.submitJob(any(classOf[Array[Byte]]), anyString())).thenReturn(jobId)

val statuses = genStatusFn(jobId)
val first = statuses.head
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/livy/msgs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class Msg[T <: Content](msg_type: MsgType, content: T)

sealed trait Content

case class ExecuteRequest(code: String) extends Content {
case class ExecuteRequest(code: String, kind: Option[String]) extends Content {
val msg_type = MsgType.execute_request
}

Expand Down
10 changes: 5 additions & 5 deletions core/src/main/scala/org/apache/livy/sessions/Kind.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ case class PySpark() extends Kind {
override def toString: String = "pyspark"
}

case class PySpark3() extends Kind {
override def toString: String = "pyspark3"
}

case class SparkR() extends Kind {
override def toString: String = "sparkr"
}

case class Shared() extends Kind {
override def toString: String = "shared"
}

object Kind {

def apply(kind: String): Kind = kind match {
case "spark" | "scala" => Spark()
case "pyspark" | "python" => PySpark()
case "pyspark3" | "python3" => PySpark3()
case "sparkr" | "r" => SparkR()
case "shared" => Shared()
case other => throw new IllegalArgumentException(s"Invalid kind: $other")
}

Expand Down
2 changes: 1 addition & 1 deletion integration-test/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
<configuration>
<environmentVariables>
<LIVY_HOME>${execution.root}</LIVY_HOME>
<LIVY_TEST>false</LIVY_TEST>
<LIVY_TEST>true</LIVY_TEST>
<LIVY_INTEGRATION_TEST>true</LIVY_INTEGRATION_TEST>
</environmentVariables>
<systemProperties>
Expand Down
4 changes: 3 additions & 1 deletion python-api/src/main/python/livy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, url, load_defaults=True, conf_dict=None):
uri = urlparse(url)
self._config = ConfigParser()
self._load_config(load_defaults, conf_dict)
self._job_type = 'pyspark'
match = re.match(r'(.*)/sessions/([0-9]+)', uri.path)
if match:
base = ParseResult(scheme=uri.scheme, netloc=uri.netloc,
Expand Down Expand Up @@ -395,7 +396,8 @@ def _reconnect_to_existing_session(self):
def _send_job(self, command, job):
pickled_job = cloudpickle.dumps(job)
base64_pickled_job = base64.b64encode(pickled_job).decode('utf-8')
base64_pickled_job_data = {'job': base64_pickled_job}
base64_pickled_job_data = \
{'job': base64_pickled_job, 'jobType': self._job_type}
handle = JobHandle(self._conn, self._session_id,
self._executor)
handle._start(command, base64_pickled_job_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ import scala.tools.nsc.interpreter.JPrintWriter
import scala.tools.nsc.interpreter.Results.Result
import scala.util.{Failure, Success, Try}

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkConf
import org.apache.spark.repl.SparkIMain

import org.apache.livy.rsc.driver.SparkEntries

/**
* This represents a Spark interpreter. It is not thread safe.
*/
class SparkInterpreter(conf: SparkConf)
extends AbstractSparkInterpreter with SparkContextInitializer {
class SparkInterpreter(protected override val conf: SparkConf) extends AbstractSparkInterpreter {

private var sparkIMain: SparkIMain = _
protected var sparkContext: SparkContext = _

override def start(): SparkContext = {
require(sparkIMain == null && sparkContext == null)
override def start(): Unit = {
require(sparkIMain == null)

val settings = new Settings()
settings.embeddedDefaults(Thread.currentThread().getContextClassLoader())
Expand Down Expand Up @@ -103,23 +103,21 @@ class SparkInterpreter(conf: SparkConf)
}
}

createSparkContext(conf)
postStart()
}

sparkContext
}

protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit = {
override protected def bind(name: String,
tpe: String,
value: Object,
modifier: List[String]): Unit = {
sparkIMain.beQuietDuring {
sparkIMain.bind(name, tpe, value, modifier)
}
}

override def close(): Unit = synchronized {
if (sparkContext != null) {
sparkContext.stop()
sparkContext = null
}
super.close()

if (sparkIMain != null) {
sparkIMain.close()
Expand All @@ -128,7 +126,7 @@ class SparkInterpreter(conf: SparkConf)
}

override protected def isStarted(): Boolean = {
sparkContext != null && sparkIMain != null
sparkIMain != null
}

override protected def interpret(code: String): Result = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ import scala.tools.nsc.interpreter.JPrintWriter
import scala.tools.nsc.interpreter.Results.Result
import scala.util.control.NonFatal

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkConf
import org.apache.spark.repl.SparkILoop

import org.apache.livy.rsc.driver.SparkEntries

/**
* Scala 2.11 version of SparkInterpreter
*/
class SparkInterpreter(conf: SparkConf)
extends AbstractSparkInterpreter with SparkContextInitializer {
class SparkInterpreter(protected override val conf: SparkConf) extends AbstractSparkInterpreter {

protected var sparkContext: SparkContext = _
private var sparkILoop: SparkILoop = _
private var sparkHttpServer: Object = _

override def start(): SparkContext = {
override def start(): Unit = {
require(sparkILoop == null)

val rootDir = conf.get("spark.repl.classdir", System.getProperty("java.io.tmpdir"))
Expand Down Expand Up @@ -89,17 +89,12 @@ class SparkInterpreter(conf: SparkConf)
}
}

createSparkContext(conf)
postStart()
}

sparkContext
}

override def close(): Unit = synchronized {
if (sparkContext != null) {
sparkContext.stop()
sparkContext = null
}
super.close()

if (sparkILoop != null) {
sparkILoop.closeInterpreter()
Expand All @@ -115,7 +110,7 @@ class SparkInterpreter(conf: SparkConf)
}

override protected def isStarted(): Boolean = {
sparkContext != null && sparkILoop != null
sparkILoop != null
}

override protected def interpret(code: String): Result = {
Expand All @@ -127,7 +122,10 @@ class SparkInterpreter(conf: SparkConf)
Option(sparkILoop.lastRequest.lineRep.call("$result"))
}

protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit = {
override protected def bind(name: String,
tpe: String,
value: Object,
modifier: List[String]): Unit = {
sparkILoop.beQuietDuring {
sparkILoop.bind(name, tpe, value, modifier)
}
Expand Down
34 changes: 31 additions & 3 deletions repl/src/main/resources/fake_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,14 +531,42 @@ def main():
listening_port = 0
if os.environ.get("LIVY_TEST") != "true":
#Load spark into the context
exec('from pyspark.shell import sc', global_dict)
exec('from pyspark.shell import sqlContext', global_dict)
exec('from pyspark.sql import HiveContext', global_dict)
exec('from pyspark.streaming import StreamingContext', global_dict)
exec('import pyspark.cloudpickle as cloudpickle', global_dict)

from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SQLContext, HiveContext, Row
# Connect to the gateway
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)

# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")

jsc = gateway.entry_point.sc()
jconf = gateway.entry_point.sc().getConf()
jsqlc = gateway.entry_point.hivectx() if gateway.entry_point.hivectx() is not None \
else gateway.entry_point.sqlctx()

conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
global_dict['sc'] = sc
sqlc = SQLContext(sc, jsqlc)
global_dict['sqlContext'] = sqlc

if spark_major_version >= "2":
exec('from pyspark.shell import spark', global_dict)
from pyspark.sql import SparkSession
spark_session = SparkSession(sc, gateway.entry_point.sparkSession())
global_dict['spark'] = spark_session
else:
# LIVY-294, need to check whether HiveContext can work properly,
# fallback to SQLContext if HiveContext can not be initialized successfully.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import java.io.ByteArrayOutputStream

import scala.tools.nsc.interpreter.Results

import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.json4s.DefaultFormats
import org.json4s.Extraction
import org.json4s.JsonAST._
import org.json4s.JsonDSL._

import org.apache.livy.Logging
import org.apache.livy.rsc.driver.SparkEntries

object AbstractSparkInterpreter {
private[repl] val KEEP_NEWLINE_REGEX = """(?<=\n)""".r
Expand All @@ -41,6 +43,10 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging {

protected val outputStream = new ByteArrayOutputStream()

protected var entries: SparkEntries = _

def sparkEntries(): SparkEntries = entries

final def kind: String = "spark"

protected def isStarted(): Boolean
Expand All @@ -49,6 +55,52 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging {

protected def valueOfTerm(name: String): Option[Any]

protected def bind(name: String, tpe: String, value: Object, modifier: List[String]): Unit

protected def conf: SparkConf

protected def postStart(): Unit = {
entries = new SparkEntries(conf)

if (isSparkSessionPresent()) {
bind("spark",
sparkEntries.sparkSession().getClass.getCanonicalName,
sparkEntries.sparkSession(),
List("""@transient"""))
bind("sc", "org.apache.spark.SparkContext", sparkEntries.sc().sc, List("""@transient"""))

execute("import org.apache.spark.SparkContext._")
execute("import spark.implicits._")
execute("import spark.sql")
execute("import org.apache.spark.sql.functions._")
} else {
bind("sc", "org.apache.spark.SparkContext", sparkEntries.sc().sc, List("""@transient"""))
val sqlContext = Option(sparkEntries.hivectx()).getOrElse(sparkEntries.sqlctx())
bind("sqlContext", sqlContext.getClass.getCanonicalName, sqlContext, List("""@transient"""))

execute("import org.apache.spark.SparkContext._")
execute("import sqlContext.implicits._")
execute("import sqlContext.sql")
execute("import org.apache.spark.sql.functions._")
}
}

override def close(): Unit = {
if (entries != null) {
entries.stop()
entries = null
}
}

private def isSparkSessionPresent(): Boolean = {
try {
Class.forName("org.apache.spark.sql.SparkSession")
true
} catch {
case _: ClassNotFoundException | _: NoClassDefFoundError => false
}
}

override protected[repl] def execute(code: String): Interpreter.ExecuteResponse =
restoreContextClassLoader {
require(isStarted())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,13 @@ package org.apache.livy.repl
import java.nio.charset.StandardCharsets

import org.apache.livy.{Job, JobContext}
import org.apache.livy.sessions._

class BypassPySparkJob(
serializedJob: Array[Byte],
replDriver: ReplDriver) extends Job[Array[Byte]] {
pi: PythonInterpreter) extends Job[Array[Byte]] {

override def call(jc: JobContext): Array[Byte] = {
val interpreter = replDriver.interpreter
require(interpreter != null && interpreter.isInstanceOf[PythonInterpreter])
val pi = interpreter.asInstanceOf[PythonInterpreter]

val resultByteArray = pi.pysparkJobProcessor.processBypassJob(serializedJob)
val resultString = new String(resultByteArray, StandardCharsets.UTF_8)
if (resultString.startsWith("Client job error:")) {
Expand Down
5 changes: 1 addition & 4 deletions repl/src/main/scala/org/apache/livy/repl/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.livy.repl

import org.apache.spark.SparkContext
import org.json4s.JObject

object Interpreter {
Expand All @@ -38,10 +37,8 @@ trait Interpreter {

/**
* Start the Interpreter.
*
* @return A SparkContext
*/
def start(): SparkContext
def start(): Unit

/**
* Execute the code and return the result, it may
Expand Down
Loading

0 comments on commit c1aafeb

Please sign in to comment.