Skip to content

Commit

Permalink
[SPARK-8962] Add Scalastyle rule to ban direct use of Class.forName; …
Browse files Browse the repository at this point in the history
…fix existing uses

This pull request adds a Scalastyle regex rule which fails the style check if `Class.forName` is used directly.  `Class.forName` always loads classes from the default / system classloader, but in a majority of cases, we should be using Spark's own `Utils.classForName` instead, which tries to load classes from the current thread's context classloader and falls back to the classloader which loaded Spark when the context classloader is not defined.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/7350)
<!-- Reviewable:end -->

Author: Josh Rosen <[email protected]>

Closes apache#7350 from JoshRosen/ban-Class.forName and squashes the following commits:

e3e96f7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into ban-Class.forName
c0b7885 [Josh Rosen] Hopefully fix the last two cases
d707ba7 [Josh Rosen] Fix uses of Class.forName that I missed in my first cleanup pass
046470d [Josh Rosen] Merge remote-tracking branch 'origin/master' into ban-Class.forName
62882ee [Josh Rosen] Fix uses of Class.forName or add exclusion.
d9abade [Josh Rosen] Add stylechecker rule to ban uses of Class.forName
  • Loading branch information
JoshRosen authored and rxin committed Jul 14, 2015
1 parent 740b034 commit 11e5c37
Show file tree
Hide file tree
Showing 49 changed files with 117 additions and 84 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private object Logging {
try {
// We use reflection here to handle the case where users remove the
// slf4j-to-jul bridge order to route their logs to JUL.
val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler")
val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler")
bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null)
val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean]
if (!installed) {
Expand Down
11 changes: 5 additions & 6 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1968,7 +1968,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
for (className <- listenerClassNames) {
// Use reflection to find the right constructor
val constructors = {
val listenerClass = Class.forName(className)
val listenerClass = Utils.classForName(className)
listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
}
val constructorTakingSparkConf = constructors.find { c =>
Expand Down Expand Up @@ -2503,7 +2503,7 @@ object SparkContext extends Logging {
"\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.")
}
val scheduler = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
} catch {
Expand All @@ -2515,7 +2515,7 @@ object SparkContext extends Logging {
}
val backend = try {
val clazz =
Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
Expand All @@ -2528,8 +2528,7 @@ object SparkContext extends Logging {

case "yarn-client" =>
val scheduler = try {
val clazz =
Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler")
val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]

Expand All @@ -2541,7 +2540,7 @@ object SparkContext extends Logging {

val backend = try {
val clazz =
Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object SparkEnv extends Logging {

// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String): T = {
val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader)
val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
Expand Down
18 changes: 2 additions & 16 deletions core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}

import org.apache.spark.Logging
import org.apache.spark.api.r.SerDe._
import org.apache.spark.util.Utils

/**
* Handler for RBackend
Expand Down Expand Up @@ -88,21 +89,6 @@ private[r] class RBackendHandler(server: RBackend)
ctx.close()
}

// Looks up a class given a class name. This function first checks the
// current class loader and if a class is not found, it looks up the class
// in the context class loader. Address [SPARK-5185]
def getStaticClass(objId: String): Class[_] = {
try {
val clsCurrent = Class.forName(objId)
clsCurrent
} catch {
// Use contextLoader if we can't find the JAR in the system class loader
case e: ClassNotFoundException =>
val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader)
clsContext
}
}

def handleMethodCall(
isStatic: Boolean,
objId: String,
Expand All @@ -113,7 +99,7 @@ private[r] class RBackendHandler(server: RBackend)
var obj: Object = null
try {
val cls = if (isStatic) {
getStaticClass(objId)
Utils.classForName(objId)
} else {
JVMObjectTracker.get(objId) match {
case None => throw new IllegalArgumentException("Object not found " + objId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.util.Utils

private[spark] class BroadcastManager(
val isDriver: Boolean,
Expand All @@ -42,7 +43,7 @@ private[spark] class BroadcastManager(
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")

broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class SparkHadoopUtil extends Logging {

private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
val statisticsDataClass =
Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
statisticsDataClass.getDeclaredMethod(methodName)
}

Expand Down Expand Up @@ -356,7 +356,7 @@ object SparkHadoopUtil {
System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
.newInstance()
.asInstanceOf[SparkHadoopUtil]
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ object SparkSubmit {
var mainClass: Class[_] = null

try {
mainClass = Class.forName(childMainClass, true, loader)
mainClass = Utils.classForName(childMainClass)
} catch {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
System.setSecurityManager(sm)

try {
Class.forName(mainClass).getMethod("main", classOf[Array[String]])
Utils.classForName(mainClass).getMethod("main", classOf[Array[String]])
.invoke(null, Array(HELP))
} catch {
case e: InvocationTargetException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ object HistoryServer extends Logging {

val providerName = conf.getOption("spark.history.provider")
.getOrElse(classOf[FsHistoryProvider].getName())
val provider = Class.forName(providerName)
val provider = Utils.classForName(providerName)
.getConstructor(classOf[SparkConf])
.newInstance(conf)
.asInstanceOf[ApplicationHistoryProvider]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private[master] class Master(
new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
.newInstance(conf, SerializationExtension(actorSystem))
.asInstanceOf[StandaloneRecoveryModeFactory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage {
*/
def fromJson(json: String): SubmitRestProtocolMessage = {
val className = parseAction(json)
val clazz = Class.forName(packagePrefix + "." + className)
val clazz = Utils.classForName(packagePrefix + "." + className)
.asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
fromJson(json, clazz)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object DriverWrapper {
Thread.currentThread.setContextClassLoader(loader)

// Delegate to supplied main class
val clazz = Class.forName(mainClass, true, loader)
val clazz = Utils.classForName(mainClass)
val mainMethod = clazz.getMethod("main", classOf[Array[String]])
mainMethod.invoke(null, extraArgs.toArray[String])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val ibmVendor = System.getProperty("java.vendor").contains("IBM")
var totalMb = 0
try {
// scalastyle:off classforname
val bean = ManagementFactory.getOperatingSystemMXBean()
if (ibmVendor) {
val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
Expand All @@ -159,6 +160,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
}
// scalastyle:on classforname
} catch {
case e: Exception => {
totalMb = 2*1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ private[spark] class Executor(
logInfo("Using REPL class URI: " + classUri)
try {
val _userClassPathFirst: java.lang.Boolean = userClassPathFirst
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[SparkConf], classOf[String],
classOf[ClassLoader], classOf[Boolean])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ private[spark] object CompressionCodec {
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codec = try {
val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
.getConstructor(classOf[SparkConf])
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
} catch {
case e: ClassNotFoundException => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}

import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.{Logging, SparkEnv, TaskContext}
import org.apache.spark.util.{Utils => SparkUtils}

private[spark]
trait SparkHadoopMapRedUtil {
Expand Down Expand Up @@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil {

private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
Class.forName(first)
SparkUtils.classForName(first)
} catch {
case e: ClassNotFoundException =>
Class.forName(second)
SparkUtils.classForName(second)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
import org.apache.spark.util.Utils

private[spark]
trait SparkHadoopMapReduceUtil {
Expand All @@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil {
isMap: Boolean,
taskId: Int,
attemptId: Int): TaskAttemptID = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
// First, attempt to use the old-style constructor that takes a boolean isMap
// (not available in YARN)
Expand All @@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil {
} catch {
case exc: NoSuchMethodException => {
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType")
.asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
taskTypeClass, if (isMap) "MAP" else "REDUCE")
Expand All @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil {

private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
Class.forName(first)
Utils.classForName(first)
} catch {
case e: ClassNotFoundException =>
Class.forName(second)
Utils.classForName(second)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.metrics
import java.util.Properties
import java.util.concurrent.TimeUnit

import org.apache.spark.util.Utils

import scala.collection.mutable

import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
Expand Down Expand Up @@ -166,7 +168,7 @@ private[spark] class MetricsSystem private (
sourceConfigs.foreach { kv =>
val classPath = kv._2.getProperty("class")
try {
val source = Class.forName(classPath).newInstance()
val source = Utils.classForName(classPath).newInstance()
registerSource(source.asInstanceOf[Source])
} catch {
case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e)
Expand All @@ -182,7 +184,7 @@ private[spark] class MetricsSystem private (
val classPath = kv._2.getProperty("class")
if (null != classPath) {
try {
val sink = Class.forName(classPath)
val sink = Utils.classForName(classPath)
.getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
.newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging {

private[spark] class SplitInfoReflections {
val inputSplitWithLocationInfo =
Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit")
val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo")
val isInMemory = splitLocationInfo.getMethod("isInMemory")
val getLocation = splitLocationInfo.getMethod("getLocation")
}
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ private[spark] object RpcEnv {
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
val rpcEnvName = conf.get("spark.rpc", "akka")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader).
newInstance().asInstanceOf[RpcEnvFactory]
Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}

def create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
extends DeserializationStream {

private val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass): Class[_] =
override def resolveClass(desc: ObjectStreamClass): Class[_] = {
// scalastyle:off classforname
Class.forName(desc.getName, false, loader)
// scalastyle:on classforname
}
}

def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())

try {
// scalastyle:off classforname
// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
// Register classes given through spark.kryo.classesToRegister.
Expand All @@ -111,6 +112,7 @@ class KryoSerializer(conf: SparkConf)
userRegistrator
.map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
.foreach { reg => reg.registerClasses(kryo) }
// scalastyle:on classforname
} catch {
case e: Exception =>
throw new SparkException(s"Failed to register classes with Kryo", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging {

/** ObjectStreamClass$ClassDataSlot.desc field */
val DescField: Field = {
// scalastyle:off classforname
val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
// scalastyle:on classforname
f.setAccessible(true)
f
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
.getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME)

try {
val instance = Class.forName(clsName)
val instance = Utils.classForName(clsName)
.newInstance()
.asInstanceOf[ExternalBlockManager]
instance.init(blockManager, executorId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,12 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM
if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
&& argTypes(0).toString.startsWith("L") // is it an object?
&& argTypes(0).getInternalName == myName) {
// scalastyle:off classforname
output += Class.forName(
owner.replace('/', '.'),
false,
Thread.currentThread.getContextClassLoader)
// scalastyle:on classforname
}
}
}
Expand Down
Loading

0 comments on commit 11e5c37

Please sign in to comment.