forked from witgo/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-2313] Use socket to communicate GatewayServer port back to Pyt…
…hon driver This patch changes PySpark so that the GatewayServer's port is communicated back to the Python process that launches it over a local socket instead of a pipe. The old pipe-based approach was brittle and could fail if `spark-submit` printed unexpected to stdout. To accomplish this, I wrote a custom `PythonGatewayServer.main()` function to use in place of Py4J's `GatewayServer.main()`. Closes apache#3424. Author: Josh Rosen <[email protected]> Closes apache#4603 from JoshRosen/SPARK-2313 and squashes the following commits: 6a7740b [Josh Rosen] Remove EchoOutputThread since it's no longer needed 0db501f [Josh Rosen] Use select() so that we don't block if GatewayServer dies. 9bdb4b6 [Josh Rosen] Handle case where getListeningPort returns -1 3fb7ed1 [Josh Rosen] Remove stdout=PIPE 2458934 [Josh Rosen] Use underscore to mark env var. as private d12c95d [Josh Rosen] Use Logging and Utils.tryOrExit() e5f9730 [Josh Rosen] Wrap everything in a giant try-block 2f70689 [Josh Rosen] Use stdin PIPE to share fate with driver 8bf956e [Josh Rosen] Initial cut at passing Py4J gateway port back to driver via socket
- Loading branch information
Showing
3 changed files
with
97 additions
and
43 deletions.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.api.python | ||
|
||
import java.io.DataOutputStream | ||
import java.net.Socket | ||
|
||
import py4j.GatewayServer | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.util.Utils | ||
|
||
/** | ||
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port | ||
* back to its caller via a callback port specified by the caller. | ||
* | ||
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). | ||
*/ | ||
private[spark] object PythonGatewayServer extends Logging { | ||
def main(args: Array[String]): Unit = Utils.tryOrExit { | ||
// Start a GatewayServer on an ephemeral port | ||
val gatewayServer: GatewayServer = new GatewayServer(null, 0) | ||
gatewayServer.start() | ||
val boundPort: Int = gatewayServer.getListeningPort | ||
if (boundPort == -1) { | ||
logError("GatewayServer failed to bind; exiting") | ||
System.exit(1) | ||
} else { | ||
logDebug(s"Started PythonGatewayServer on port $boundPort") | ||
} | ||
|
||
// Communicate the bound port back to the caller via the caller-specified callback port | ||
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") | ||
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt | ||
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") | ||
val callbackSocket = new Socket(callbackHost, callbackPort) | ||
val dos = new DataOutputStream(callbackSocket.getOutputStream) | ||
dos.writeInt(boundPort) | ||
dos.close() | ||
callbackSocket.close() | ||
|
||
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: | ||
while (System.in.read() != -1) { | ||
// Do nothing | ||
} | ||
logDebug("Exiting due to broken pipe from Python driver") | ||
System.exit(0) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters