Skip to content

Commit

Permalink
[CELEBORN-1150] support io encryption for spark
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1. To support io encryption for spark.

### Why are the changes needed?
Ditto.

### Does this PR introduce _any_ user-facing change?
NO.

### How was this patch tested?
GA and manually test on a cluster.

Closes apache#2135 from FMX/B1150.

Authored-by: mingji <[email protected]>
Signed-off-by: mingji <[email protected]>
  • Loading branch information
FMX committed Dec 19, 2023
1 parent 7a58b91 commit 4dacf72
Show file tree
Hide file tree
Showing 18 changed files with 498 additions and 22 deletions.
1 change: 1 addition & 0 deletions client-spark/spark-2-shaded/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>org.apache.commons:commons-crypto</include>
</includes>
</artifactSet>
<filters>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;

import scala.Int;
import scala.Option;

import org.apache.spark.*;
import org.apache.spark.internal.config.package$;
import org.apache.spark.launcher.SparkLauncher;
import org.apache.spark.rdd.DeterministicLevel;
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.util.Utils;
Expand All @@ -35,6 +40,7 @@

import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.common.util.ThreadUtils;
Expand Down Expand Up @@ -99,7 +105,29 @@ private SortShuffleManager sortShuffleManager() {
return _sortShuffleManager;
}

private void initializeLifecycleManager(String appId) {
private Properties getIoCryptoConf() {
if (!celebornConf.sparkIoEncryptionEnabled()) return new Properties();
Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf);
cryptoConf.put(
CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION,
conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION()));
return cryptoConf;
}

private Optional<byte[]> getIoCryptoKey() {
if (!celebornConf.sparkIoEncryptionEnabled()) return Optional.empty();
Option<byte[]> key = SparkEnv.get().securityManager().getIOEncryptionKey();
return key.isEmpty() ? Optional.empty() : Optional.ofNullable(key.get());
}

private byte[] getIoCryptoInitializationVector() {
if (!celebornConf.sparkIoEncryptionEnabled()) return null;
return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false)
? CryptoUtils.createIoCryptoInitializationVector()
: null;
}

private void initializeLifecycleManager(String appId, byte[] ioCryptoInitializationVector) {
// Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we
// need to ensure that LifecycleManager will only be created once. Parallelism needs to be
// considered in this place, because if there is one RDD that depends on multiple RDDs
Expand All @@ -126,7 +154,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
// is the same SparkContext among different shuffleIds.
// This method may be called many times.
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
initializeLifecycleManager(appUniqueId);
byte[] iv = getIoCryptoInitializationVector();
initializeLifecycleManager(appUniqueId, iv);

lifecycleManager.registerAppShuffleDeterminate(
shuffleId,
Expand All @@ -146,7 +175,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
shuffleId,
celebornConf.clientFetchThrowsFetchFailure(),
numMaps,
dependency);
dependency,
iv);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
shuffleId: Int,
val throwsFetchFailure: Boolean,
numMappers: Int,
dependency: ShuffleDependency[K, V, C])
extends BaseShuffleHandle(shuffleId, numMappers, dependency)
dependency: ShuffleDependency[K, V, C],
val ioCryptoInitializationVector: Array[Byte])
extends BaseShuffleHandle(shuffleId, numMappers, dependency) {
def this(
appUniqueId: String,
lifecycleManagerHost: String,
lifecycleManagerPort: Int,
userIdentifier: UserIdentifier,
shuffleId: Int,
throwsFetchFailure: Boolean,
numMappers: Int,
dependency: ShuffleDependency[K, V, C]) = this(
appUniqueId,
lifecycleManagerHost,
lifecycleManagerPort,
userIdentifier,
shuffleId,
throwsFetchFailure,
numMappers,
dependency,
null)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.util.{Optional, Properties}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.celeborn

import java.util.{Optional, Properties}

import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
Expand All @@ -34,7 +36,9 @@ class CelebornColumnarShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
shuffleIdTracker: ExecutorShuffleIdTracker)
shuffleIdTracker: ExecutorShuffleIdTracker,
ioCryptoKey: Optional[Array[Byte]],
ioCryptoConf: Properties)
extends CelebornShuffleReader[K, C](
handle,
startPartition,
Expand All @@ -44,7 +48,9 @@ class CelebornColumnarShuffleReader[K, C](
context,
conf,
metrics,
shuffleIdTracker) {
shuffleIdTracker,
ioCryptoKey,
ioCryptoConf) {

override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.celeborn

import java.util.Optional

import org.apache.spark.{ShuffleDependency, SparkConf}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
Expand Down Expand Up @@ -55,7 +57,9 @@ class CelebornColumnarShuffleReaderSuite {
null,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
new ExecutorShuffleIdTracker(),
Optional.empty(),
null)
assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]])
} finally {
if (shuffleClient != null) {
Expand All @@ -78,6 +82,7 @@ class CelebornColumnarShuffleReaderSuite {
0,
false,
10,
null,
null),
0,
10,
Expand All @@ -86,7 +91,9 @@ class CelebornColumnarShuffleReaderSuite {
null,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
new ExecutorShuffleIdTracker(),
Optional.empty(),
null)
val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]])
Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
Expand Down
1 change: 1 addition & 0 deletions client-spark/spark-3-shaded/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>org.apache.commons:commons-crypto</include>
</includes>
</artifactSet>
<filters>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.spark.*;
import org.apache.spark.internal.config.package$;
import org.apache.spark.launcher.SparkLauncher;
import org.apache.spark.rdd.DeterministicLevel;
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.sql.internal.SQLConf;
Expand All @@ -34,6 +38,7 @@

