Skip to content

Commit

Permalink
[Security] Update to support pyspark and sparkr changes in Spark 2.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryshao committed Jun 13, 2018
1 parent fe0283f commit 2196302
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 21 deletions.
11 changes: 9 additions & 2 deletions core/src/main/scala/org/apache/livy/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.livy

import java.io.{Closeable, File, FileInputStream, InputStreamReader}
import java.io.{Closeable, File, InputStreamReader}
import java.net.URL
import java.nio.charset.StandardCharsets.UTF_8
import java.util.Properties
import java.security.SecureRandom
import java.util.{Base64, Properties}

import scala.annotation.tailrec
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -106,4 +107,10 @@ object Utils {
}
}

def createSecret(secretBitLength: Int): String = {
val rnd = new SecureRandom()
val secretBytes = new Array[Byte](secretBitLength / java.lang.Byte.SIZE)
rnd.nextBytes(secretBytes)
Base64.getEncoder.encodeToString(secretBytes)
}
}
8 changes: 4 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1196,13 +1196,13 @@
</property>
</activation>
<properties>
<spark.scala-2.11.version>2.3.0</spark.scala-2.11.version>
<spark.scala-2.11.version>2.3.1</spark.scala-2.11.version>
<spark.scala-2.10.version>2.2.0</spark.scala-2.10.version>
<spark.version>${spark.scala-2.11.version}</spark.version>
<netty.spark-2.11.version>4.1.17.Final</netty.spark-2.11.version>
<netty.spark-2.10.version>4.0.37.Final</netty.spark-2.10.version>
<java.version>1.8</java.version>
<py4j.version>0.10.4</py4j.version>
<py4j.version>0.10.7</py4j.version>
<json4s.version>3.2.11</json4s.version>
</properties>
</profile>
Expand All @@ -1216,9 +1216,9 @@
</activation>
<properties>
<spark.bin.download.url>
http://apache.mirrors.ionfish.org/spark/spark-2.3.0/spark-2.3.0-bin-hadoop2.7.tgz
http://mirrors.advancedhosters.com/apache/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz
</spark.bin.download.url>
<spark.bin.name>spark-2.3.0-bin-hadoop2.7</spark.bin.name>
<spark.bin.name>spark-2.3.1-bin-hadoop2.7</spark.bin.name>
</properties>
</profile>

Expand Down
23 changes: 17 additions & 6 deletions repl/src/main/resources/fake_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,13 @@ def main():
from pyspark.sql import SQLContext, HiveContext, Row
# Connect to the gateway
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
try:
from py4j.java_gateway import GatewayParameters
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
gateway = JavaGateway(gateway_parameters=GatewayParameters(
port=gateway_port, auth_token=gateway_secret, auto_convert=True))
except:
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)

# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
Expand Down Expand Up @@ -613,12 +619,17 @@ def main():

#Start py4j callback server
from py4j.protocol import ENTRY_POINT_OBJECT_ID
from py4j.java_gateway import JavaGateway, GatewayClient, CallbackServerParameters
from py4j.java_gateway import CallbackServerParameters

try:
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
gateway.start_callback_server(
callback_server_parameters=CallbackServerParameters(
port=0, auth_token=gateway_secret))
except:
gateway.start_callback_server(
callback_server_parameters=CallbackServerParameters(port=0))

gateway_client_port = int(os.environ.get("PYSPARK_GATEWAY_PORT"))
gateway = JavaGateway(GatewayClient(port=gateway_client_port))
gateway.start_callback_server(
callback_server_parameters=CallbackServerParameters(port=0))
socket_info = gateway._callback_server.server_socket.getsockname()
listening_port = socket_info[1]
pyspark_job_processor = PySparkJobProcessorImpl()
Expand Down
30 changes: 28 additions & 2 deletions repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.livy.repl

import java.io._
import java.lang.{Integer => JInteger}
import java.lang.ProcessBuilder.Redirect
import java.lang.reflect.Proxy
import java.net.InetAddress
import java.nio.file.{Files, Paths}

import scala.annotation.tailrec
Expand All @@ -35,7 +37,7 @@ import org.json4s.jackson.Serialization.write
import py4j._
import py4j.reflection.PythonProxyHandler

