From 4dacf72a6d662dd6b2bf6c60669a8baad0f1566d Mon Sep 17 00:00:00 2001 From: mingji Date: Tue, 19 Dec 2023 11:44:05 +0800 Subject: [PATCH] [CELEBORN-1150] support io encryption for spark ### What changes were proposed in this pull request? 1. To support io encryption for spark. ### Why are the changes needed? Ditto. ### Does this PR introduce _any_ user-facing change? NO. ### How was this patch tested? GA and manually test on a cluster. Closes #2135 from FMX/B1150. Authored-by: mingji Signed-off-by: mingji --- client-spark/spark-2-shaded/pom.xml | 1 + .../shuffle/celeborn/SparkShuffleManager.java | 36 ++++- .../celeborn/CelebornShuffleHandle.scala | 24 +++- .../celeborn/CelebornShuffleReader.scala | 1 + .../CelebornColumnarShuffleReader.scala | 10 +- .../CelebornColumnarShuffleReaderSuite.scala | 11 +- client-spark/spark-3-shaded/pom.xml | 1 + .../shuffle/celeborn/SparkShuffleManager.java | 51 ++++++- .../spark/shuffle/celeborn/SparkUtils.java | 14 +- .../celeborn/CelebornShuffleHandle.scala | 24 +++- .../celeborn/CelebornShuffleReader.scala | 31 ++++- .../apache/celeborn/client/ShuffleClient.java | 27 ++++ .../celeborn/client/ShuffleClientImpl.java | 60 ++++++++ .../client/read/CelebornInputStream.java | 66 +++++++++ .../celeborn/client/security/CryptoUtils.java | 128 ++++++++++++++++++ .../apache/celeborn/common/CelebornConf.scala | 30 ++++ docs/configuration/client.md | 3 + .../celeborn/tests/spark/SparkTestBase.scala | 2 + 18 files changed, 498 insertions(+), 22 deletions(-) create mode 100644 client/src/main/java/org/apache/celeborn/client/security/CryptoUtils.java diff --git a/client-spark/spark-2-shaded/pom.xml b/client-spark/spark-2-shaded/pom.xml index 655e4b433c6..c21ac887139 100644 --- a/client-spark/spark-2-shaded/pom.xml +++ b/client-spark/spark-2-shaded/pom.xml @@ -73,6 +73,7 @@ io.netty:* org.apache.commons:commons-lang3 org.roaringbitmap:RoaringBitmap + org.apache.commons:commons-crypto diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 470d2e989ee..35939dd5c80 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -18,15 +18,20 @@ package org.apache.spark.shuffle.celeborn; import java.io.IOException; +import java.util.Optional; +import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; import scala.Int; +import scala.Option; import org.apache.spark.*; +import org.apache.spark.internal.config.package$; import org.apache.spark.launcher.SparkLauncher; import org.apache.spark.rdd.DeterministicLevel; +import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.shuffle.*; import org.apache.spark.shuffle.sort.SortShuffleManager; import org.apache.spark.util.Utils; @@ -35,6 +40,7 @@ import org.apache.celeborn.client.LifecycleManager; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoUtils; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.ShuffleMode; import org.apache.celeborn.common.util.ThreadUtils; @@ -99,7 +105,29 @@ private SortShuffleManager sortShuffleManager() { return _sortShuffleManager; } - private void initializeLifecycleManager(String appId) { + private Properties getIoCryptoConf() { + if (!celebornConf.sparkIoEncryptionEnabled()) return new Properties(); + Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf); + cryptoConf.put( + CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION, + conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION())); + return cryptoConf; + } + + private Optional getIoCryptoKey() { + if (!celebornConf.sparkIoEncryptionEnabled()) return Optional.empty(); + Option key = SparkEnv.get().securityManager().getIOEncryptionKey(); + return key.isEmpty() ? Optional.empty() : Optional.ofNullable(key.get()); + } + + private byte[] getIoCryptoInitializationVector() { + if (!celebornConf.sparkIoEncryptionEnabled()) return null; + return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false) + ? CryptoUtils.createIoCryptoInitializationVector() + : null; + } + + private void initializeLifecycleManager(String appId, byte[] ioCryptoInitializationVector) { // Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we // need to ensure that LifecycleManager will only be created once. Parallelism needs to be // considered in this place, because if there is one RDD that depends on multiple RDDs @@ -126,7 +154,8 @@ public ShuffleHandle registerShuffle( // is the same SparkContext among different shuffleIds. // This method may be called many times. appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context()); - initializeLifecycleManager(appUniqueId); + byte[] iv = getIoCryptoInitializationVector(); + initializeLifecycleManager(appUniqueId, iv); lifecycleManager.registerAppShuffleDeterminate( shuffleId, @@ -146,7 +175,8 @@ public ShuffleHandle registerShuffle( shuffleId, celebornConf.clientFetchThrowsFetchFailure(), numMaps, - dependency); + dependency, + iv); } } diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala index 4f67edaf325..dc9783a7cee 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala @@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C]( shuffleId: Int, val throwsFetchFailure: Boolean, numMappers: Int, - dependency: ShuffleDependency[K, V, C]) - extends BaseShuffleHandle(shuffleId, numMappers, dependency) + dependency: ShuffleDependency[K, V, C], + val ioCryptoInitializationVector: Array[Byte]) + extends BaseShuffleHandle(shuffleId, numMappers, dependency) { + def this( + appUniqueId: String, + lifecycleManagerHost: String, + lifecycleManagerPort: Int, + userIdentifier: UserIdentifier, + shuffleId: Int, + throwsFetchFailure: Boolean, + numMappers: Int, + dependency: ShuffleDependency[K, V, C]) = this( + appUniqueId, + lifecycleManagerHost, + lifecycleManagerPort, + userIdentifier, + shuffleId, + throwsFetchFailure, + numMappers, + dependency, + null) +} diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index dec30522562..56a36f0e0ab 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException +import java.util.{Optional, Properties} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index fd888fb9dc1..363f5d7eddc 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.{Optional, Properties} + import org.apache.spark.{ShuffleDependency, TaskContext} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.ShuffleReadMetricsReporter @@ -34,7 +36,9 @@ class CelebornColumnarShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + ioCryptoKey: Optional[Array[Byte]], + ioCryptoConf: Properties) extends CelebornShuffleReader[K, C]( handle, startPartition, @@ -44,7 +48,9 @@ class CelebornColumnarShuffleReader[K, C]( context, conf, metrics, - shuffleIdTracker) { + shuffleIdTracker, + ioCryptoKey, + ioCryptoConf) { override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) diff --git a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index 5df434f5432..231fdfbb550 100644 --- a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, SparkConf} import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} import org.apache.spark.sql.execution.UnsafeRowSerializer @@ -55,7 +57,9 @@ class CelebornColumnarShuffleReaderSuite { null, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty(), + null) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { if (shuffleClient != null) { @@ -78,6 +82,7 @@ class CelebornColumnarShuffleReaderSuite { 0, false, 10, + null, null), 0, 10, @@ -86,7 +91,9 @@ class CelebornColumnarShuffleReaderSuite { null, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty(), + null) val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) Mockito.when(shuffleDependency.shuffleId).thenReturn(0) Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml index c8701776dc3..f4aab4be82f 100644 --- a/client-spark/spark-3-shaded/pom.xml +++ b/client-spark/spark-3-shaded/pom.xml @@ -73,6 +73,7 @@ io.netty:* org.apache.commons:commons-lang3 org.roaringbitmap:RoaringBitmap + org.apache.commons:commons-crypto diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index a1cb458cf1c..03a38b9332a 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -18,13 +18,17 @@ package org.apache.spark.shuffle.celeborn; import java.io.IOException; +import java.util.Optional; +import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; import org.apache.spark.*; +import org.apache.spark.internal.config.package$; import org.apache.spark.launcher.SparkLauncher; import org.apache.spark.rdd.DeterministicLevel; +import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.shuffle.*; import org.apache.spark.shuffle.sort.SortShuffleManager; import org.apache.spark.sql.internal.SQLConf; @@ -34,6 +38,7 @@ import org.apache.celeborn.client.LifecycleManager; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoUtils; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.ShuffleMode; import org.apache.celeborn.common.util.ThreadUtils; @@ -130,7 +135,32 @@ private SortShuffleManager sortShuffleManager() { return _sortShuffleManager; } - private void initializeLifecycleManager() { + private Properties getIoCryptoConf() { + if (!celebornConf.sparkIoEncryptionEnabled()) return new Properties(); + Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf); + cryptoConf.put( + CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION, + conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION())); + return cryptoConf; + } + + private Optional getIoCryptoKey() { + if (!celebornConf.sparkIoEncryptionEnabled()) return Optional.empty(); + return SparkEnv.get() + .securityManager() + .getIOEncryptionKey() + .map(key -> Optional.ofNullable(key)) + .getOrElse(() -> Optional.empty()); + } + + private byte[] getIoCryptoInitializationVector() { + if (!celebornConf.sparkIoEncryptionEnabled()) return null; + return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false) + ? CryptoUtils.createIoCryptoInitializationVector() + : null; + } + + private void initializeLifecycleManager(byte[] ioCryptoInitializationVector) { // Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we // need to ensure that LifecycleManager will only be created once. Parallelism needs to be // considered in this place, because if there is one RDD that depends on multiple RDDs @@ -158,7 +188,8 @@ public ShuffleHandle registerShuffle( // is the same SparkContext among different shuffleIds. // This method may be called many times. appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context()); - initializeLifecycleManager(); + byte[] iv = getIoCryptoInitializationVector(); + initializeLifecycleManager(iv); lifecycleManager.registerAppShuffleDeterminate( shuffleId, @@ -187,7 +218,8 @@ public ShuffleHandle registerShuffle( shuffleId, celebornConf.clientFetchThrowsFetchFailure(), dependency.rdd().getNumPartitions(), - dependency); + dependency, + iv); } } @@ -242,7 +274,10 @@ public ShuffleWriter getWriter( h.lifecycleManagerHost(), h.lifecycleManagerPort(), celebornConf, - h.userIdentifier()); + h.userIdentifier(), + getIoCryptoKey(), + getIoCryptoConf(), + h.ioCryptoInitializationVector()); int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, true); shuffleIdTracker.track(h.shuffleId(), shuffleId); @@ -371,7 +406,9 @@ public ShuffleReader getCelebornShuffleReader( context, celebornConf, metrics, - shuffleIdTracker); + shuffleIdTracker, + getIoCryptoKey(), + getIoCryptoConf()); } else { return new CelebornShuffleReader<>( h, @@ -382,7 +419,9 @@ public ShuffleReader getCelebornShuffleReader( context, celebornConf, metrics, - shuffleIdTracker); + shuffleIdTracker, + getIoCryptoKey(), + getIoCryptoConf()); } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index e7a6a5b8b6c..b11dcde2950 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn; +import java.util.Optional; +import java.util.Properties; import java.util.concurrent.atomic.LongAdder; import scala.Tuple2; @@ -219,7 +221,9 @@ public static HashBasedShuffleWriter createColumnarHashBasedS TaskContext.class, CelebornConf.class, ShuffleReadMetricsReporter.class, - ExecutorShuffleIdTracker.class); + ExecutorShuffleIdTracker.class, + Optional.class, + Properties.class); public static CelebornShuffleReader createColumnarShuffleReader( CelebornShuffleHandle handle, @@ -230,7 +234,9 @@ public static CelebornShuffleReader createColumnarShuffleReader( TaskContext context, CelebornConf conf, ShuffleReadMetricsReporter metrics, - ExecutorShuffleIdTracker shuffleIdTracker) { + ExecutorShuffleIdTracker shuffleIdTracker, + Optional ioCryptoKey, + Properties ioCryptoConf) { return COLUMNAR_SHUFFLE_READER_CONSTRUCTOR_BUILDER .build() .invoke( @@ -243,7 +249,9 @@ public static CelebornShuffleReader createColumnarShuffleReader( context, conf, metrics, - shuffleIdTracker); + shuffleIdTracker, + ioCryptoKey, + ioCryptoConf); } // Added in SPARK-32920, for Spark 3.2 and above diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala index 18a3053e006..2c12282e4bd 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala @@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C]( shuffleId: Int, val throwsFetchFailure: Boolean, val numMappers: Int, - dependency: ShuffleDependency[K, V, C]) - extends BaseShuffleHandle(shuffleId, dependency) + dependency: ShuffleDependency[K, V, C], + val ioCryptoInitializationVector: Array[Byte]) + extends BaseShuffleHandle(shuffleId, dependency) { + def this( + appUniqueId: String, + lifecycleManagerHost: String, + lifecycleManagerPort: Int, + userIdentifier: UserIdentifier, + shuffleId: Int, + throwsFetchFailure: Boolean, + numMappers: Int, + dependency: ShuffleDependency[K, V, C]) = this( + appUniqueId, + lifecycleManagerHost, + lifecycleManagerPort, + userIdentifier, + shuffleId, + throwsFetchFailure, + numMappers, + dependency, + null) +} diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index fe7af83091a..ceb3639b4ef 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException +import java.util.{Optional, Properties} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference @@ -44,16 +45,42 @@ class CelebornShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + ioCryptoKey: Optional[Array[Byte]], + ioCryptoConf: Properties) extends ShuffleReader[K, C] with Logging { + def this( + handle: CelebornShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + startMapIndex: Int, + endMapIndex: Int, + context: TaskContext, + conf: CelebornConf, + metrics: ShuffleReadMetricsReporter) = this( + handle, + startPartition, + endPartition, + startMapIndex, + endMapIndex, + context, + conf, + metrics, + null, + Optional.empty(), + null) + private val dep = handle.dependency private val shuffleClient = ShuffleClient.get( handle.appUniqueId, handle.lifecycleManagerHost, handle.lifecycleManagerPort, conf, - handle.userIdentifier) + handle.userIdentifier, + ioCryptoKey, + ioCryptoConf, + handle.ioCryptoInitializationVector) private val exceptionRef = new AtomicReference[IOException] diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 72230a536de..b5fc2ec6bb7 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -18,6 +18,8 @@ package org.apache.celeborn.client; import java.io.IOException; +import java.util.Optional; +import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -61,6 +63,26 @@ public static ShuffleClient get( int port, CelebornConf conf, UserIdentifier userIdentifier) { + return ShuffleClient.get( + appUniqueId, + driverHost, + port, + conf, + userIdentifier, + Optional.empty(), + new Properties(), + null); + } + + public static ShuffleClient get( + String appUniqueId, + String driverHost, + int port, + CelebornConf conf, + UserIdentifier userIdentifier, + Optional ioCryptoKey, + Properties ioCryptoConf, + byte[] ioCryptoInitializationVector) { if (null == _instance || !initialized) { synchronized (ShuffleClient.class) { if (null == _instance) { @@ -72,11 +94,13 @@ public static ShuffleClient get( // when communicating with LifecycleManager, it will cause a NullPointerException. _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); + _instance.setupIoCrypto(ioCryptoKey, ioCryptoConf, ioCryptoInitializationVector); initialized = true; } else if (!initialized) { _instance.shutdown(); _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); + _instance.setupIoCrypto(ioCryptoKey, ioCryptoConf, ioCryptoInitializationVector); initialized = true; } } @@ -118,6 +142,9 @@ public static void printReadStats(Logger logger) { String.format("%.2f", (localReadCount * 1.0d / totalReadCount) * 100)); } + public void setupIoCrypto( + Optional ioCryptoKey, Properties ioCryptoConf, byte[] ioCryptoInitializationVector) {} + public abstract void setupLifecycleManagerRef(String host, int port); public abstract void setupLifecycleManagerRef(RpcEndpointRef endpointRef); diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index c5f463b1969..d83e4283a0e 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -30,6 +30,7 @@ import com.google.common.annotations.VisibleForTesting; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; +import org.apache.commons.crypto.cipher.CryptoCipher; import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,6 +38,7 @@ import org.apache.celeborn.client.compress.Compressor; import org.apache.celeborn.client.read.CelebornInputStream; import org.apache.celeborn.client.read.MetricsCallback; +import org.apache.celeborn.client.security.CryptoUtils; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.identity.UserIdentifier; @@ -157,6 +159,29 @@ public void update(ReduceFileGroups fileGroups) { protected final Map reduceFileGroupsMap = JavaUtils.newConcurrentHashMap(); + protected Optional ioCryptoKey = Optional.empty(); + + protected Properties ioCryptoConf; + + protected byte[] ioCyrptoInitializationVector; + + private ThreadLocal encipherThreadLocal = + new ThreadLocal() { + @Override + protected CryptoCipher initialValue() { + CryptoCipher cryptoCipher = null; + if (ioCryptoKey.isPresent()) { + try { + cryptoCipher = + CryptoUtils.getEncipher(ioCryptoKey, ioCryptoConf, ioCyrptoInitializationVector); + } catch (IOException e) { + logger.error("Failed to init crypto", e); + } + } + return cryptoCipher; + } + }; + public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier userIdentifier) { super(); this.appUniqueId = appUniqueId; @@ -878,6 +903,23 @@ public int pushOrMergeData( // increment batchId final int nextBatchId = pushState.nextBatchId(); + if (ioCryptoKey.isPresent()) { + CryptoCipher encipher = encipherThreadLocal.get(); + byte[] encryptData = new byte[length + encipher.getBlockSize()]; + int encryptLength = CryptoUtils.encrypt(encipher, data, offset, length, encryptData); + logger.debug( + "Push data encryption encryptLength/beforeLength {}/{} for shuffle {} map {} attempt {} partition {}.", + encryptLength, + length, + shuffleId, + mapId, + attemptId, + partitionId); + length = encryptLength; + data = encryptData; + offset = 0; + } + if (shuffleCompressionEnabled) { // compress data final Compressor compressor = compressorThreadLocal.get(); @@ -1651,6 +1693,9 @@ public CelebornInputStream readPartition( startMapIndex, endMapIndex, fetchExcludedWorkers, + ioCryptoKey, + ioCryptoConf, + ioCyrptoInitializationVector, metricsCallback); } } @@ -1754,4 +1799,19 @@ private boolean connectFail(String message) { public TransportClientFactory getDataClientFactory() { return dataClientFactory; } + + @Override + public void setupIoCrypto( + Optional ioCryptoKey, Properties ioCryptoConf, byte[] ioCryptoInitializationVector) { + this.ioCryptoKey = ioCryptoKey; + this.ioCryptoConf = ioCryptoConf; + this.ioCyrptoInitializationVector = ioCryptoInitializationVector; + if (this.ioCryptoKey.isPresent()) { + try { + CryptoUtils.getEncipher(this.ioCryptoKey, this.ioCryptoConf, ioCryptoInitializationVector); + } catch (IOException e) { + throw new RuntimeException("Failed to init encipher", e); + } + } + } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index bb1e95ce9a9..3284c828afa 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -27,12 +27,14 @@ import com.google.common.util.concurrent.Uninterruptibles; import io.netty.buffer.ByteBuf; +import org.apache.commons.crypto.cipher.CryptoCipher; import org.roaringbitmap.RoaringBitmap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; +import org.apache.celeborn.client.security.CryptoUtils; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClientFactory; @@ -59,6 +61,37 @@ public static CelebornInputStream create( ConcurrentHashMap fetchExcludedWorkers, MetricsCallback metricsCallback) throws IOException { + return create( + conf, + clientFactory, + shuffleKey, + locations, + attempts, + attemptNumber, + startMapIndex, + endMapIndex, + fetchExcludedWorkers, + Optional.empty(), + null, + null, + metricsCallback); + } + + public static CelebornInputStream create( + CelebornConf conf, + TransportClientFactory clientFactory, + String shuffleKey, + PartitionLocation[] locations, + int[] attempts, + int attemptNumber, + int startMapIndex, + int endMapIndex, + ConcurrentHashMap fetchExcludedWorkers, + Optional ioCryptoKey, + Properties ioCryptoProp, + byte[] ioCryptoInitializationVector, + MetricsCallback metricsCallback) + throws IOException { if (locations == null || locations.length == 0) { return emptyInputStream; } else { @@ -72,6 +105,9 @@ public static CelebornInputStream create( startMapIndex, endMapIndex, fetchExcludedWorkers, + ioCryptoKey, + ioCryptoProp, + ioCryptoInitializationVector, metricsCallback); } } @@ -149,6 +185,9 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private boolean shuffleCompressionEnabled; private long fetchExcludedWorkerExpireTimeout; private final ConcurrentHashMap fetchExcludedWorkers; + private Optional encryptKey; + private Properties encryptProp; + private CryptoCipher decipher; private boolean containLocalRead = false; @@ -162,6 +201,9 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { int startMapIndex, int endMapIndex, ConcurrentHashMap fetchExcludedWorkers, + Optional ioCryptoKey, + Properties ioCryptoProp, + byte[] ioCryptoInitializationVector, MetricsCallback metricsCallback) throws IOException { this.conf = conf; @@ -202,6 +244,12 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { retryWaitMs = transportConf.ioRetryWaitTimeMs(); this.callback = metricsCallback; moveToNextReader(); + + this.encryptKey = ioCryptoKey; + this.encryptProp = ioCryptoProp; + if (ioCryptoKey.isPresent()) { + decipher = CryptoUtils.getDecipher(ioCryptoKey, ioCryptoProp, ioCryptoInitializationVector); + } } private boolean skipLocation(int startMapIndex, int endMapIndex, PartitionLocation location) { @@ -570,6 +618,7 @@ private boolean fillBuffer() throws IOException { callback.incBytesRead(BATCH_HEADER_SIZE + size); if (shuffleCompressionEnabled) { // decompress data + int originalLength = decompressor.getOriginalLen(compressedBuf); if (rawDataBuf.length < originalLength) { rawDataBuf = new byte[originalLength]; @@ -578,6 +627,23 @@ private boolean fillBuffer() throws IOException { } else { limit = size; } + + if (decipher != null) { + byte[] decryptBuf = new byte[limit]; + int decryptLength = CryptoUtils.decrypt(decipher, rawDataBuf, 0, limit, decryptBuf); + logger.debug( + "fetch data decryption shuffleKey: {}, mapId: {}, attempId: {}, batchId: {}, decryptLength/originLength: {}/{}", + shuffleKey, + mapId, + attemptId, + batchId, + decryptLength, + limit); + limit = decryptLength; + System.arraycopy(decryptBuf, 0, rawDataBuf, 0, limit); + decryptBuf = null; + } + position = 0; hasData = true; break; diff --git a/client/src/main/java/org/apache/celeborn/client/security/CryptoUtils.java b/client/src/main/java/org/apache/celeborn/client/security/CryptoUtils.java new file mode 100644 index 00000000000..5ce4672d67e --- /dev/null +++ b/client/src/main/java/org/apache/celeborn/client/security/CryptoUtils.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.security; + +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.TimeUnit; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; + +import org.apache.commons.crypto.cipher.CryptoCipher; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class CryptoUtils { + private static Logger logger = LoggerFactory.getLogger(CryptoUtils.class); + public static final int IV_LENGTH_IN_BYTES = 16; + public static final String COMMONS_CRYPTO_CONFIG_PREFIX = "commons.crypto."; + public static final String COMMONS_CRYPTO_CONFIG_TRANSFORMATION = + COMMONS_CRYPTO_CONFIG_PREFIX + "cipher.transformation"; + public static final String CRYPTO_ALGORITHM = "AES"; + + public static byte[] createIoCryptoInitializationVector() { + byte[] iv = new byte[IV_LENGTH_IN_BYTES]; + long initialIVStart = System.nanoTime(); + try { + CryptoRandomFactory.getCryptoRandom(new Properties()).nextBytes(iv); + } catch (GeneralSecurityException e) { + logger.warn("Failed to create crypto Initialization Vector", e); + iv = "1234567890123456".getBytes(); + } + long initialIVFinish = System.nanoTime(); + long initialIVTime = TimeUnit.NANOSECONDS.toMillis(initialIVFinish - initialIVStart); + if (initialIVTime > 2000) { + logger.warn( + "It costs {} milliseconds to create the Initialization Vector used by crypto", + initialIVTime); + } + return iv; + } + + public static CryptoCipher getEncipher( + Optional ioCryptoKey, Properties ioCryptoConf, byte[] ioCryptoInitializationVector) + throws IOException { + CryptoCipher encipher = null; + if (ioCryptoKey.isPresent()) { + SecretKeySpec keySpec = new SecretKeySpec(ioCryptoKey.get(), CRYPTO_ALGORITHM); + String transformation = (String) ioCryptoConf.get(COMMONS_CRYPTO_CONFIG_TRANSFORMATION); + try (final CryptoCipher _encipher = + org.apache.commons.crypto.utils.Utils.getCipherInstance(transformation, ioCryptoConf)) { + encipher = _encipher; + try { + encipher.init( + Cipher.ENCRYPT_MODE, keySpec, new IvParameterSpec(ioCryptoInitializationVector)); + } catch (GeneralSecurityException e) { + throw new IOException("Failed to init encipher", e); + } + } + } + return encipher; + } + + public static int encrypt( + CryptoCipher encipher, byte[] input, int offset, int length, byte[] output) + throws IOException { + try { + int updateBytes = encipher.update(input, offset, length, output, 0); + int finalBytes = encipher.doFinal(input, 0, 0, output, updateBytes); + return updateBytes + finalBytes; + } catch (ShortBufferException | BadPaddingException | IllegalBlockSizeException e) { + throw new IOException("Failed to encrypt", e); + } + } + + public static CryptoCipher getDecipher( + Optional key, Properties cryptoProp, byte[] cryptoInitilizationVector) + throws IOException { + CryptoCipher decipher = null; + if (key.isPresent()) { + SecretKeySpec keySpec = new SecretKeySpec(key.get(), CRYPTO_ALGORITHM); + String transformation = (String) cryptoProp.get(COMMONS_CRYPTO_CONFIG_TRANSFORMATION); + try (final CryptoCipher _decipher = + org.apache.commons.crypto.utils.Utils.getCipherInstance(transformation, cryptoProp)) { + decipher = _decipher; + try { + decipher.init( + Cipher.DECRYPT_MODE, keySpec, new IvParameterSpec(cryptoInitilizationVector)); + } catch (GeneralSecurityException e) { + throw new IOException("Failed to init encipher", e); + } + } + } + return decipher; + } + + public static int decrypt( + CryptoCipher decipher, byte[] input, int offset, int length, byte[] decoded) + throws IOException { + try { + return decipher.doFinal(input, offset, length, decoded, 0); + } catch (ShortBufferException | IllegalBlockSizeException | BadPaddingException e) { + throw new IOException("Failed to decrypt", e); + } + } +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 0c723832ceb..e89f300785f 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -758,6 +758,12 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientExcludedWorkerExpireTimeout: Long = get(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT) def clientExcludeReplicaOnFailureEnabled: Boolean = get(CLIENT_EXCLUDE_PEER_WORKER_ON_FAILURE_ENABLED) + + def sparkIoEncryptionEnabled: Boolean = get(SPARK_CLIENT_IO_ENCRYPTION_ENABLED) + def sparkIoEncryptionKey: String = get(SPARK_CLIENT_IO_ENCRYPTION_KEY) + def sparkIoEncryptionInitializationVector: String = + get(SPARK_CLIENT_IO_ENCRYPTION_INITIALIZATION_VECTOR) + def clientMrMaxPushData: Long = get(CLIENT_MR_PUSH_DATA_MAX) // ////////////////////////////////////////////////////// @@ -4239,4 +4245,28 @@ object CelebornConf extends Logging { .version("0.5.0") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("30s") + + val SPARK_CLIENT_IO_ENCRYPTION_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.spark.io.encryption.enabled") + .categories("client") + .version("0.4.0") + .doc("whether to enable io encryption") + .booleanConf + .createWithDefault(true) + + val SPARK_CLIENT_IO_ENCRYPTION_KEY: ConfigEntry[String] = + buildConf("celeborn.client.spark.io.encryption.key") + .categories("client") + .version("0.4.0") + .doc("io encryption key") + .stringConf + .createWithDefault("") + + val SPARK_CLIENT_IO_ENCRYPTION_INITIALIZATION_VECTOR: ConfigEntry[String] = + buildConf("celeborn.client.spark.io.encryption.initialization.vector") + .categories("client") + .version("0.4.0") + .doc("io encryption initialization vector") + .stringConf + .createWithDefault("") } diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 7c171ed3fc7..d77cb952161 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -102,6 +102,9 @@ license: | | celeborn.client.shuffle.register.filterExcludedWorker.enabled | false | Whether to filter excluded worker when register shuffle. | 0.4.0 | | celeborn.client.slot.assign.maxWorkers | 10000 | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 | | celeborn.client.spark.fetch.throwsFetchFailure | false | client throws FetchFailedException instead of CelebornIOException | 0.4.0 | +| celeborn.client.spark.io.encryption.enabled | true | whether to enable io encryption | 0.4.0 | +| celeborn.client.spark.io.encryption.initialization.vector | | io encryption initialization vector | 0.4.0 | +| celeborn.client.spark.io.encryption.key | | io encryption key | 0.4.0 | | celeborn.client.spark.push.sort.memory.threshold | 64m | When SortBasedPusher use memory over the threshold, will trigger push data. If the pipeline push feature is enabled (`celeborn.client.spark.push.sort.pipeline.enabled=true`), the SortBasedPusher will trigger a data push when the memory usage exceeds half of the threshold(by default, 32m). | 0.3.0 | | celeborn.client.spark.push.sort.pipeline.enabled | false | Whether to enable pipelining for sort based shuffle writer. If true, double buffering will be used to pipeline push | 0.3.0 | | celeborn.client.spark.push.unsafeRow.fastWrite.enabled | true | This is Celeborn's optimization on UnsafeRow for Spark and it's true by default. If you have changed UnsafeRow's memory layout set this to false. | 0.2.2 | diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index e2cb3c98a22..54de9b34dc8 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -59,6 +59,8 @@ trait SparkTestBase extends AnyFunSuite sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false") sparkConf.set(s"spark.${MASTER_ENDPOINTS.key}", masterInfo._1.rpcEnv.address.toString) sparkConf.set(s"spark.${SPARK_SHUFFLE_WRITER_MODE.key}", mode.toString) + sparkConf.set("spark.io.encryption.enabled", "true") + sparkConf.set("spark.io.crypto.cipher.transformation", "AES/CBC/PKCS5Padding") sparkConf }