Skip to content

Commit

Permalink
[SPARK-3030] [PySpark] Reuse Python worker
Browse files Browse the repository at this point in the history
Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts.

This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming.

For a job with broadcast (43M after compress):
```
    b = sc.broadcast(set(range(30000000)))
    print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count()
```
It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks.

It's enabled by default, could be disabled by `spark.python.worker.reuse = false`.

Author: Davies Liu <[email protected]>

Closes apache#2259 from davies/reuse-worker and squashes the following commits:

f11f617 [Davies Liu] Merge branch 'master' into reuse-worker
3939f20 [Davies Liu] fix bug in serializer in mllib
cf1c55e [Davies Liu] address comments
3133a60 [Davies Liu] fix accumulator with reused worker
760ab1f [Davies Liu] do not reuse worker if there are any exceptions
7abb224 [Davies Liu] refactor: sychronized with itself
ac3206e [Davies Liu] renaming
8911f44 [Davies Liu] synchronized getWorkerBroadcasts()
6325fc1 [Davies Liu] bugfix: bid >= 0
e0131a2 [Davies Liu] fix name of config
583716e [Davies Liu] only reuse completed and not interrupted worker
ace2917 [Davies Liu] kill python worker after timeout
6123d0f [Davies Liu] track broadcasts for each worker
8d2f08c [Davies Liu] reuse python worker
  • Loading branch information
davies authored and JoshRosen committed Sep 13, 2014
1 parent 0f8c4ed commit 2aea0da
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 51 deletions.
8 changes: 8 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ class SparkEnv (
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}

private[spark]
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
}
}
}

object SparkEnv extends Logging {
Expand Down
58 changes: 46 additions & 12 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}
Expand Down Expand Up @@ -52,6 +53,7 @@ private[spark] class PythonRDD(
extends RDD[Array[Byte]](parent) {

val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

override def getPartitions = parent.partitions

Expand All @@ -63,19 +65,26 @@ private[spark] class PythonRDD(
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
if (reuse_worker) {
envVars += ("SPARK_REUSE_WORKER" -> "1")
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)

var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()

// Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception => logWarning("Failed to close worker socket", e)
if (reuse_worker && complete_cleanly) {
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
} else {
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}

Expand Down Expand Up @@ -133,6 +142,7 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
complete_cleanly = true
null
}
} catch {
Expand Down Expand Up @@ -195,29 +205,45 @@ private[spark] class PythonRDD(
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
dataOut.writeInt(cnt)
for (bid <- oldBids) {
if (!newBids.contains(bid)) {
// remove the broadcast from worker
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
}
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
oldBids.add(broadcast.id)
}
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
worker.shutdownOutput()

case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
} finally {
Try(worker.shutdownOutput()) // kill Python worker process
worker.shutdownOutput()
}
}
}
Expand Down Expand Up @@ -278,6 +304,14 @@ private object SpecialLengths {
private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")

// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
private def getWorkerBroadcasts(worker: Socket) = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
}
}

/**
* Adapter for calling SparkContext#runJob from Python.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
val idleWorkers = new mutable.Queue[Socket]()
var lastActivity = 0L
new MonitorThread().start()

var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()

Expand All @@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

def create(): Socket = {
if (useDaemon) {
synchronized {
if (idleWorkers.size > 0) {
return idleWorkers.dequeue()
}
}
createThroughDaemon()
} else {
createSimpleWorker()
Expand Down Expand Up @@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}

/**
* Monitor all the idle workers, kill them after timeout.
*/
private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {

setDaemon(true)

override def run() {
while (true) {
synchronized {
if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
cleanupIdleWorkers()
lastActivity = System.currentTimeMillis()
}
}
Thread.sleep(10000)
}
}
}

private def cleanupIdleWorkers() {
while (idleWorkers.length > 0) {
val worker = idleWorkers.dequeue()
try {
// the worker will exit after closing the socket
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}

private def stopDaemon() {
synchronized {
if (useDaemon) {
cleanupIdleWorkers()

// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
Expand All @@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}

def stopWorker(worker: Socket) {
if (useDaemon) {
if (daemon != null) {
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.flush()
daemon.getOutputStream.flush()
synchronized {
if (useDaemon) {
if (daemon != null) {
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.flush()
daemon.getOutputStream.flush()
}
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}

def releaseWorker(worker: Socket) {
if (useDaemon) {
synchronized {
lastActivity = System.currentTimeMillis()
idleWorkers.enqueue(worker)
}
} else {
// Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}
}

private object PythonWorkerFactory {
val PROCESS_WAIT_TIMEOUT_MS = 10000
val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
}
10 changes: 10 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
</td>
</tr>
<tr>
<td><code>spark.python.worker.reuse</code></td>
<td>true</td>
<td>
Reuse Python worker or not. If yes, it will use a fixed number of Python workers,
does not need to fork() a Python process for every tasks. It will be very useful
if there is large broadcast, then the broadcast will not be needed to transfered
from JVM to Python worker for every task.
</td>
</tr>
<tr>
<td><code>spark.executorEnv.[EnvironmentVariableName]</code></td>
<td>(none)</td>
Expand Down
38 changes: 19 additions & 19 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import traceback
import time
import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
Expand All @@ -46,35 +47,20 @@ def worker(sock):
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)

# Blocks until the socket is closed by draining the input stream
# until it raises an exception or returns EOF.
def waitSocketClose(sock):
try:
while True:
# Empty string is returned upon EOF (and only then).
if sock.recv(4096) == '':
return
except:
pass

# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
# Acknowledge that the fork was successful
write_int(os.getpid(), outfile)
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
exit_code = compute_real_exit_code(exc.code)
finally:
outfile.flush()
# The Scala side will close the socket upon task completion.
waitSocketClose(sock)
os._exit(compute_real_exit_code(exit_code))
if exit_code:
os._exit(exit_code)


# Cleanup zombie children
Expand Down Expand Up @@ -111,6 +97,8 @@ def handle_sigterm(*args):
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP

reuse = os.environ.get("SPARK_REUSE_WORKER")

# Initialization complete
try:
while True:
Expand Down Expand Up @@ -163,7 +151,19 @@ def handle_sigterm(*args):
# in child process
listen_sock.close()
try:
worker(sock)
# Acknowledge that the fork was successful
outfile = sock.makefile("w")
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()
while True:
worker(sock)
if not reuse:
# wait for closing
while sock.recv(1024):
pass
break
gc.collect()
except:
traceback.print_exc()
os._exit(1)
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from numpy import ndarray, float64, int64, int32, array_equal, array
from pyspark import SparkContext, RDD
from pyspark.mllib.linalg import SparseVector
from pyspark.serializers import Serializer
from pyspark.serializers import FramedSerializer


"""
Expand Down Expand Up @@ -451,18 +451,16 @@ def _serialize_rating(r):
return ba


class RatingDeserializer(Serializer):
class RatingDeserializer(FramedSerializer):

def loads(self, stream):
length = struct.unpack("!i", stream.read(4))[0]
ba = stream.read(length)
res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4)
def loads(self, string):
res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4)
return int(res[0]), int(res[1]), res[2]

def load_stream(self, stream):
while True:
try:
yield self.loads(stream)
yield self._read_with_length(stream)
except struct.error:
return
except EOFError:
Expand Down
Loading

0 comments on commit 2aea0da

Please sign in to comment.