Skip to content

Commit

Permalink
[SPARK-32047][SQL] Add JDBC connection provider disable possibility
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
At the moment there is no possibility to turn off JDBC authentication providers which exists on the classpath. This can be problematic because service providers are loaded with service loader. In this PR I've added `spark.sql.sources.disabledJdbcConnProviderList` configuration possibility (default: empty).

### Why are the changes needed?
No possibility to turn off JDBC authentication providers.

### Does this PR introduce _any_ user-facing change?
Yes, it introduces new configuration option.

### How was this patch tested?
* Existing + newly added unit tests.
* Existing integration tests.

Closes apache#29964 from gaborgsomogyi/SPARK-32047.

Authored-by: Gabor Somogyi <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
  • Loading branch information
gaborgsomogyi authored and HyukjinKwon committed Oct 12, 2020
1 parent 50b2a49 commit 4af1ac9
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2775,6 +2775,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val DISABLED_JDBC_CONN_PROVIDER_LIST =
buildConf("spark.sql.sources.disabledJdbcConnProviderList")
.internal()
.doc("Configures a list of JDBC connection providers, which are disabled. " +
"The list contains the name of the JDBC connection providers separated by comma.")
.version("3.1.0")
.stringConf
.createWithDefault("")

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -3399,6 +3408,8 @@ class SQLConf extends Serializable with Logging {

def truncateTrashEnabled: Boolean = getConf(SQLConf.TRUNCATE_TRASH_ENABLED)

def disabledJdbcConnectionProviders: String = getConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ private[jdbc] class BasicConnectionProvider extends JdbcConnectionProvider with
*/
def getAdditionalProperties(options: JDBCOptions): Properties = new Properties()

override val name: String = "basic"

override def canHandle(driver: Driver, options: Map[String, String]): Boolean = {
val jdbcOptions = new JDBCOptions(options)
jdbcOptions.keytab == null || jdbcOptions.principal == null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.security.SecurityConfigurationLock
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcConnectionProvider
import org.apache.spark.util.Utils

Expand All @@ -47,8 +48,10 @@ private[jdbc] object ConnectionProvider extends Logging {
logInfo("Loading of the provider failed with the exception:", t)
}
}
// Seems duplicate but it's needed for Scala 2.13
providers.toSeq

val disabledProviders = Utils.stringToSeq(SQLConf.get.disabledJdbcConnectionProviders)
// toSeq seems duplicate but it's needed for Scala 2.13
providers.filterNot(p => disabledProviders.contains(p.name)).toSeq
}

def create(driver: Driver, options: Map[String, String]): Connection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[sql] class DB2ConnectionProvider extends SecureConnectionProvider {
override val driverClass = "com.ibm.db2.jcc.DB2Driver"

override val name: String = "db2"

override def appEntry(driver: Driver, options: JDBCOptions): String = "JaasClient"

override def getConnection(driver: Driver, options: Map[String, String]): Connection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ private[sql] class MSSQLConnectionProvider extends SecureConnectionProvider {
override val driverClass = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
val parserMethod: String = "parseAndMergeProperties"

override val name: String = "mssql"

override def appEntry(driver: Driver, options: JDBCOptions): String = {
val configName = "jaasConfigurationName"
val appEntryDefault = "SQLJDBCDriver"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[jdbc] class MariaDBConnectionProvider extends SecureConnectionProvider {
override val driverClass = "org.mariadb.jdbc.Driver"

override val name: String = "mariadb"

override def appEntry(driver: Driver, options: JDBCOptions): String =
"Krb5ConnectorContext"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[sql] class OracleConnectionProvider extends SecureConnectionProvider {
override val driverClass = "oracle.jdbc.OracleDriver"

override val name: String = "oracle"

override def appEntry(driver: Driver, options: JDBCOptions): String = "kprb5module"

override def getConnection(driver: Driver, options: Map[String, String]): Connection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
private[jdbc] class PostgresConnectionProvider extends SecureConnectionProvider {
override val driverClass = "org.postgresql.Driver"

override val name: String = "postgres"

override def appEntry(driver: Driver, options: JDBCOptions): String = {
val parseURL = driver.getClass.getMethod("parseURL", classOf[String], classOf[Properties])
val properties = parseURL.invoke(driver, options.url, null).asInstanceOf[Properties]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ import org.apache.spark.annotation.{DeveloperApi, Unstable}
@DeveloperApi
@Unstable
abstract class JdbcConnectionProvider {
/**
* Name of the service to provide JDBC connections. This name should be unique. Spark will
* internally use this name to differentiate JDBC connection providers.
*/
val name: String

/**
* Checks if this connection provider instance can handle the connection initiated by the driver.
* There must be exactly one active connection provider which can handle the connection for a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection

import javax.security.auth.login.Configuration

class ConnectionProviderSuite extends ConnectionProviderSuiteBase {
test("All built-in provides must be loaded") {
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

class ConnectionProviderSuite extends ConnectionProviderSuiteBase with SharedSparkSession {
test("All built-in providers must be loaded") {
IntentionallyFaultyConnectionProvider.constructed = false
val providers = ConnectionProvider.loadProviders()
assert(providers.exists(_.isInstanceOf[BasicConnectionProvider]))
Expand All @@ -34,6 +37,14 @@ class ConnectionProviderSuite extends ConnectionProviderSuiteBase {
assert(providers.size === 6)
}

test("Disabled provider must not be loaded") {
withSQLConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST.key -> "db2") {
val providers = ConnectionProvider.loadProviders()
assert(!providers.exists(_.isInstanceOf[DB2ConnectionProvider]))
assert(providers.size === 5)
}
}

test("Multiple security configs must be reachable") {
Configuration.setConfiguration(null)
val postgresProvider = new PostgresConnectionProvider()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ abstract class ConnectionProviderSuiteBase extends SparkFunSuite with BeforeAndA
JDBCOptions.JDBC_PRINCIPAL -> "principal"
))

override def afterEach(): Unit = {
protected override def afterEach(): Unit = {
try {
Configuration.setConfiguration(null)
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.jdbc.JdbcConnectionProvider
private class IntentionallyFaultyConnectionProvider extends JdbcConnectionProvider {
IntentionallyFaultyConnectionProvider.constructed = true
throw new IllegalArgumentException("Intentional Exception")
override val name: String = "IntentionallyFaultyConnectionProvider"
override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true
override def getConnection(driver: Driver, options: Map[String, String]): Connection = null
}
Expand Down

0 comments on commit 4af1ac9

Please sign in to comment.