From ab4c0bc85bcdd7add792f8bcdc0f6110b088b9e3 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Tue, 6 Feb 2024 14:53:28 +0800 Subject: [PATCH] [CELEBORN-1257] Adds a secured port in Celeborn Master for secure communication with LifecycleManager ### What changes were proposed in this pull request? This adds a secured port to Celeborn Master which is used for secure communication with LifecycleManager. This is part of adding authentication support in Celeborn (see CELEBORN-1011). This change targets just adding the secured port to Master. The following items from the proposal are still pending: 1. Persisting the app secrets in Ratis. 2. Forwarding secrets to Workers and having ability for the workers to pull registration info from the Master. 3. Secured and internal port in Workers. 4. Secured communication between workers and clients. In addition, since we are supporting both secured and unsecured communication for backward compatibility and seamless rolling upgrades, there is an additional change needed. An app which registers with the Master can try to talk to the workers on unsecured ports which is a security breach. So, the workers need to know whether an app registered with Master or not and for that Master has to propagate list of un-secured apps to Celeborn workers as well. We can discuss this more with https://issues.apache.org/jira/browse/CELEBORN-1261 ### Why are the changes needed? It is needed for adding authentication support to Celeborn (CELEBORN-1011) ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added a simple UT. Closes #2281 from otterc/CELEBORN-1257. Authored-by: Chandni Singh Signed-off-by: zky.zhoukeyong --- .../celeborn/client/LifecycleManager.scala | 74 +++++++++++- .../celeborn/common/client/MasterClient.java | 33 +++++- .../common/protocol/RpcNameConstants.java | 4 + .../apache/celeborn/common/CelebornConf.scala | 103 ++++++++++++++-- .../common/client/MasterClientSuiteJ.java | 14 +-- docs/configuration/client.md | 1 + docs/configuration/ha.md | 1 + docs/configuration/master.md | 1 + docs/configuration/worker.md | 1 + .../service/deploy/master/Master.scala | 110 ++++++++++++++---- .../deploy/master/MasterArguments.scala | 12 ++ .../deploy/master/SecuredRpcEndpoint.scala | 91 +++++++++++++++ .../service/deploy/master/MasterSuite.scala | 43 ++++++- .../service/deploy/worker/Worker.scala | 3 +- 14 files changed, 444 insertions(+), 47 deletions(-) create mode 100644 master/src/main/scala/org/apache/celeborn/service/deploy/master/SecuredRpcEndpoint.scala diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 48b22734341..ad23d1dbc9c 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -17,7 +17,9 @@ package org.apache.celeborn.client +import java.lang.{Byte => JByte} import java.nio.ByteBuffer +import java.security.SecureRandom import java.util import java.util.{function, List => JList} import java.util.concurrent.{Callable, ConcurrentHashMap, LinkedBlockingQueue, ScheduledFuture, TimeUnit} @@ -41,12 +43,14 @@ import org.apache.celeborn.common.client.MasterClient import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} +import org.apache.celeborn.common.network.sasl.registration.RegistrationInfo import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP import org.apache.celeborn.common.protocol.message.ControlMessages._ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc._ import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext} +import org.apache.celeborn.common.security.{ClientSaslContextBuilder, RpcSecurityContext, RpcSecurityContextBuilder} import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Utils} // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ @@ -108,6 +112,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends .build().asInstanceOf[Cache[Int, ByteBuffer]] private val mockDestroyFailure = conf.testMockDestroySlotsFailure + private val authEnabled = conf.authEnabled @VisibleForTesting def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, ShufflePartitionLocationInfo] = @@ -159,7 +164,32 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logInfo(s"Starting LifecycleManager on ${rpcEnv.address}") - private val masterClient = new MasterClient(rpcEnv, conf) + private var masterRpcEnvInUse = rpcEnv + private var workerRpcEnvInUse = rpcEnv + if (authEnabled) { + logInfo(s"Authentication is enabled; setting up master and worker RPC environments") + val appSecret = createSecret() + val registrationInfo = new RegistrationInfo() + masterRpcEnvInUse = + RpcEnv.create( + RpcNameConstants.LIFECYCLE_MANAGER_MASTER_SYS, + lifecycleHost, + 0, + conf, + createRpcSecurityContext( + appSecret, + addClientRegistrationBootstrap = true, + Some(registrationInfo))) + workerRpcEnvInUse = + RpcEnv.create( + RpcNameConstants.LIFECYCLE_MANAGER_WORKER_SYS, + lifecycleHost, + 0, + conf, + createRpcSecurityContext(appSecret)) + } + + private val masterClient = new MasterClient(masterRpcEnvInUse, conf, false) val commitManager = new CommitManager(appUniqueId, conf, this) val workerStatusTracker = new WorkerStatusTracker(conf, this) private val heartbeater = @@ -214,6 +244,36 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends rpcEnv.shutdown() rpcEnv.awaitTermination() } + if (authEnabled) { + if (masterRpcEnvInUse != null) { + masterRpcEnvInUse.shutdown() + masterRpcEnvInUse.awaitTermination() + } + if (workerRpcEnvInUse != null) { + workerRpcEnvInUse.shutdown() + workerRpcEnvInUse.awaitTermination() + } + } + } + + /** + * Creates security context for external RPC endpoint. + */ + def createRpcSecurityContext( + appSecret: String, + addClientRegistrationBootstrap: Boolean = false, + registrationInfo: Option[RegistrationInfo] = None): Option[RpcSecurityContext] = { + val clientSaslContextBuilder = new ClientSaslContextBuilder() + .withAddRegistrationBootstrap(addClientRegistrationBootstrap) + .withAppId(appUniqueId) + .withSaslUser(appUniqueId) + .withSaslPassword(appSecret) + if (registrationInfo.isDefined) { + clientSaslContextBuilder.withRegistrationInfo(registrationInfo.get) + } + val rpcSecurityContext = new RpcSecurityContextBuilder() + .withClientSaslContext(clientSaslContextBuilder.build()).build() + Some(rpcSecurityContext) } def getUserIdentifier: UserIdentifier = { @@ -356,7 +416,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends connectFailedWorkers: ShuffleFailedWorkers): Unit = { val futures = new util.LinkedList[(Future[RpcEndpointRef], WorkerInfo)]() slots.asScala foreach { case (workerInfo, _) => - val future = rpcEnv.asyncSetupEndpointRefByAddr(RpcEndpointAddress( + val future = workerRpcEnvInUse.asyncSetupEndpointRefByAddr(RpcEndpointAddress( RpcAddress.apply(workerInfo.host, workerInfo.rpcPort), WORKER_EP)) futures.add((future, workerInfo)) @@ -1065,7 +1125,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends s" ${destroyWorkerInfo.readableAddress()}, init according to partition info") try { if (!workerStatusTracker.workerExcluded(destroyWorkerInfo)) { - destroyWorkerInfo.endpoint = rpcEnv.setupEndpointRef( + destroyWorkerInfo.endpoint = workerRpcEnvInUse.setupEndpointRef( RpcAddress.apply(destroyWorkerInfo.host, destroyWorkerInfo.rpcPort), WORKER_EP) } else { @@ -1573,4 +1633,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends heartbeater.stop() super.stop() } + + private def createSecret(): String = { + val bits = 256 + val rnd = new SecureRandom() + val secretBytes = new Array[Byte](bits / JByte.SIZE) + rnd.nextBytes(secretBytes) + JavaUtils.bytesToString(ByteBuffer.wrap(secretBytes)) + } } diff --git a/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java b/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java index 83a72747781..90db88a1662 100644 --- a/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java +++ b/common/src/main/java/org/apache/celeborn/common/client/MasterClient.java @@ -58,10 +58,15 @@ public class MasterClient { private final AtomicReference rpcEndpointRef; private final ExecutorService oneWayMessageSender; + private final CelebornConf conf; + private final boolean isWorker; + private String masterEndpointName; - public MasterClient(RpcEnv rpcEnv, CelebornConf conf) { + public MasterClient(RpcEnv rpcEnv, CelebornConf conf, boolean isWorker) { this.rpcEnv = rpcEnv; - this.masterEndpoints = Arrays.asList(conf.masterEndpoints()); + this.conf = conf; + this.isWorker = isWorker; + this.masterEndpoints = resolveMasterEndpoints(); Collections.shuffle(this.masterEndpoints); this.maxRetries = Math.max(masterEndpoints.size(), conf.masterClientMaxRetries()); this.rpcTimeout = conf.masterClientRpcAskTimeout(); @@ -250,7 +255,7 @@ private RpcEndpointRef setupEndpointRef(String endpoint) { RpcEndpointRef endpointRef = null; try { endpointRef = - rpcEnv.setupEndpointRef(RpcAddress.fromHostAndPort(endpoint), RpcNameConstants.MASTER_EP); + rpcEnv.setupEndpointRef(RpcAddress.fromHostAndPort(endpoint), masterEndpointName); } catch (Exception e) { // Catch all exceptions. Because we don't care whether this exception is IOException or // TimeoutException or other exceptions, so we just try to connect to host:port, if fail, @@ -259,4 +264,26 @@ private RpcEndpointRef setupEndpointRef(String endpoint) { } return endpointRef; } + + private List resolveMasterEndpoints() { + if (isWorker) { + // For worker, we should use the internal endpoints if internal port is enabled. + if (conf.internalPortEnabled()) { + masterEndpointName = RpcNameConstants.MASTER_INTERNAL_EP; + return Arrays.asList(conf.masterInternalEndpoints()); + } else { + masterEndpointName = RpcNameConstants.MASTER_EP; + return Arrays.asList(conf.masterEndpoints()); + } + } else { + // This is for client, so we should use the secured endpoints if auth is enabled. + if (conf.authEnabled()) { + masterEndpointName = RpcNameConstants.MASTER_SECURED_EP; + return Arrays.asList(conf.masterSecuredEndpoints()); + } else { + masterEndpointName = RpcNameConstants.MASTER_EP; + return Arrays.asList(conf.masterEndpoints()); + } + } + } } diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/RpcNameConstants.java b/common/src/main/java/org/apache/celeborn/common/protocol/RpcNameConstants.java index 6c3898fdc76..fed7b827b48 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/RpcNameConstants.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/RpcNameConstants.java @@ -21,9 +21,11 @@ public class RpcNameConstants { // For Master public static String MASTER_SYS = "Master"; public static String MASTER_INTERNAL_SYS = "MasterInternal"; + public static String MASTER_SECURED_SYS = "MasterSecured"; // Master Endpoint Name public static String MASTER_EP = "MasterEndpoint"; public static String MASTER_INTERNAL_EP = "MasterInternalEndpoint"; + public static String MASTER_SECURED_EP = "MasterSecuredEndpoint"; // For Worker public static String WORKER_SYS = "Worker"; @@ -32,6 +34,8 @@ public class RpcNameConstants { // For LifecycleManager public static String LIFECYCLE_MANAGER_SYS = "LifecycleManager"; + public static String LIFECYCLE_MANAGER_MASTER_SYS = "LifecycleManagerMasterSys"; + public static String LIFECYCLE_MANAGER_WORKER_SYS = "LifecycleManagerWorkerSys"; public static String LIFECYCLE_MANAGER_EP = "LifecycleManagerEndpoint"; // For Shuffle Client diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 822fe545966..c4aad0306e4 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.concurrent.duration._ import scala.util.Try +import org.apache.celeborn.common.CelebornConf.MASTER_INTERNAL_ENDPOINTS import org.apache.celeborn.common.identity.{DefaultIdentityProvider, IdentityProvider} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.internal.config._ @@ -1133,13 +1134,44 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se // ////////////////////////////////////////////////////// // Authentication // // ////////////////////////////////////////////////////// - def authEnabled: Boolean = get(AUTH_ENABLED) + def authEnabled: Boolean = { + val authEnabled = get(AUTH_ENABLED) + val internalPortEnabled = get(INTERNAL_PORT_ENABLED) + if (authEnabled && !internalPortEnabled) { + throw new IllegalArgumentException( + s"${AUTH_ENABLED.key} is true, but ${INTERNAL_PORT_ENABLED.key} is false") + } + return authEnabled && internalPortEnabled + } + + def haMasterNodeSecuredPort(nodeId: String): Int = { + val key = HA_MASTER_NODE_SECURED_PORT.key.replace("", nodeId) + getInt(key, HA_MASTER_NODE_SECURED_PORT.defaultValue.get) + } + + def masterSecuredPort: Int = get(MASTER_SECURED_PORT) + + def masterSecuredEndpoints: Array[String] = + get(MASTER_SECURED_ENDPOINTS).toArray.map { endpoint => + Utils.parseHostPort(endpoint.replace("", Utils.localHostName(this))) match { + case (host, 0) => s"$host:${HA_MASTER_NODE_SECURED_PORT.defaultValue.get}" + case (host, port) => s"$host:$port" + } + } // ////////////////////////////////////////////////////// // Internal Port // // ////////////////////////////////////////////////////// def internalPortEnabled: Boolean = get(INTERNAL_PORT_ENABLED) + def masterInternalEndpoints: Array[String] = + get(MASTER_INTERNAL_ENDPOINTS).toArray.map { endpoint => + Utils.parseHostPort(endpoint.replace("", Utils.localHostName(this))) match { + case (host, 0) => s"$host:${HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get}" + case (host, port) => s"$host:$port" + } + } + // ////////////////////////////////////////////////////// // Rack Resolver // // ////////////////////////////////////////////////////// @@ -1147,8 +1179,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def haMasterNodeInternalPort(nodeId: String): Int = { val key = HA_MASTER_NODE_INTERNAL_PORT.key.replace("", nodeId) - val legacyKey = HA_MASTER_NODE_INTERNAL_PORT.alternatives.head._1.replace("", nodeId) - getInt(key, getInt(legacyKey, HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get)) + getInt(key, HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get) } def masterInternalPort: Int = get(MASTER_INTERNAL_PORT) @@ -4497,14 +4528,6 @@ object CelebornConf extends Logging { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("30s") - val AUTH_ENABLED: ConfigEntry[Boolean] = - buildConf("celeborn.auth.enabled") - .categories("auth") - .version("0.5.0") - .doc("Whether to enable authentication.") - .booleanConf - .createWithDefault(false) - val INTERNAL_PORT_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.internal.port.enabled") .categories("master", "worker") @@ -4516,6 +4539,15 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val AUTH_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.auth.enabled") + .categories("auth") + .version("0.5.0") + .doc("Whether to enable authentication. Authentication will be enabled only when " + + s"${INTERNAL_PORT_ENABLED.key} is enabled as well.") + .booleanConf + .createWithDefault(false) + val MASTER_INTERNAL_PORT: ConfigEntry[Int] = buildConf("celeborn.master.internal.port") .categories("master") @@ -4536,6 +4568,20 @@ object CelebornConf extends Logging { .checkValue(p => p >= 1024 && p < 65535, "Invalid port") .createWithDefault(8097) + val MASTER_INTERNAL_ENDPOINTS: ConfigEntry[Seq[String]] = + buildConf("celeborn.master.internal.endpoints") + .categories("worker") + .doc("Endpoints of master nodes just for celeborn workers to connect, allowed pattern " + + "is: `:[,:]*`, e.g. `clb1:8097,clb2:8097,clb3:8097`. " + + "If the port is omitted, 8097 will be used.") + .version("0.5.0") + .stringConf + .toSequence + .checkValue( + endpoints => endpoints.map(_ => Try(Utils.parseHostPort(_))).forall(_.isSuccess), + "Allowed pattern is: `:[,:]*`") + .createWithDefaultString(s":8097") + val RACKRESOLVER_REFRESH_INTERVAL: ConfigEntry[Long] = buildConf("celeborn.master.rackResolver.refresh.interval") .categories("master") @@ -4543,4 +4589,39 @@ object CelebornConf extends Logging { .doc("Interval for refreshing the node rack information periodically.") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("30s") + + val MASTER_SECURED_PORT: ConfigEntry[Int] = + buildConf("celeborn.master.secured.port") + .categories("master", "auth") + .version("0.5.0") + .doc( + "Secured port on the master where clients connect.") + .intConf + .checkValue(p => p >= 1024 && p < 65535, "Invalid port") + .createWithDefault(19097) + + val HA_MASTER_NODE_SECURED_PORT: ConfigEntry[Int] = + buildConf("celeborn.master.ha.node..secured.port") + .categories("ha", "auth") + .doc( + "Secured port for the clients to bind to a master node in HA mode.") + .version("0.5.0") + .intConf + .checkValue(p => p >= 1024 && p < 65535, "Invalid port") + .createWithDefault(19097) + + val MASTER_SECURED_ENDPOINTS: ConfigEntry[Seq[String]] = + buildConf("celeborn.master.secured.endpoints") + .categories("client", "auth") + .doc("Endpoints of master nodes for celeborn client to connect for secured communication, allowed pattern " + + "is: `:[,:]*`, e.g. `clb1:19097,clb2:19097,clb3:19097`. " + + "If the port is omitted, 19097 will be used.") + .version("0.5.0") + .stringConf + .toSequence + .checkValue( + endpoints => endpoints.map(_ => Try(Utils.parseHostPort(_))).forall(_.isSuccess), + "Allowed pattern is: `:[,:]*`") + .createWithDefaultString(s":19097") + } diff --git a/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java index 4a84f7c5f52..dacb32c4148 100644 --- a/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/client/MasterClientSuiteJ.java @@ -79,7 +79,7 @@ public void testSendOneWayMessageWithoutHA() throws Exception { }); prepareForRpcEnvWithoutHA(); - MasterClient client = new MasterClient(rpcEnv, conf); + MasterClient client = new MasterClient(rpcEnv, conf, false); HeartbeatFromApplication message = Mockito.mock(HeartbeatFromApplication.class); try { @@ -106,7 +106,7 @@ public void testSendOneWayMessageWithoutHAWithRetry() throws Exception { }); prepareForRpcEnvWithoutHA(); - MasterClient client = new MasterClient(rpcEnv, conf); + MasterClient client = new MasterClient(rpcEnv, conf, false); HeartbeatFromApplication message = Mockito.mock(HeartbeatFromApplication.class); try { @@ -132,7 +132,7 @@ public void testSendOneWayMessageWithHA() throws Exception { return Future$.MODULE$.successful(response); }); - MasterClient client = new MasterClient(rpcEnv, conf); + MasterClient client = new MasterClient(rpcEnv, conf, false); HeartbeatFromApplication message = Mockito.mock(HeartbeatFromApplication.class); try { @@ -152,7 +152,7 @@ public void testSendMessageWithoutHA() { prepareForEndpointRefWithoutRetry(() -> Future$.MODULE$.successful(mockResponse)); prepareForRpcEnvWithoutHA(); - MasterClient client = new MasterClient(rpcEnv, conf); + MasterClient client = new MasterClient(rpcEnv, conf, false); HeartbeatFromWorker message = Mockito.mock(HeartbeatFromWorker.class); HeartbeatFromWorkerResponse response = null; @@ -174,7 +174,7 @@ public void testSendMessageWithoutHAWithRetry() { prepareForEndpointRefWithRetry(numTries, () -> Future$.MODULE$.successful(mockResponse)); prepareForRpcEnvWithoutHA(); - MasterClient client = new MasterClient(rpcEnv, conf); + MasterClient client = new MasterClient(rpcEnv, conf, false); HeartbeatFromWorker message = Mockito.mock(HeartbeatFromWorker.class); HeartbeatFromWorkerResponse response = null; @@ -195,7 +195,7 @@ public void testSendMessageWithHA() { prepareForRpcEnvWithHA(() -> Future$.MODULE$.successful(mockResponse)); - MasterClient client = new MasterClient(rpcEnv, conf); + MasterClient client = new MasterClient(rpcEnv, conf, false); HeartbeatFromWorker message = Mockito.mock(HeartbeatFromWorker.class); HeartbeatFromWorkerResponse response = null; @@ -254,7 +254,7 @@ private void checkOneMasterDownInHA(Exception causedByException) { .when(rpcEnv) .setupEndpointRef(Mockito.any(RpcAddress.class), Mockito.anyString()); - MasterClient client = new MasterClient(rpcEnv, conf); + MasterClient client = new MasterClient(rpcEnv, conf, false); HeartbeatFromWorker message = Mockito.mock(HeartbeatFromWorker.class); HeartbeatFromWorkerResponse response = null; diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 284edf1573c..3b4b8836f42 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -110,6 +110,7 @@ license: | | celeborn.client.spark.shuffle.forceFallback.numPartitionsThreshold | 2147483647 | Celeborn will only accept shuffle of partition number lower than this configuration value. | 0.3.0 | celeborn.shuffle.forceFallback.numPartitionsThreshold | | celeborn.client.spark.shuffle.writer | HASH | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer | | celeborn.master.endpoints | <localhost>:9097 | Endpoints of master nodes for celeborn client to connect, allowed pattern is: `:[,:]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. | 0.2.0 | | +| celeborn.master.secured.endpoints | <localhost>:19097 | Endpoints of master nodes for celeborn client to connect for secured communication, allowed pattern is: `:[,:]*`, e.g. `clb1:19097,clb2:19097,clb3:19097`. If the port is omitted, 19097 will be used. | 0.5.0 | | | celeborn.storage.availableTypes | HDD | Enabled storages. Available options: MEMORY,HDD,SSD,HDFS. Note: HDD and SSD would be treated as identical. | 0.3.0 | celeborn.storage.activeTypes | | celeborn.storage.hdfs.dir | <undefined> | HDFS base directory for Celeborn to store shuffle data. | 0.2.0 | | diff --git a/docs/configuration/ha.md b/docs/configuration/ha.md index faef0916e94..c6ab6aeeb18 100644 --- a/docs/configuration/ha.md +++ b/docs/configuration/ha.md @@ -24,6 +24,7 @@ license: | | celeborn.master.ha.node.<id>.internal.port | 8097 | Internal port for the workers and other masters to bind to a master node in HA mode. | 0.5.0 | | | celeborn.master.ha.node.<id>.port | 9097 | Port to bind of master node in HA mode. | 0.3.0 | celeborn.ha.master.node.<id>.port | | celeborn.master.ha.node.<id>.ratis.port | 9872 | Ratis port to bind of master node in HA mode. | 0.3.0 | celeborn.ha.master.node.<id>.ratis.port | +| celeborn.master.ha.node.<id>.secured.port | 19097 | Secured port for the clients to bind to a master node in HA mode. | 0.5.0 | | | celeborn.master.ha.ratis.raft.rpc.type | netty | RPC type for Ratis, available options: netty, grpc. | 0.3.0 | celeborn.ha.master.ratis.raft.rpc.type | | celeborn.master.ha.ratis.raft.server.storage.dir | /tmp/ratis | | 0.3.0 | celeborn.ha.master.ratis.raft.server.storage.dir | diff --git a/docs/configuration/master.md b/docs/configuration/master.md index 091c362f52e..a7d7fa29262 100644 --- a/docs/configuration/master.md +++ b/docs/configuration/master.md @@ -45,6 +45,7 @@ license: | | celeborn.master.internal.port | 8097 | Internal port on the master where both workers and other master nodes connect. | 0.5.0 | | | celeborn.master.port | 9097 | Port for master to bind. | 0.2.0 | | | celeborn.master.rackResolver.refresh.interval | 30s | Interval for refreshing the node rack information periodically. | 0.5.0 | | +| celeborn.master.secured.port | 19097 | Secured port on the master where clients connect. | 0.5.0 | | | celeborn.master.slot.assign.extraSlots | 2 | Extra slots number when master assign slots. | 0.3.0 | celeborn.slots.assign.extraSlots | | celeborn.master.slot.assign.loadAware.diskGroupGradient | 0.1 | This value means how many more workload will be placed into a faster disk group than a slower group. | 0.3.0 | celeborn.slots.assign.loadAware.diskGroupGradient | | celeborn.master.slot.assign.loadAware.fetchTimeWeight | 1.0 | Weight of average fetch time when calculating ordering in load-aware assignment strategy | 0.3.0 | celeborn.slots.assign.loadAware.fetchTimeWeight | diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index c792840f0c4..0424e940f4f 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -35,6 +35,7 @@ license: | | celeborn.internal.port.enabled | false | Whether to create a internal port on Masters/Workers for inter-Masters/Workers communication. This is beneficial when SASL authentication is enforced for all interactions between clients and Celeborn Services, but the services can exchange messages without being subject to SASL authentication. | 0.5.0 | | | celeborn.master.endpoints | <localhost>:9097 | Endpoints of master nodes for celeborn client to connect, allowed pattern is: `:[,:]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. | 0.2.0 | | | celeborn.master.estimatedPartitionSize.minSize | 8mb | Ignore partition size smaller than this configuration of partition size for estimation. | 0.3.0 | celeborn.shuffle.minPartitionSizeToEstimate | +| celeborn.master.internal.endpoints | <localhost>:8097 | Endpoints of master nodes just for celeborn workers to connect, allowed pattern is: `:[,:]*`, e.g. `clb1:8097,clb2:8097,clb3:8097`. If the port is omitted, 8097 will be used. | 0.5.0 | | | celeborn.shuffle.chunk.size | 8m | Max chunk size of reducer's merged shuffle data. For example, if a reducer's shuffle data is 128M and the data will need 16 fetch chunk requests to fetch. | 0.2.0 | | | celeborn.storage.availableTypes | HDD | Enabled storages. Available options: MEMORY,HDD,SSD,HDFS. Note: HDD and SSD would be treated as identical. | 0.3.0 | celeborn.storage.activeTypes | | celeborn.storage.hdfs.dir | <undefined> | HDFS base directory for Celeborn to store shuffle data. | 0.2.0 | | diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala index b6441d8e0d4..b4bcfc49bb5 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala @@ -38,11 +38,13 @@ import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{DiskInfo, WorkerInfo, WorkerStatus} import org.apache.celeborn.common.metrics.MetricsSystem import org.apache.celeborn.common.metrics.source.{JVMCPUSource, JVMSource, ResourceConsumptionSource, SystemMiscSource, ThreadPoolSource} +import org.apache.celeborn.common.network.sasl.SecretRegistryImpl import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode} import org.apache.celeborn.common.protocol.message.ControlMessages._ import org.apache.celeborn.common.quota.{QuotaManager, ResourceConsumption} import org.apache.celeborn.common.rpc._ +import org.apache.celeborn.common.security.{RpcSecurityContextBuilder, ServerSaslContextBuilder} import org.apache.celeborn.common.util.{CelebornHadoopUtils, CollectionUtils, JavaUtils, PbSerDeUtils, ThreadUtils, Utils} import org.apache.celeborn.server.common.{HttpService, Service} import org.apache.celeborn.service.deploy.master.clustermeta.SingleMasterMetaManager @@ -97,6 +99,29 @@ private[celeborn] class Master( } private val rackResolver = new CelebornRackResolver(conf) + private val authEnabled = conf.authEnabled + private val secretRegistry = new SecretRegistryImpl() + // Visible for testing + private[master] var securedRpcEnv: RpcEnv = _ + if (authEnabled) { + val externalSecurityContext = new RpcSecurityContextBuilder() + .withServerSaslContext( + new ServerSaslContextBuilder() + .withAddRegistrationBootstrap(true) + .withSecretRegistry(secretRegistry).build()).build() + + securedRpcEnv = RpcEnv.create( + RpcNameConstants.MASTER_SECURED_SYS, + masterArgs.host, + masterArgs.host, + masterArgs.securedPort, + conf, + Math.max(64, Runtime.getRuntime.availableProcessors()), + Some(externalSecurityContext)) + logInfo( + s"Secure port enabled ${masterArgs.securedPort} for secured RPC.") + } + private val statusSystem = if (conf.haEnabled) { val sys = new HAMasterMetaManager(internalRpcEnvInUse, conf, rackResolver) @@ -228,6 +253,16 @@ private[celeborn] class Master( internalRpcEndpoint) } + // Visible for testing + private[master] var securedRpcEndpoint: RpcEndpoint = _ + private var securedRpcEndpointRef: RpcEndpointRef = _ + if (authEnabled) { + securedRpcEndpoint = new SecuredRpcEndpoint(this, securedRpcEnv, conf) + securedRpcEndpointRef = securedRpcEnv.setupEndpoint( + RpcNameConstants.MASTER_SECURED_EP, + securedRpcEndpoint) + } + // start threads to check timeout for workers and applications override def onStart(): Unit = { if (!threadsStarted.compareAndSet(false, true)) { @@ -351,16 +386,22 @@ private[celeborn] class Master( requestId, shouldResponse) => logDebug(s"Received heartbeat from app $appId") - executeWithLeaderChecker( - context, - handleHeartbeatFromApplication( + if (checkAuthStatus(appId, context)) { + // TODO: [CELEBORN-1261] For the workers to be able to check whether an auth-enabled app is talking to it on + // unsecured port, Master will need to maintain a list of unauthenticated apps and send it to workers. + // This wasn't part of the original proposal because that proposal didn't target the Celeborn server to support + // both secured and unsecured communication. + executeWithLeaderChecker( context, - appId, - totalWritten, - fileCount, - needCheckedWorkerList, - requestId, - shouldResponse)) + handleHeartbeatFromApplication( + context, + appId, + totalWritten, + fileCount, + needCheckedWorkerList, + requestId, + shouldResponse)) + } case pbRegisterWorker: PbRegisterWorker => val requestId = pbRegisterWorker.getRequestId @@ -394,23 +435,32 @@ private[celeborn] class Master( // keep it for compatible reason context.reply(ReleaseSlotsResponse(StatusCode.SUCCESS)) - case requestSlots @ RequestSlots(_, _, _, _, _, _, _, _, _, _, _) => - logTrace(s"Received RequestSlots request $requestSlots.") - executeWithLeaderChecker(context, handleRequestSlots(context, requestSlots)) + case requestSlots @ RequestSlots(applicationId, _, _, _, _, _, _, _, _, _, _) => + if (checkAuthStatus(applicationId, context)) { + // TODO: [CELEBORN-1261] + logTrace(s"Received RequestSlots request $requestSlots.") + executeWithLeaderChecker(context, handleRequestSlots(context, requestSlots)) + } case pb: PbUnregisterShuffle => val applicationId = pb.getAppId val shuffleId = pb.getShuffleId val requestId = pb.getRequestId - logDebug(s"Received UnregisterShuffle request $requestId, $applicationId, $shuffleId") - executeWithLeaderChecker( - context, - handleUnregisterShuffle(context, applicationId, shuffleId, requestId)) + if (checkAuthStatus(applicationId, context)) { + // TODO: [CELEBORN-1261] + logDebug(s"Received UnregisterShuffle request $requestId, $applicationId, $shuffleId") + executeWithLeaderChecker( + context, + handleUnregisterShuffle(context, applicationId, shuffleId, requestId)) + } case ApplicationLost(appId, requestId) => - logDebug( - s"Received ApplicationLost request $requestId, $appId from ${context.senderAddress}.") - executeWithLeaderChecker(context, handleApplicationLost(context, appId, requestId)) + if (context.senderAddress.equals(self.address) || checkAuthStatus(appId, context)) { + // TODO: [CELEBORN-1261] + logDebug( + s"Received ApplicationLost request $requestId, $appId from ${context.senderAddress}.") + executeWithLeaderChecker(context, handleApplicationLost(context, appId, requestId)) + } case HeartbeatFromWorker( host, @@ -475,6 +525,8 @@ private[celeborn] class Master( handleWorkerLost(context, host, rpcPort, pushPort, fetchPort, replicatePort, requestId)) case CheckQuota(userIdentifier) => + // TODO: CheckQuota doesn't have application id in it, so we can't check auth status here. + // Will have to add application id to CheckQuota message to check auth status. executeWithLeaderChecker(context, handleCheckQuota(userIdentifier, context)) case _: PbCheckWorkersAvailable => @@ -907,7 +959,7 @@ private[celeborn] class Master( } } - private def handleHeartbeatFromApplication( + private[master] def handleHeartbeatFromApplication( context: RpcCallContext, appId: String, totalWritten: Long, @@ -1007,7 +1059,7 @@ private[celeborn] class Master( resourceConsumption } - private def handleCheckQuota( + private[master] def handleCheckQuota( userIdentifier: UserIdentifier, context: RpcCallContext): Unit = { val userResourceConsumption = handleResourceConsumption(userIdentifier) @@ -1037,6 +1089,16 @@ private[celeborn] class Master( }.asJava } + private def checkAuthStatus(appId: String, context: RpcCallContext): Boolean = { + if (conf.authEnabled && secretRegistry.isRegistered(appId)) { + context.sendFailure(new SecurityException( + s"Auth enabled application $appId sending messages on unsecured port!")) + false + } else { + true + } + } + override def getMasterGroupInfo: String = { val sb = new StringBuilder sb.append("====================== Master Group INFO ==============================\n") @@ -1242,6 +1304,9 @@ private[celeborn] class Master( if (conf.internalPortEnabled) { internalRpcEnvInUse.awaitTermination() } + if (authEnabled) { + securedRpcEnv.awaitTermination() + } } override def stop(exitKind: Int): Unit = synchronized { @@ -1251,6 +1316,9 @@ private[celeborn] class Master( if (conf.internalPortEnabled) { internalRpcEnvInUse.stop(internalRpcEndpointRef) } + if (authEnabled) { + securedRpcEnv.stop(securedRpcEndpointRef) + } super.stop(exitKind) logInfo("Master stopped.") stopped = true diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala index 701116a2e2e..95dda30b419 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala @@ -28,6 +28,7 @@ class MasterArguments(args: Array[String], conf: CelebornConf) { private var _host: Option[String] = None private var _port: Option[Int] = None private var _internalPort: Option[Int] = None + private var _securedPort: Option[Int] = None private var _propertiesFile: Option[String] = None private var _masterClusterInfo: Option[MasterClusterInfo] = None @@ -47,11 +48,15 @@ class MasterArguments(args: Array[String], conf: CelebornConf) { _internalPort = _internalPort.orElse { if (conf.internalPortEnabled) Some(conf.haMasterNodeInternalPort(localNode.nodeId)) else None } + _securedPort = _securedPort.orElse { + if (conf.authEnabled) Some(conf.haMasterNodeSecuredPort(localNode.nodeId)) else None + } _masterClusterInfo = Some(clusterInfo) } else { _host = _host.orElse(Some(conf.masterHost)) _port = _port.orElse(Some(conf.masterPort)) _internalPort = _internalPort.orElse(Some(conf.masterInternalPort)) + _securedPort = _securedPort.orElse(Some(conf.masterSecuredPort)) } def host: String = _host.get @@ -60,6 +65,8 @@ class MasterArguments(args: Array[String], conf: CelebornConf) { def internalPort: Int = _internalPort.get + def securedPort: Int = _securedPort.get + def masterClusterInfo: Option[MasterClusterInfo] = _masterClusterInfo @tailrec @@ -77,6 +84,10 @@ class MasterArguments(args: Array[String], conf: CelebornConf) { _internalPort = Some(value) parse(tail) + case ("--secured-port") :: IntParam(value) :: tail => + _securedPort = Some(value) + parse(tail) + case "--properties-file" :: value :: tail => _propertiesFile = Some(value) parse(tail) @@ -102,6 +113,7 @@ class MasterArguments(args: Array[String], conf: CelebornConf) { | -h HOST, --host HOST Hostname to listen on | -p PORT, --port PORT Port to listen on (default: 9097) | --internal-port PORT Internal port for internal communication (default: 8097) + | --secured-port PORT Secured port for secured communication (default: 19097) | --properties-file FILE Path to a custom Celeborn properties file, | default is conf/celeborn-defaults.conf. |""".stripMargin) diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/SecuredRpcEndpoint.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/SecuredRpcEndpoint.scala new file mode 100644 index 00000000000..e8b1eb96698 --- /dev/null +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/SecuredRpcEndpoint.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.master + +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.protocol.PbUnregisterShuffle +import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, CheckQuota, HeartbeatFromApplication, RequestSlots} +import org.apache.celeborn.common.rpc._ + +/** + * Secured RPC endpoint used by the Master to communicate with the Clients. + */ +private[celeborn] class SecuredRpcEndpoint( + val master: Master, + override val rpcEnv: RpcEnv, + val conf: CelebornConf) + extends RpcEndpoint with Logging { + + // start threads to check timeout for workers and applications + override def onStart(): Unit = { + master.onStart() + } + + override def onStop(): Unit = { + master.onStop() + } + + override def onDisconnected(address: RpcAddress): Unit = { + logDebug(s"Client $address got disconnected.") + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case HeartbeatFromApplication( + appId, + totalWritten, + fileCount, + needCheckedWorkerList, + requestId, + shouldResponse) => + logDebug(s"Received heartbeat from app $appId") + master.executeWithLeaderChecker( + context, + master.handleHeartbeatFromApplication( + context, + appId, + totalWritten, + fileCount, + needCheckedWorkerList, + requestId, + shouldResponse)) + + case requestSlots @ RequestSlots(_, _, _, _, _, _, _, _, _, _, _) => + logTrace(s"Received RequestSlots request $requestSlots.") + master.executeWithLeaderChecker(context, master.handleRequestSlots(context, requestSlots)) + + case pb: PbUnregisterShuffle => + val applicationId = pb.getAppId + val shuffleId = pb.getShuffleId + val requestId = pb.getRequestId + logDebug(s"Received UnregisterShuffle request $requestId, $applicationId, $shuffleId") + master.executeWithLeaderChecker( + context, + master.handleUnregisterShuffle(context, applicationId, shuffleId, requestId)) + + case ApplicationLost(appId, requestId) => + logDebug( + s"Received ApplicationLost request $requestId, $appId from ${context.senderAddress}.") + master.executeWithLeaderChecker( + context, + master.handleApplicationLost(context, appId, requestId)) + + case CheckQuota(userIdentifier) => + master.executeWithLeaderChecker(context, master.handleCheckQuota(userIdentifier, context)) + } +} diff --git a/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala b/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala index 2d3e6691297..62e733691f4 100644 --- a/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala +++ b/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala @@ -18,13 +18,14 @@ package org.apache.celeborn.service.deploy.master import com.google.common.io.Files -import org.mockito.Mockito.mock +import org.mockito.Mockito.{mock, times, verify} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.{PbCheckForWorkerTimeout, PbRegisterWorker} +import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, HeartbeatFromApplication} import org.apache.celeborn.common.util.{CelebornExitKind, Utils} class MasterSuite extends AnyFunSuite @@ -96,4 +97,44 @@ class MasterSuite extends AnyFunSuite master.rpcEnv.shutdown() master.internalRpcEnvInUse.shutdown() } + + test("test secured port receives") { + val conf = new CelebornConf() + conf.set(CelebornConf.HA_ENABLED.key, "false") + conf.set(CelebornConf.HA_MASTER_RATIS_STORAGE_DIR.key, getTmpDir()) + conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, getTmpDir()) + conf.set(CelebornConf.METRICS_ENABLED.key, "true") + conf.set(CelebornConf.INTERNAL_PORT_ENABLED.key, "true") + conf.set(CelebornConf.AUTH_ENABLED.key, "true") + + val args = + Array("-h", "localhost", "-p", "9097", "--internal-port", "8097", "--secured-port", "19097") + + val masterArgs = new MasterArguments(args, conf) + val master = new Master(conf, masterArgs) + new Thread() { + override def run(): Unit = { + master.initialize() + } + }.start() + Thread.sleep(5000L) + master.securedRpcEndpoint.receiveAndReply( + mock(classOf[org.apache.celeborn.common.rpc.RpcCallContext])) + .applyOrElse( + HeartbeatFromApplication("appId", 0L, 0L, null), + (_: Any) => fail("Unexpected message")) + master.securedRpcEndpoint.receiveAndReply( + mock(classOf[org.apache.celeborn.common.rpc.RpcCallContext])) + .applyOrElse(ApplicationLost("appId"), (_: Any) => fail("Unexpected message")) + + assertThrows[scala.MatchError] { + master.securedRpcEndpoint.receiveAndReply( + mock(classOf[org.apache.celeborn.common.rpc.RpcCallContext]))( + PbRegisterWorker.newBuilder().build()) + } + master.stop(CelebornExitKind.EXIT_IMMEDIATELY) + master.rpcEnv.shutdown() + master.internalRpcEnvInUse.shutdown() + master.securedRpcEnv.shutdown() + } } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index 46ff96f4fbe..30993f3c339 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -237,7 +237,8 @@ private[celeborn] class Worker( val shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, CommitInfo]] = JavaUtils.newConcurrentHashMap[String, ConcurrentHashMap[Long, CommitInfo]]() - private val masterClient = new MasterClient(rpcEnv, conf) + // TODO: pass the internal rpc env here when internal port is added to the worker. + private val masterClient = new MasterClient(rpcEnv, conf, true) // (workerInfo -> last connect timeout timestamp) val unavailablePeers: ConcurrentHashMap[WorkerInfo, Long] =