Skip to content

Commit

Permalink
[SPARK-6209] Clean up connections in ExecutorClassLoader after failin…
Browse files Browse the repository at this point in the history
…g to load classes (master branch PR)

ExecutorClassLoader does not ensure proper cleanup of network connections that it opens. If it fails to load a class, it may leak partially-consumed InputStreams that are connected to the REPL's HTTP class server, causing that server to exhaust its thread pool, which can cause the entire job to hang.  See [SPARK-6209](https://issues.apache.org/jira/browse/SPARK-6209) for more details, including a bug reproduction.

This patch fixes this issue by ensuring proper cleanup of these resources.  It also adds logging for unexpected error cases.

This PR is an extended version of apache#4935 and adds a regression test.

Author: Josh Rosen <[email protected]>

Closes apache#4944 from JoshRosen/executorclassloader-leak-master-branch and squashes the following commits:

e0e3c25 [Josh Rosen] Wrap try block around getReponseCode; re-enable keep-alive by closing error stream
961c284 [Josh Rosen] Roll back changes that were added to get the regression test to fail
7ee2261 [Josh Rosen] Add a failing regression test
e2d70a3 [Josh Rosen] Properly clean up after errors in ExecutorClassLoader
  • Loading branch information
JoshRosen authored and Andrew Or committed Mar 24, 2015
1 parent a8f51b8 commit 7215aa7
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 20 deletions.
5 changes: 5 additions & 0 deletions repl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<scope>test</scope>
</dependency>

<!-- Explicit listing of transitive deps that are shaded. Otherwise, odd compiler crashes. -->
<dependency>
Expand Down
85 changes: 67 additions & 18 deletions repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.repl

import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException}
import java.net.{URI, URL, URLEncoder}
import java.util.concurrent.{Executors, ExecutorService}
import java.io.{IOException, ByteArrayOutputStream, InputStream}
import java.net.{HttpURLConnection, URI, URL, URLEncoder}

import scala.util.control.NonFatal

import org.apache.hadoop.fs.{FileSystem, Path}

Expand All @@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader

val parentLoader = new ParentClassLoader(parent)

// Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
private[repl] var httpUrlConnectionTimeoutMillis: Int = -1

// Hadoop FileSystem object for our URI, if it isn't using HTTP
var fileSystem: FileSystem = {
if (Set("http", "https", "ftp").contains(uri.getScheme)) {
Expand Down Expand Up @@ -71,37 +75,82 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
}
}

private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}
val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(),
SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection]
// Set the connection timeouts (for testing purposes)
if (httpUrlConnectionTimeoutMillis != -1) {
connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
}
connection.connect()
try {
if (connection.getResponseCode != 200) {
// Close the error stream so that the connection is eligible for re-use
try {
connection.getErrorStream.close()
} catch {
case ioe: IOException =>
logError("Exception while closing error stream", ioe)
}
throw new ClassNotFoundException(s"Class file not found at URL $url")
} else {
connection.getInputStream
}
} catch {
case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
connection.disconnect()
throw e
}
}

private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory)
if (fileSystem.exists(path)) {
fileSystem.open(path)
} else {
throw new ClassNotFoundException(s"Class file not found at path $path")
}
}

def findClassLocally(name: String): Option[Class[_]] = {
val pathInDirectory = name.replace('.', '/') + ".class"
var inputStream: InputStream = null
try {
val pathInDirectory = name.replace('.', '/') + ".class"
val inputStream = {
inputStream = {
if (fileSystem != null) {
fileSystem.open(new Path(directory, pathInDirectory))
getClassFileInputStreamFromFileSystem(pathInDirectory)
} else {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}

Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager)
.getInputStream
getClassFileInputStreamFromHttpServer(pathInDirectory)
}
}
val bytes = readAndTransformClass(name, inputStream)
inputStream.close()
Some(defineClass(name, bytes, 0, bytes.length))
} catch {
case e: FileNotFoundException =>
case e: ClassNotFoundException =>
// We did not find the class
logDebug(s"Did not load class $name from REPL class server at $uri", e)
None
case e: Exception =>
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
None
} finally {
if (inputStream != null) {
try {
inputStream.close()
} catch {
case e: Exception =>
logError("Exception while closing inputStream", e)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,33 @@ package org.apache.spark.repl
import java.io.File
import java.net.{URL, URLClassLoader}

import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps

import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito._

import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark._
import org.apache.spark.util.Utils

class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
class ExecutorClassLoaderSuite
extends FunSuite
with BeforeAndAfterAll
with MockitoSugar
with Logging {

val childClassNames = List("ReplFakeClass1", "ReplFakeClass2")
val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3")
var tempDir1: File = _
var tempDir2: File = _
var url1: String = _
var urls2: Array[URL] = _
var classServer: HttpServer = _

override def beforeAll() {
super.beforeAll()
Expand All @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {

override def afterAll() {
super.afterAll()
if (classServer != null) {
classServer.stop()
}
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)
SparkEnv.set(null)
}

test("child first") {
Expand Down Expand Up @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
}
}

test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") {
// This is a regression test for SPARK-6209, a bug where each failed attempt to load a class
// from the driver's class server would leak a HTTP connection, causing the class server's
// thread / connection pool to be exhausted.
val conf = new SparkConf()
val securityManager = new SecurityManager(conf)
classServer = new HttpServer(conf, tempDir1, securityManager)
classServer.start()
// ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this
val mockEnv = mock[SparkEnv]
when(mockEnv.securityManager).thenReturn(securityManager)
SparkEnv.set(mockEnv)
// Create an ExecutorClassLoader that's configured to load classes from the HTTP server
val parentLoader = new URLClassLoader(Array.empty, null)
val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false)
classLoader.httpUrlConnectionTimeoutMillis = 500
// Check that this class loader can actually load classes that exist
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
// Try to perform a full GC now, since GC during the test might mask resource leaks
System.gc()
// When the original bug occurs, the test thread becomes blocked in a classloading call
// and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to
// shut down the HTTP server when the test times out
val interruptor: Interruptor = new Interruptor {
override def apply(thread: Thread): Unit = {
classServer.stop()
classServer = null
thread.interrupt()
}
}
def tryAndFailToLoadABunchOfClasses(): Unit = {
// The number of trials here should be much larger than Jetty's thread / connection limit
// in order to expose thread or connection leaks
for (i <- 1 to 1000) {
if (Thread.currentThread().isInterrupted) {
throw new InterruptedException()
}
// Incorporate the iteration number into the class name in order to avoid any response
// caching that might be added in the future
intercept[ClassNotFoundException] {
classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance()
}
}
}
failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
}

}

0 comments on commit 7215aa7

Please sign in to comment.