import org.apache.livy.Logging
import org.apache.livy.{Logging, Utils}
import org.apache.livy.client.common.ClientConf
import org.apache.livy.rsc.driver.SparkEntries
import org.apache.livy.sessions._
Expand All @@ -49,7 +51,8 @@ object PythonInterpreter extends Logging {
.orElse(sys.props.get("pyspark.python")) // This java property is only used for internal UT.
.getOrElse("python")

val gatewayServer = new GatewayServer(sparkEntries, 0)
val secretKey = Utils.createSecret(256)
val gatewayServer = createGatewayServer(sparkEntries, secretKey)
gatewayServer.start()

val builder = new ProcessBuilder(Seq(pythonExec, createFakeShell().toString).asJava)
Expand All @@ -65,6 +68,7 @@ object PythonInterpreter extends Logging {
env.put("PYTHONPATH", pythonPath.mkString(File.pathSeparator))
env.put("PYTHONUNBUFFERED", "YES")
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
env.put("PYSPARK_GATEWAY_SECRET", secretKey)
env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
env.put("LIVY_SPARK_MAJOR_VERSION", conf.get("spark.livy.spark_major_version", "1"))
builder.redirectError(Redirect.PIPE)
Expand Down Expand Up @@ -131,6 +135,28 @@ object PythonInterpreter extends Logging {
file
}

private def createGatewayServer(sparkEntries: SparkEntries, secretKey: String): GatewayServer = {
try {
val clz = Class.forName("py4j.GatewayServer$GatewayServerBuilder", true,
Thread.currentThread().getContextClassLoader)
val builder = clz.getConstructor(classOf[Object])
.newInstance(sparkEntries)

val localhost = InetAddress.getLoopbackAddress()
builder.getClass.getMethod("authToken", classOf[String]).invoke(builder, secretKey)
builder.getClass.getMethod("javaPort", classOf[Int]).invoke(builder, 0: JInteger)
builder.getClass.getMethod("javaAddress", classOf[InetAddress]).invoke(builder, localhost)
builder.getClass
.getMethod("callbackClient", classOf[Int], classOf[InetAddress], classOf[String])
.invoke(builder, GatewayServer.DEFAULT_PYTHON_PORT: JInteger, localhost, secretKey)
builder.getClass.getMethod("build").invoke(builder).asInstanceOf[GatewayServer]
} catch {
case NonFatal(e) =>
warn("Fail to create GatewayServer with auth parameter, downgrade to old constructor", e)
new GatewayServer(sparkEntries, 0)
}
}

private def initiatePy4jCallbackGateway(server: GatewayServer): PySparkJobProcessor = {
val f = server.getClass.getDeclaredField("gateway")
f.setAccessible(true)
Expand Down
40 changes: 33 additions & 7 deletions repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe
import scala.util.Try
import scala.util.control.NonFatal

import org.apache.commons.codec.binary.Base64
import org.apache.commons.lang.StringEscapeUtils
Expand All @@ -33,13 +35,14 @@ import org.apache.spark.sql.SQLContext
import org.json4s._
import org.json4s.JsonDSL._

import org.apache.livy.Logging
import org.apache.livy.client.common.ClientConf
import org.apache.livy.rsc.driver.SparkEntries

private case class RequestResponse(content: String, error: Boolean)

// scalastyle:off println
object SparkRInterpreter {
object SparkRInterpreter extends Logging {
private val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----"
private val LIVY_ERROR_MARKER = "----LIVY_END_OF_ERROR----"
private val PRINT_MARKER = f"""print("$LIVY_END_MARKER")"""
Expand Down Expand Up @@ -76,12 +79,25 @@ object SparkRInterpreter {
val backendInstance = sparkRBackendClass.getDeclaredConstructor().newInstance()

var sparkRBackendPort = 0
var sparkRBackendSecret: String = null
val initialized = new Semaphore(0)
// Launch a SparkR backend server for the R process to connect to
val backendThread = new Thread("SparkR backend") {
override def run(): Unit = {
sparkRBackendPort = sparkRBackendClass.getMethod("init").invoke(backendInstance)
.asInstanceOf[Int]
try {
sparkRBackendPort = sparkRBackendClass.getMethod("init").invoke(backendInstance)
.asInstanceOf[Int]
} catch {
case NonFatal(e) =>
warn("Fail to init Spark RBackend, using different method signature", e)
val retTuple = sparkRBackendClass.getMethod("init").invoke(backendInstance)
.asInstanceOf[(Int, Object)]
sparkRBackendPort = retTuple._1
sparkRBackendSecret = Try {
val rAuthHelper = retTuple._2
rAuthHelper.getClass.getMethod("secret").invoke(rAuthHelper).asInstanceOf[String]
}.getOrElse(null)
}

initialized.release()
sparkRBackendClass.getMethod("run").invoke(backendInstance)
Expand Down Expand Up @@ -116,14 +132,17 @@ object SparkRInterpreter {
val env = builder.environment()
env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
if (sparkRBackendSecret != null) {
env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret)
}
env.put("SPARKR_PACKAGE_DIR", packageDir)
env.put("R_PROFILE_USER",
Seq(packageDir, "SparkR", "profile", "general.R").mkString(File.separator))

builder.redirectErrorStream(true)
val process = builder.start()
new SparkRInterpreter(process, backendInstance, backendThread,
conf.getInt("spark.livy.spark_major_version", 1))
conf.getInt("spark.livy.spark_major_version", 1), sparkRBackendSecret != null)
} catch {
case e: Exception =>
if (backendThread != null) {
Expand All @@ -149,10 +168,12 @@ object SparkRInterpreter {
}
}

class SparkRInterpreter(process: Process,
class SparkRInterpreter(
process: Process,
backendInstance: Any,
backendThread: Thread,
val sparkMajorVersion: Int)
val sparkMajorVersion: Int,
authProvided: Boolean)
extends ProcessInterpreter(process) {
import SparkRInterpreter._

Expand All @@ -169,7 +190,12 @@ class SparkRInterpreter(process: Process,
// scalastyle:off line.size.limit
sendRequest("library(SparkR)")
sendRequest("""port <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")""")
sendRequest("""SparkR:::connectBackend("localhost", port, 6000)""")
if (authProvided) {
sendRequest("""authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET", "")""")
sendRequest("""SparkR:::connectBackend("localhost", port, 6000, authSecret)""")
} else {
sendRequest("""SparkR:::connectBackend("localhost", port, 6000)""")
}
sendRequest("""assign(".scStartTime", as.integer(Sys.time()), envir = SparkR:::.sparkREnv)""")

sendRequest("""assign(".sc", SparkR:::callJStatic("org.apache.livy.repl.SparkRInterpreter", "getSparkContext"), envir = SparkR:::.sparkREnv)""")
Expand Down

0 comments on commit 2196302

Please sign in to comment.