Skip to content

Commit

Permalink
[SPARK-1549] Add Python support to spark-submit
Browse files Browse the repository at this point in the history
This PR updates spark-submit to allow submitting Python scripts (currently only with deploy-mode=client, but that's all that was supported before) and updates the PySpark code to properly find various paths, etc. One significant change is that we assume we can always find the Python files either from the Spark assembly JAR (which will happen with the Maven assembly build in make-distribution.sh) or from SPARK_HOME (which will exist in local mode even if you use sbt assembly, and should be enough for testing). This means we no longer need a weird hack to modify the environment for YARN.

This patch also updates the Python worker manager to run python with -u, which means unbuffered output (send it to our logs right away instead of waiting a while after stuff was written); this should simplify debugging.

In addition, it fixes https://issues.apache.org/jira/browse/SPARK-1709, setting the main class from a JAR's Main-Class attribute if not specified by the user, and fixes a few help strings and style issues in spark-submit.

In the future we may want to make the `pyspark` shell use spark-submit as well, but it seems unnecessary for 1.0.

Author: Matei Zaharia <[email protected]>

Closes apache#664 from mateiz/py-submit and squashes the following commits:

15e9669 [Matei Zaharia] Fix some uses of path.separator property
051278c [Matei Zaharia] Small style fixes
0afe886 [Matei Zaharia] Add license headers
4650412 [Matei Zaharia] Add pyFiles to PYTHONPATH in executors, remove old YARN stuff, add tests
15f8e1e [Matei Zaharia] Set PYTHONPATH in PythonWorkerFactory in case it wasn't set from outside
47c0655 [Matei Zaharia] More work to make spark-submit work with Python:
d4375bd [Matei Zaharia] Clean up description of spark-submit args a bit and add Python ones
  • Loading branch information
mateiz committed May 6, 2014
1 parent ec09acd commit 951a5d9
Show file tree
Hide file tree
Showing 16 changed files with 505 additions and 194 deletions.
13 changes: 0 additions & 13 deletions assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@
<deb.user>root</deb.user>
</properties>

<repositories>
<!-- A repository in the local filesystem for the Py4J JAR, which is not in Maven central -->
<repository>
<id>lib</id>
<url>file://${project.basedir}/lib</url>
</repository>
</repositories>

<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down Expand Up @@ -84,11 +76,6 @@
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>net.sf.py4j</groupId>
<artifactId>py4j</artifactId>
<version>0.8.1</version>
</dependency>
</dependencies>

<build>
Expand Down
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@
<artifactId>pyrolite</artifactId>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>net.sf.py4j</groupId>
<artifactId>py4j</artifactId>
<version>0.8.1</version>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark

import java.io.File

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.concurrent.Await
Expand Down Expand Up @@ -304,7 +306,7 @@ object SparkEnv extends Logging {
k == "java.class.path"
}.getOrElse(("", ""))
val classPathEntries = classPathProperty._2
.split(conf.get("path.separator", ":"))
.split(File.pathSeparator)
.filterNot(e => e.isEmpty)
.map(e => (e, "System Classpath"))
val addedJarsAndFiles = (addedJars ++ addedFiles).map((_, "Added By User"))
Expand Down
42 changes: 42 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.python

import java.io.File

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkContext

private[spark] object PythonUtils {
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
def sparkPythonPath: String = {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python").mkString(File.separator)
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.1-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
}

/** Merge PYTHONPATHS with the appropriate separator. Ignores blank strings. */
def mergePythonPaths(paths: String*): String = {
paths.filter(_ != "").mkString(File.pathSeparator)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0

val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", ""))

def create(): Socket = {
if (useDaemon) {
createThroughDaemon()
Expand Down Expand Up @@ -78,9 +81,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))

// Create and start the worker
val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker"))
val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
val worker = pb.start()

// Redirect the worker's stderr to ours
Expand Down Expand Up @@ -151,9 +155,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

try {
// Create and start the daemon
val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon"))
val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
daemon = pb.start()

// Redirect the stderr to ours
Expand Down
84 changes: 84 additions & 0 deletions core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.deploy

import java.io.{IOException, File, InputStream, OutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._

import org.apache.spark.SparkContext
import org.apache.spark.api.python.PythonUtils

/**
* A main class used by spark-submit to launch Python applications. It executes python as a
* subprocess and then has it connect back to the JVM to access system properties, etc.
*/
object PythonRunner {
def main(args: Array[String]) {
val primaryResource = args(0)
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)

val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf

// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
val gatewayServer = new py4j.GatewayServer(null, 0)
gatewayServer.start()

// Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the
// python directories in SPARK_HOME (if set), and any files in the pyFiles argument
val pathElements = new ArrayBuffer[String]
pathElements ++= pyFiles.split(",")
pathElements += PythonUtils.sparkPythonPath
pathElements += sys.env.getOrElse("PYTHONPATH", "")
val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)

// Launch Python process
val builder = new ProcessBuilder(Seq(pythonExec, "-u", primaryResource) ++ otherArgs)
val env = builder.environment()
env.put("PYTHONPATH", pythonPath)
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()

new RedirectThread(process.getInputStream, System.out, "redirect output").start()

System.exit(process.waitFor())
}

/**
* A utility class to redirect the child process's stdout or stderr
*/
class RedirectThread(in: InputStream, out: OutputStream, name: String) extends Thread(name) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
out.write(buf, 0, len)
out.flush()
len = in.read(buf)
}
}
}
}
}
Loading

0 comments on commit 951a5d9

Please sign in to comment.