import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.common.util.ThreadUtils;
Expand Down Expand Up @@ -130,7 +135,32 @@ private SortShuffleManager sortShuffleManager() {
return _sortShuffleManager;
}

private void initializeLifecycleManager() {
private Properties getIoCryptoConf() {
if (!celebornConf.sparkIoEncryptionEnabled()) return new Properties();
Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf);
cryptoConf.put(
CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION,
conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION()));
return cryptoConf;
}

private Optional<byte[]> getIoCryptoKey() {
if (!celebornConf.sparkIoEncryptionEnabled()) return Optional.empty();
return SparkEnv.get()
.securityManager()
.getIOEncryptionKey()
.map(key -> Optional.ofNullable(key))
.getOrElse(() -> Optional.empty());
}

private byte[] getIoCryptoInitializationVector() {
if (!celebornConf.sparkIoEncryptionEnabled()) return null;
return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false)
? CryptoUtils.createIoCryptoInitializationVector()
: null;
}

private void initializeLifecycleManager(byte[] ioCryptoInitializationVector) {
// Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we
// need to ensure that LifecycleManager will only be created once. Parallelism needs to be
// considered in this place, because if there is one RDD that depends on multiple RDDs
Expand Down Expand Up @@ -158,7 +188,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
// is the same SparkContext among different shuffleIds.
// This method may be called many times.
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
initializeLifecycleManager();
byte[] iv = getIoCryptoInitializationVector();
initializeLifecycleManager(iv);

lifecycleManager.registerAppShuffleDeterminate(
shuffleId,
Expand Down Expand Up @@ -187,7 +218,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
shuffleId,
celebornConf.clientFetchThrowsFetchFailure(),
dependency.rdd().getNumPartitions(),
dependency);
dependency,
iv);
}
}

Expand Down Expand Up @@ -242,7 +274,10 @@ public <K, V> ShuffleWriter<K, V> getWriter(
h.lifecycleManagerHost(),
h.lifecycleManagerPort(),
celebornConf,
h.userIdentifier());
h.userIdentifier(),
getIoCryptoKey(),
getIoCryptoConf(),
h.ioCryptoInitializationVector());
int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, true);
shuffleIdTracker.track(h.shuffleId(), shuffleId);

Expand Down Expand Up @@ -371,7 +406,9 @@ public <K, C> ShuffleReader<K, C> getCelebornShuffleReader(
context,
celebornConf,
metrics,
shuffleIdTracker);
shuffleIdTracker,
getIoCryptoKey(),
getIoCryptoConf());
} else {
return new CelebornShuffleReader<>(
h,
Expand All @@ -382,7 +419,9 @@ public <K, C> ShuffleReader<K, C> getCelebornShuffleReader(
context,
celebornConf,
metrics,
shuffleIdTracker);
shuffleIdTracker,
getIoCryptoKey(),
getIoCryptoConf());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.celeborn;

import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.atomic.LongAdder;

import scala.Tuple2;
Expand Down Expand Up @@ -219,7 +221,9 @@ public static <K, V, C> HashBasedShuffleWriter<K, V, C> createColumnarHashBasedS
TaskContext.class,
CelebornConf.class,
ShuffleReadMetricsReporter.class,
ExecutorShuffleIdTracker.class);
ExecutorShuffleIdTracker.class,
Optional.class,
Properties.class);

public static <K, C> CelebornShuffleReader<K, C> createColumnarShuffleReader(
CelebornShuffleHandle<K, ?, C> handle,
Expand All @@ -230,7 +234,9 @@ public static <K, C> CelebornShuffleReader<K, C> createColumnarShuffleReader(
TaskContext context,
CelebornConf conf,
ShuffleReadMetricsReporter metrics,
ExecutorShuffleIdTracker shuffleIdTracker) {
ExecutorShuffleIdTracker shuffleIdTracker,
Optional<byte[]> ioCryptoKey,
Properties ioCryptoConf) {
return COLUMNAR_SHUFFLE_READER_CONSTRUCTOR_BUILDER
.build()
.invoke(
Expand All @@ -243,7 +249,9 @@ public static <K, C> CelebornShuffleReader<K, C> createColumnarShuffleReader(
context,
conf,
metrics,
shuffleIdTracker);
shuffleIdTracker,
ioCryptoKey,
ioCryptoConf);
}

// Added in SPARK-32920, for Spark 3.2 and above
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
shuffleId: Int,
val throwsFetchFailure: Boolean,
val numMappers: Int,
dependency: ShuffleDependency[K, V, C])
extends BaseShuffleHandle(shuffleId, dependency)
dependency: ShuffleDependency[K, V, C],
val ioCryptoInitializationVector: Array[Byte])
extends BaseShuffleHandle(shuffleId, dependency) {
def this(
appUniqueId: String,
lifecycleManagerHost: String,
lifecycleManagerPort: Int,
userIdentifier: UserIdentifier,
shuffleId: Int,
throwsFetchFailure: Boolean,
numMappers: Int,
dependency: ShuffleDependency[K, V, C]) = this(
appUniqueId,
lifecycleManagerHost,
lifecycleManagerPort,
userIdentifier,
shuffleId,
throwsFetchFailure,
numMappers,
dependency,
null)
}
Loading

0 comments on commit 4dacf72

Please sign in to comment.