Skip to content

Commit

Permalink
LIVY-233. SparkSession support for Job API. (apache#217)
Browse files Browse the repository at this point in the history
Added sparkSession to expose Spark 2.0 SparkSession in JobContext. If SparkSession is not supported, it will throw an exception.

Example:
JobHandle<String> handler = client.submit(new Job<String>() {
  @OverRide
  public String call(JobContext jc) throws Exception {
	SparkSession session = jc.sparkSession();
	return session.version();
  }
});
  • Loading branch information
jerryshao authored and alex-the-man committed Dec 12, 2016
1 parent 1ee2ed8 commit 1b72ce9
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 6 deletions.
8 changes: 6 additions & 2 deletions api/src/main/java/com/cloudera/livy/JobContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ public interface JobContext {
/** The shared SparkContext instance. */
JavaSparkContext sc();

/** The shared SQLContext inststance. */
/** The shared SQLContext instance. */
SQLContext sqlctx();

/** The shared HiveContext inststance. */
/** The shared HiveContext instance. */
HiveContext hivectx();

/** Returns the JavaStreamingContext which has already been created. */
Expand All @@ -63,4 +63,8 @@ public interface JobContext {
*/
File getLocalTmpDir();

/**
* Returns SparkSession if it existed, otherwise throws Exception.
*/
<E> E sparkSession() throws Exception;
}
29 changes: 29 additions & 0 deletions integration-test/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,35 @@

<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>parse-spark-version</id>
<phase>process-test-sources</phase>
<goals>
<goal>parse-version</goal>
</goals>
<configuration>
<propertyPrefix>spark</propertyPrefix>
<versionString>${spark.version}</versionString>
</configuration>
</execution>
<execution>
<id>add-spark-version-specific-test</id>
<phase>process-test-sources</phase>
<goals>
<goal>add-test-source</goal>
</goals>
<configuration>
<sources>
<source>${project.basedir}/src/test/spark${spark.majorVersion}/scala</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import com.cloudera.livy.utils.LivySparkUtils

// Proper type representing the return value of "GET /sessions". At some point we should make
// SessionServlet use something like this.
private class SessionList {
class SessionList {
val from: Int = -1
val total: Int = -1
val sessions: List[SessionInfo] = Nil
Expand Down Expand Up @@ -130,13 +130,13 @@ class JobApiIT extends BaseIntegrationTestSuite with BeforeAndAfterAll with Logg

test("run spark job") {
assume(client != null, "Client not active.")
val result = waitFor(client.submit(new SmallCount(100)));
val result = waitFor(client.submit(new SmallCount(100)))
assert(result === 100)
}

test("run spark sql job") {
assume(client != null, "Client not active.")
val result = waitFor(client.submit(new SQLGetTweets(false)));
val result = waitFor(client.submit(new SQLGetTweets(false)))
assert(result.size() > 0)
}

Expand Down
107 changes: 107 additions & 0 deletions integration-test/src/test/spark2/scala/Spark2JobApiIT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to Cloudera, Inc. under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Cloudera, Inc. licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.cloudera.livy.test

import java.io.File
import java.net.URI
import java.util.concurrent.{TimeUnit, Future => JFuture}
import javax.servlet.http.HttpServletResponse

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.scalatest.BeforeAndAfterAll

import com.cloudera.livy._
import com.cloudera.livy.client.common.HttpMessages._
import com.cloudera.livy.sessions.SessionKindModule
import com.cloudera.livy.test.framework.BaseIntegrationTestSuite
import com.cloudera.livy.test.jobs.spark2._

class Spark2JobApiIT extends BaseIntegrationTestSuite with BeforeAndAfterAll with Logging {

private var client: LivyClient = _
private var sessionId: Int = _
private val mapper = new ObjectMapper()
.registerModule(DefaultScalaModule)
.registerModule(new SessionKindModule())

override def afterAll(): Unit = {
super.afterAll()

if (client != null) {
client.stop(true)
}

livyClient.connectSession(sessionId).stop()
}

test("create a new session and upload test jar") {
val tempClient = createClient(livyEndpoint)

try {
// Figure out the session ID by poking at the REST endpoint. We should probably expose this
// in the Java API.
val list = sessionList()
assert(list.total === 1)
val tempSessionId = list.sessions(0).id

livyClient.connectSession(tempSessionId).verifySessionIdle()
waitFor(tempClient.uploadJar(new File(testLib)))

client = tempClient
sessionId = tempSessionId
} finally {
if (client == null) {
try {
if (tempClient != null) {
tempClient.stop(true)
}
} catch {
case e: Exception => warn("Error stopping client.", e)
}
}
}
}

test("run spark2 job") {
assume(client != null, "Client not active.")
val result = waitFor(client.submit(new SparkSessionTest()))
assert(result === 3)
}

test("run spark2 dataset job") {
assume(client != null, "Client not active.")
val result = waitFor(client.submit(new DatasetTest()))
assert(result === 2)
}

private def waitFor[T](future: JFuture[T]): T = {
future.get(30, TimeUnit.SECONDS)
}

private def sessionList(): SessionList = {
val response = httpClient.prepareGet(s"$livyEndpoint/sessions/").execute().get()
assert(response.getStatusCode === HttpServletResponse.SC_OK)
mapper.readValue(response.getResponseBodyAsStream, classOf[SessionList])
}

private def createClient(uri: String): LivyClient = {
new LivyClientBuilder().setURI(new URI(uri)).build()
}
}
20 changes: 20 additions & 0 deletions python-api/src/main/python/livy/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,23 @@ def local_tmp_dir_path(self):
Returns a local tmp dir path specific to the context
"""
pass

@abstractproperty
def spark_session(self):
"""
The shared SparkSession instance.
Returns
-------
sc : pyspark.sql.SparkSession
A SparkSession instance
Examples
-------
>>> def simple_spark_job(context):
>>> session = context.spark_session
>>> df1 = session.read.json('/sample.json')
>>> return df1.dTypes()
"""
pass
4 changes: 4 additions & 0 deletions repl/src/main/resources/fake_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self):
self.hive_ctx = None
self.streaming_ctx = None
self.local_tmp_dir_path = local_tmp_dir_path
self.spark_session = global_dict['spark']

def sc(self):
return self.sc
Expand Down Expand Up @@ -143,6 +144,9 @@ def stop(self):
if self.sc is not None:
self.sc.stop()

def spark_session(self):
return self.spark_session


class PySparkJobProcessorImpl(object):
def processBypassJob(self, serialized_job):
Expand Down
27 changes: 27 additions & 0 deletions rsc/src/main/java/com/cloudera/livy/rsc/driver/JobContextImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package com.cloudera.livy.rsc.driver;

import java.io.File;
import java.lang.reflect.Method;

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaFutureAction;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
Expand All @@ -41,6 +43,7 @@ class JobContextImpl implements JobContext {
private volatile HiveContext hivectx;
private volatile JavaStreamingContext streamingctx;
private final RSCDriver driver;
private volatile Object sparksession;

public JobContextImpl(JavaSparkContext sc, File localTmpDir, RSCDriver driver) {
this.sc = sc;
Expand All @@ -53,6 +56,30 @@ public JavaSparkContext sc() {
return sc;
}

@Override
public Object sparkSession() throws Exception {
if (sparksession == null) {
synchronized (this) {
if (sparksession == null) {
try {
Class<?> clz = Class.forName("org.apache.spark.sql.SparkSession$");
Object spark = clz.getField("MODULE$").get(null);
Method m = clz.getMethod("builder");
Object builder = m.invoke(spark);
builder.getClass().getMethod("sparkContext", SparkContext.class)
.invoke(builder, sc.sc());
sparksession = builder.getClass().getMethod("getOrCreate").invoke(builder);
} catch (Exception e) {
LOG.warn("SparkSession is not supported", e);
throw e;
}
}
}
}

return sparksession;
}

@Override
public SQLContext sqlctx() {
if (sqlctx == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ScalaJobContext private[livy] (context: JobContext) {
/** Returns the StreamingContext which has already been created. */
def streamingctx: StreamingContext = context.streamingctx().ssc

def sparkSession[E]: E = context.sparkSession()

/**
* Creates the SparkStreaming context.
*
Expand Down
32 changes: 31 additions & 1 deletion test-lib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,37 @@
<skip>true</skip>
</configuration>
</plugin>

<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>parse-spark-version</id>
<phase>process-sources</phase>
<goals>
<goal>parse-version</goal>
</goals>
<configuration>
<propertyPrefix>spark</propertyPrefix>
<versionString>${spark.version}</versionString>
</configuration>
</execution>
<execution>
<id>add-spark2-source-code</id>
<phase>process-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${project.basedir}/src/main/spark${spark.majorVersion}/scala</source>
<source>${project.basedir}/src/main/spark${spark.majorVersion}/java</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to Cloudera, Inc. under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Cloudera, Inc. licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.cloudera.livy.test.jobs.spark2;

import java.util.Arrays;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import com.cloudera.livy.Job;
import com.cloudera.livy.JobContext;

public class DatasetTest implements Job<Long> {

@Override
public Long call(JobContext jc) throws Exception {
SparkSession spark = jc.sparkSession();

JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
JavaRDD<Row> rdd = sc.parallelize(Arrays.asList(1, 2, 3)).map(
new Function<Integer, Row>() {
public Row call(Integer integer) throws Exception {
return RowFactory.create(integer);
}
});
StructType schema = DataTypes.createStructType(new StructField[] {
DataTypes.createStructField("value", DataTypes.IntegerType, false)
});

Dataset<Row> ds = spark.createDataFrame(rdd, schema);

return ds.filter(new FilterFunction<Row>() {
@Override
public boolean call(Row row) throws Exception {
return row.getInt(0) >= 2;
}
}).count();
}
}
Loading

0 comments on commit 1b72ce9

Please sign in to comment.