Skip to content

Commit

Permalink
[SPARK-2313] Use socket to communicate GatewayServer port back to Pyt…
Browse files Browse the repository at this point in the history
…hon driver

This patch changes PySpark so that the GatewayServer's port is communicated back to the Python process that launches it over a local socket instead of a pipe.  The old pipe-based approach was brittle and could fail if `spark-submit` printed unexpected to stdout.

To accomplish this, I wrote a custom `PythonGatewayServer.main()` function to use in place of Py4J's `GatewayServer.main()`.

Closes apache#3424.

Author: Josh Rosen <[email protected]>

Closes apache#4603 from JoshRosen/SPARK-2313 and squashes the following commits:

6a7740b [Josh Rosen] Remove EchoOutputThread since it's no longer needed
0db501f [Josh Rosen] Use select() so that we don't block if GatewayServer dies.
9bdb4b6 [Josh Rosen] Handle case where getListeningPort returns -1
3fb7ed1 [Josh Rosen] Remove stdout=PIPE
2458934 [Josh Rosen] Use underscore to mark env var. as private
d12c95d [Josh Rosen] Use Logging and Utils.tryOrExit()
e5f9730 [Josh Rosen] Wrap everything in a giant try-block
2f70689 [Josh Rosen] Use stdin PIPE to share fate with driver
8bf956e [Josh Rosen] Initial cut at passing Py4J gateway port back to driver via socket
  • Loading branch information
JoshRosen committed Feb 16, 2015
1 parent c01c4eb commit 0cfda84
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket

import py4j.GatewayServer

import org.apache.spark.Logging
import org.apache.spark.util.Utils

/**
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
* back to its caller via a callback port specified by the caller.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
def main(args: Array[String]): Unit = Utils.tryOrExit {
// Start a GatewayServer on an ephemeral port
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
logError("GatewayServer failed to bind; exiting")
System.exit(1)
} else {
logDebug(s"Started PythonGatewayServer on port $boundPort")
}

// Communicate the bound port back to the caller via the caller-specified callback port
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
val callbackSocket = new Socket(callbackHost, callbackPort)
val dos = new DataOutputStream(callbackSocket.getOutputStream)
dos.writeInt(boundPort)
dos.close()
callbackSocket.close()

// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
while (System.in.read() != -1) {
// Do nothing
}
logDebug("Exiting due to broken pipe from Python driver")
System.exit(0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver}

import org.apache.spark.SPARK_VERSION
import org.apache.spark.deploy.rest._
import org.apache.spark.executor._
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}

/**
Expand Down Expand Up @@ -284,8 +283,7 @@ object SparkSubmit {
// If we're running a python app, set the main class to our specific python runner
if (args.isPython && deployMode == CLIENT) {
if (args.primaryResource == PYSPARK_SHELL) {
args.mainClass = "py4j.GatewayServer"
args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0")
args.mainClass = "org.apache.spark.api.python.PythonGatewayServer"
} else {
// If a python file is provided, add it to the child arguments and list of files to deploy.
// Usage: PythonAppRunner <main python file> <extra python files> [app arguments]
Expand Down
72 changes: 32 additions & 40 deletions python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@

import atexit
import os
import sys
import select
import signal
import shlex
import socket
import platform
from subprocess import Popen, PIPE
from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient

from pyspark.serializers import read_int


def launch_gateway():
SPARK_HOME = os.environ["SPARK_HOME"]

gateway_port = -1
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
else:
Expand All @@ -41,36 +42,42 @@ def launch_gateway():
submit_args = submit_args if submit_args is not None else ""
submit_args = shlex.split(submit_args)
command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"]

# Start a socket that will be used by PythonGatewayServer to communicate its port to us
callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
callback_socket.bind(('127.0.0.1', 0))
callback_socket.listen(1)
callback_host, callback_port = callback_socket.getsockname()
env = dict(os.environ)
env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)

# Launch the Java gateway.
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
if not on_windows:
# Don't send ctrl-c / SIGINT to the Java gateway:
def preexec_func():
signal.signal(signal.SIGINT, signal.SIG_IGN)
env = dict(os.environ)
env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits
proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env)
proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
else:
# preexec_fn not supported on Windows
proc = Popen(command, stdout=PIPE, stdin=PIPE)
proc = Popen(command, stdin=PIPE, env=env)

try:
# Determine which ephemeral port the server started on:
gateway_port = proc.stdout.readline()
gateway_port = int(gateway_port)
except ValueError:
# Grab the remaining lines of stdout
(stdout, _) = proc.communicate()
exit_code = proc.poll()
error_msg = "Launching GatewayServer failed"
error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n"
error_msg += "Warning: Expected GatewayServer to output a port, but found "
if gateway_port == "" and stdout == "":
error_msg += "no output.\n"
else:
error_msg += "the following:\n\n"
error_msg += "--------------------------------------------------------------\n"
error_msg += gateway_port + stdout
error_msg += "--------------------------------------------------------------\n"
raise Exception(error_msg)
gateway_port = None
# We use select() here in order to avoid blocking indefinitely if the subprocess dies
# before connecting
while gateway_port is None and proc.poll() is None:
timeout = 1 # (seconds)
readable, _, _ = select.select([callback_socket], [], [], timeout)
if callback_socket in readable:
gateway_connection = callback_socket.accept()[0]
# Determine which ephemeral port the server started on:
gateway_port = read_int(gateway_connection.makefile())
gateway_connection.close()
callback_socket.close()
if gateway_port is None:
raise Exception("Java gateway process exited before sending the driver its port number")

# In Windows, ensure the Java child processes do not linger after Python has exited.
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
Expand All @@ -88,21 +95,6 @@ def killChild():
Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
atexit.register(killChild)

# Create a thread to echo output from the GatewayServer, which is required
# for Java log output to show up:
class EchoOutputThread(Thread):

def __init__(self, stream):
Thread.__init__(self)
self.daemon = True
self.stream = stream

def run(self):
while True:
line = self.stream.readline()
sys.stderr.write(line)
EchoOutputThread(proc.stdout).start()

# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)

Expand Down

0 comments on commit 0cfda84

Please sign in to comment.