Skip to content

Commit

Permalink
[spark] KUDU-1884 Add custom SASL protocol name
Browse files Browse the repository at this point in the history
Java client already supports setting custom SASL protocol names for a
KuduClient or AsyncKuduClient instance which is needed when using a
non-default service principal name. This patch exposes this setting in
KuduContext and DefaultSource.

Change-Id: Ifd0dba4f829f369c363cc89bb58650249035f356
Reviewed-on: http://gerrit.cloudera.org:8080/17328
Tested-by: Attila Bukor <[email protected]>
Reviewed-by: Alexey Serbin <[email protected]>
Reviewed-by: Grant Henke <[email protected]>
  • Loading branch information
attilabukor committed Apr 23, 2021
1 parent 1e5150b commit dc5b5bd
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class DefaultSource
val HANDLE_SCHEMA_DRIFT = "kudu.handleSchemaDrift"
val USE_DRIVER_METADATA = "kudu.useDriverMetadata"
val SNAPSHOT_TIMESTAMP_MS = "kudu.snapshotTimestampMs"
val SASL_PROTOCOL_NAME = "kudu.saslProtocolName"

/**
* A nice alias for the data source so that when specifying the format
Expand Down Expand Up @@ -109,13 +110,15 @@ class DefaultSource
val tableName = getTableName(parameters)
val kuduMaster = getMasterAddrs(parameters)
val operationType = getOperationType(parameters)
val saslProtocolName = getSaslProtocolName(parameters)
val schemaOption = Option(schema)
val readOptions = getReadOptions(parameters)
val writeOptions = getWriteOptions(parameters)

new KuduRelation(
tableName,
kuduMaster,
saslProtocolName,
operationType,
schemaOption,
readOptions,
Expand Down Expand Up @@ -157,12 +160,14 @@ class DefaultSource
val tableName = getTableName(parameters)
val masterAddrs = getMasterAddrs(parameters)
val operationType = getOperationType(parameters)
val saslProtocolName = getSaslProtocolName(parameters)
val readOptions = getReadOptions(parameters)
val writeOptions = getWriteOptions(parameters)

new KuduSink(
tableName,
masterAddrs,
saslProtocolName,
operationType,
readOptions,
writeOptions
Expand Down Expand Up @@ -227,6 +232,10 @@ class DefaultSource
parameters.getOrElse(KUDU_MASTER, InetAddress.getLocalHost.getCanonicalHostName)
}

private def getSaslProtocolName(parameters: Map[String, String]): String = {
parameters.getOrElse(SASL_PROTOCOL_NAME, "kudu")
}

private def getScanLocalityType(opParam: String): ReplicaSelection = {
opParam.toLowerCase(Locale.ENGLISH) match {
case "leader_only" => ReplicaSelection.LEADER_ONLY
Expand Down Expand Up @@ -274,6 +283,7 @@ class DefaultSource
class KuduRelation(
val tableName: String,
val masterAddrs: String,
val saslProtocolName: String,
val operationType: OperationType,
val userSchema: Option[StructType],
val readOptions: KuduReadOptions = new KuduReadOptions,
Expand All @@ -282,7 +292,7 @@ class KuduRelation(
val log: Logger = LoggerFactory.getLogger(getClass)

private val context: KuduContext =
new KuduContext(masterAddrs, sqlContext.sparkContext)
new KuduContext(masterAddrs, sqlContext.sparkContext, None, Some(saslProtocolName))

private val table: KuduTable = context.syncClient.openTable(tableName)

Expand Down Expand Up @@ -498,13 +508,14 @@ private[spark] object KuduRelation {
class KuduSink(
val tableName: String,
val masterAddrs: String,
val saslProtocolName: String,
val operationType: OperationType,
val readOptions: KuduReadOptions = new KuduReadOptions,
val writeOptions: KuduWriteOptions)(val sqlContext: SQLContext)
extends Sink {

private val context: KuduContext =
new KuduContext(masterAddrs, sqlContext.sparkContext)
new KuduContext(masterAddrs, sqlContext.sparkContext, None, Some(saslProtocolName))

override def addBatch(batchId: Long, data: DataFrame): Unit = {
context.writeRows(data, tableName, operationType, writeOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ import org.apache.kudu.Type
@InterfaceAudience.Public
@InterfaceStability.Evolving
@SerialVersionUID(1L)
class KuduContext(val kuduMaster: String, sc: SparkContext, val socketReadTimeoutMs: Option[Long])
class KuduContext(
val kuduMaster: String,
sc: SparkContext,
val socketReadTimeoutMs: Option[Long],
val saslProtocolName: Option[String] = None)
extends Serializable {
val log: Logger = LoggerFactory.getLogger(getClass)

Expand Down Expand Up @@ -149,7 +153,7 @@ class KuduContext(val kuduMaster: String, sc: SparkContext, val socketReadTimeou
@transient lazy val syncClient: KuduClient = asyncClient.syncClient()

@transient lazy val asyncClient: AsyncKuduClient = {
val c = KuduClientCache.getAsyncClient(kuduMaster)
val c = KuduClientCache.getAsyncClient(kuduMaster, saslProtocolName)
if (authnCredentials != null) {
c.importAuthenticationCredentials(authnCredentials)
}
Expand Down Expand Up @@ -607,10 +611,14 @@ private object KuduClientCache {
clientCache.clear()
}

def getAsyncClient(kuduMaster: String): AsyncKuduClient = {
def getAsyncClient(kuduMaster: String, saslProtocolName: Option[String]): AsyncKuduClient = {
clientCache.synchronized {
if (!clientCache.contains(kuduMaster)) {
val asyncClient = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster).build()
val builder = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster)
if (saslProtocolName.nonEmpty) {
builder.saslProtocolName(saslProtocolName.get)
}
val asyncClient = builder.build()
val hookHandle = new Runnable {
override def run(): Unit = asyncClient.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.kudu.spark.kudu

import java.nio.charset.StandardCharsets
import java.util

import scala.collection.JavaConverters._
import scala.collection.immutable.IndexedSeq
import org.apache.spark.SparkException
Expand All @@ -35,6 +34,7 @@ import org.apache.kudu.client.CreateTableOptions
import org.apache.kudu.test.KuduTestHarness
import org.apache.kudu.test.RandomUtils
import org.apache.kudu.spark.kudu.SparkListenerUtil.withJobTaskCounter
import org.apache.kudu.test.KuduTestHarness.EnableKerberos
import org.apache.kudu.test.KuduTestHarness.MasterServerConfig
import org.junit.Before
import org.junit.Test
Expand Down Expand Up @@ -876,4 +876,25 @@ class DefaultSourceTest extends KuduTestSuite with Matchers {
val kuduRelation = kuduRelationFromDataFrame(dataFrame)
assert(kuduRelation.sizeInBytes == 1024)
}

@Test
@EnableKerberos(principal = "oryx")
def testNonDefaultPrincipal(): Unit = {
KuduClientCache.clearCacheForTests()
val exception = intercept[Exception] {
val df = sqlContext.read.options(kuduOptions).format("kudu").load
df.count()
}
assertTrue(exception.getCause.getMessage.contains("this client is not authenticated"))

KuduClientCache.clearCacheForTests()
kuduOptions = Map(
"kudu.table" -> tableName,
"kudu.master" -> harness.getMasterAddressesAsString,
"kudu.saslProtocolName" -> "oryx"
)

val df = sqlContext.read.options(kuduOptions).format("kudu").load
assertEquals(rowCount, df.count())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ trait KuduTestSuite {
@Before
def setUpBase(): Unit = {
ss = SparkSession.builder().config(conf).getOrCreate()
kuduContext = new KuduContext(harness.getMasterAddressesAsString, ss.sparkContext)
kuduContext = new KuduContext(
harness.getMasterAddressesAsString,
ss.sparkContext,
None,
Some(harness.getPrincipal()))

// Spark tests should use the client from the kuduContext.
kuduClient = kuduContext.syncClient
Expand Down

0 comments on commit dc5b5bd

Please sign in to comment.