Skip to content

Commit

Permalink
[CELEBORN-1257] Adds a secured port in Celeborn Master for secure com…
Browse files Browse the repository at this point in the history
…munication with LifecycleManager

### What changes were proposed in this pull request?
This adds a secured port to Celeborn Master which is used for secure communication with LifecycleManager.
This is part of adding authentication support in Celeborn (see CELEBORN-1011).

This change targets just adding the secured port to Master. The following items from the proposal are still pending:
1. Persisting the app secrets in Ratis.
2. Forwarding secrets to Workers and having ability for the workers to pull registration info from the Master.
3. Secured and internal port in Workers.
4. Secured communication between workers and clients.

In addition, since we are supporting both secured and unsecured communication for backward compatibility and seamless rolling upgrades, there is an additional change needed. An app which registers with the Master can try to talk to the workers on unsecured ports which is a security breach. So, the workers need to know whether an app registered with Master or not and for that Master has to propagate list of un-secured apps to Celeborn workers as well. We can discuss this more with https://issues.apache.org/jira/browse/CELEBORN-1261

### Why are the changes needed?
It is needed for adding authentication support to Celeborn (CELEBORN-1011)

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
Added a simple UT.

Closes apache#2281 from otterc/CELEBORN-1257.

Authored-by: Chandni Singh <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
  • Loading branch information
otterc authored and waitinfuture committed Feb 6, 2024
1 parent c3b129d commit ab4c0bc
Show file tree
Hide file tree
Showing 14 changed files with 444 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.celeborn.client

import java.lang.{Byte => JByte}
import java.nio.ByteBuffer
import java.security.SecureRandom
import java.util
import java.util.{function, List => JList}
import java.util.concurrent.{Callable, ConcurrentHashMap, LinkedBlockingQueue, ScheduledFuture, TimeUnit}
Expand All @@ -41,12 +43,14 @@ import org.apache.celeborn.common.client.MasterClient
import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.network.sasl.registration.RegistrationInfo
import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP
import org.apache.celeborn.common.protocol.message.ControlMessages._
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc._
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
import org.apache.celeborn.common.security.{ClientSaslContextBuilder, RpcSecurityContext, RpcSecurityContextBuilder}
import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
Expand Down Expand Up @@ -108,6 +112,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
.build().asInstanceOf[Cache[Int, ByteBuffer]]

private val mockDestroyFailure = conf.testMockDestroySlotsFailure
private val authEnabled = conf.authEnabled

@VisibleForTesting
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, ShufflePartitionLocationInfo] =
Expand Down Expand Up @@ -159,7 +164,32 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends

logInfo(s"Starting LifecycleManager on ${rpcEnv.address}")

private val masterClient = new MasterClient(rpcEnv, conf)
private var masterRpcEnvInUse = rpcEnv
private var workerRpcEnvInUse = rpcEnv
if (authEnabled) {
logInfo(s"Authentication is enabled; setting up master and worker RPC environments")
val appSecret = createSecret()
val registrationInfo = new RegistrationInfo()
masterRpcEnvInUse =
RpcEnv.create(
RpcNameConstants.LIFECYCLE_MANAGER_MASTER_SYS,
lifecycleHost,
0,
conf,
createRpcSecurityContext(
appSecret,
addClientRegistrationBootstrap = true,
Some(registrationInfo)))
workerRpcEnvInUse =
RpcEnv.create(
RpcNameConstants.LIFECYCLE_MANAGER_WORKER_SYS,
lifecycleHost,
0,
conf,
createRpcSecurityContext(appSecret))
}

private val masterClient = new MasterClient(masterRpcEnvInUse, conf, false)
val commitManager = new CommitManager(appUniqueId, conf, this)
val workerStatusTracker = new WorkerStatusTracker(conf, this)
private val heartbeater =
Expand Down Expand Up @@ -214,6 +244,36 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
rpcEnv.shutdown()
rpcEnv.awaitTermination()
}
if (authEnabled) {
if (masterRpcEnvInUse != null) {
masterRpcEnvInUse.shutdown()
masterRpcEnvInUse.awaitTermination()
}
if (workerRpcEnvInUse != null) {
workerRpcEnvInUse.shutdown()
workerRpcEnvInUse.awaitTermination()
}
}
}

/**
* Creates security context for external RPC endpoint.
*/
def createRpcSecurityContext(
appSecret: String,
addClientRegistrationBootstrap: Boolean = false,
registrationInfo: Option[RegistrationInfo] = None): Option[RpcSecurityContext] = {
val clientSaslContextBuilder = new ClientSaslContextBuilder()
.withAddRegistrationBootstrap(addClientRegistrationBootstrap)
.withAppId(appUniqueId)
.withSaslUser(appUniqueId)
.withSaslPassword(appSecret)
if (registrationInfo.isDefined) {
clientSaslContextBuilder.withRegistrationInfo(registrationInfo.get)
}
val rpcSecurityContext = new RpcSecurityContextBuilder()
.withClientSaslContext(clientSaslContextBuilder.build()).build()
Some(rpcSecurityContext)
}

def getUserIdentifier: UserIdentifier = {
Expand Down Expand Up @@ -356,7 +416,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
connectFailedWorkers: ShuffleFailedWorkers): Unit = {
val futures = new util.LinkedList[(Future[RpcEndpointRef], WorkerInfo)]()
slots.asScala foreach { case (workerInfo, _) =>
val future = rpcEnv.asyncSetupEndpointRefByAddr(RpcEndpointAddress(
val future = workerRpcEnvInUse.asyncSetupEndpointRefByAddr(RpcEndpointAddress(
RpcAddress.apply(workerInfo.host, workerInfo.rpcPort),
WORKER_EP))
futures.add((future, workerInfo))
Expand Down Expand Up @@ -1065,7 +1125,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
s" ${destroyWorkerInfo.readableAddress()}, init according to partition info")
try {
if (!workerStatusTracker.workerExcluded(destroyWorkerInfo)) {
destroyWorkerInfo.endpoint = rpcEnv.setupEndpointRef(
destroyWorkerInfo.endpoint = workerRpcEnvInUse.setupEndpointRef(
RpcAddress.apply(destroyWorkerInfo.host, destroyWorkerInfo.rpcPort),
WORKER_EP)
} else {
Expand Down Expand Up @@ -1573,4 +1633,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
heartbeater.stop()
super.stop()
}

private def createSecret(): String = {
val bits = 256
val rnd = new SecureRandom()
val secretBytes = new Array[Byte](bits / JByte.SIZE)
rnd.nextBytes(secretBytes)
JavaUtils.bytesToString(ByteBuffer.wrap(secretBytes))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,15 @@ public class MasterClient {

private final AtomicReference<RpcEndpointRef> rpcEndpointRef;
private final ExecutorService oneWayMessageSender;
private final CelebornConf conf;
private final boolean isWorker;
private String masterEndpointName;

public MasterClient(RpcEnv rpcEnv, CelebornConf conf) {
public MasterClient(RpcEnv rpcEnv, CelebornConf conf, boolean isWorker) {
this.rpcEnv = rpcEnv;
this.masterEndpoints = Arrays.asList(conf.masterEndpoints());
this.conf = conf;
this.isWorker = isWorker;
this.masterEndpoints = resolveMasterEndpoints();
Collections.shuffle(this.masterEndpoints);
this.maxRetries = Math.max(masterEndpoints.size(), conf.masterClientMaxRetries());
this.rpcTimeout = conf.masterClientRpcAskTimeout();
Expand Down Expand Up @@ -250,7 +255,7 @@ private RpcEndpointRef setupEndpointRef(String endpoint) {
RpcEndpointRef endpointRef = null;
try {
endpointRef =
rpcEnv.setupEndpointRef(RpcAddress.fromHostAndPort(endpoint), RpcNameConstants.MASTER_EP);
rpcEnv.setupEndpointRef(RpcAddress.fromHostAndPort(endpoint), masterEndpointName);
} catch (Exception e) {
// Catch all exceptions. Because we don't care whether this exception is IOException or
// TimeoutException or other exceptions, so we just try to connect to host:port, if fail,
Expand All @@ -259,4 +264,26 @@ private RpcEndpointRef setupEndpointRef(String endpoint) {
}
return endpointRef;
}

private List<String> resolveMasterEndpoints() {
if (isWorker) {
// For worker, we should use the internal endpoints if internal port is enabled.
if (conf.internalPortEnabled()) {
masterEndpointName = RpcNameConstants.MASTER_INTERNAL_EP;
return Arrays.asList(conf.masterInternalEndpoints());
} else {
masterEndpointName = RpcNameConstants.MASTER_EP;
return Arrays.asList(conf.masterEndpoints());
}
} else {
// This is for client, so we should use the secured endpoints if auth is enabled.
if (conf.authEnabled()) {
masterEndpointName = RpcNameConstants.MASTER_SECURED_EP;
return Arrays.asList(conf.masterSecuredEndpoints());
} else {
masterEndpointName = RpcNameConstants.MASTER_EP;
return Arrays.asList(conf.masterEndpoints());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ public class RpcNameConstants {
// For Master
public static String MASTER_SYS = "Master";
public static String MASTER_INTERNAL_SYS = "MasterInternal";
public static String MASTER_SECURED_SYS = "MasterSecured";
// Master Endpoint Name
public static String MASTER_EP = "MasterEndpoint";
public static String MASTER_INTERNAL_EP = "MasterInternalEndpoint";
public static String MASTER_SECURED_EP = "MasterSecuredEndpoint";

// For Worker
public static String WORKER_SYS = "Worker";
Expand All @@ -32,6 +34,8 @@ public class RpcNameConstants {

// For LifecycleManager
public static String LIFECYCLE_MANAGER_SYS = "LifecycleManager";
public static String LIFECYCLE_MANAGER_MASTER_SYS = "LifecycleManagerMasterSys";
public static String LIFECYCLE_MANAGER_WORKER_SYS = "LifecycleManagerWorkerSys";
public static String LIFECYCLE_MANAGER_EP = "LifecycleManagerEndpoint";

// For Shuffle Client
Expand Down
103 changes: 92 additions & 11 deletions common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.mutable
import scala.concurrent.duration._
import scala.util.Try

import org.apache.celeborn.common.CelebornConf.MASTER_INTERNAL_ENDPOINTS
import org.apache.celeborn.common.identity.{DefaultIdentityProvider, IdentityProvider}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.internal.config._
Expand Down Expand Up @@ -1133,22 +1134,52 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
// //////////////////////////////////////////////////////
// Authentication //
// //////////////////////////////////////////////////////
def authEnabled: Boolean = get(AUTH_ENABLED)
def authEnabled: Boolean = {
val authEnabled = get(AUTH_ENABLED)
val internalPortEnabled = get(INTERNAL_PORT_ENABLED)
if (authEnabled && !internalPortEnabled) {
throw new IllegalArgumentException(
s"${AUTH_ENABLED.key} is true, but ${INTERNAL_PORT_ENABLED.key} is false")
}
return authEnabled && internalPortEnabled
}

def haMasterNodeSecuredPort(nodeId: String): Int = {
val key = HA_MASTER_NODE_SECURED_PORT.key.replace("<id>", nodeId)
getInt(key, HA_MASTER_NODE_SECURED_PORT.defaultValue.get)
}

def masterSecuredPort: Int = get(MASTER_SECURED_PORT)

def masterSecuredEndpoints: Array[String] =
get(MASTER_SECURED_ENDPOINTS).toArray.map { endpoint =>
Utils.parseHostPort(endpoint.replace("<localhost>", Utils.localHostName(this))) match {
case (host, 0) => s"$host:${HA_MASTER_NODE_SECURED_PORT.defaultValue.get}"
case (host, port) => s"$host:$port"
}
}

// //////////////////////////////////////////////////////
// Internal Port //
// //////////////////////////////////////////////////////
def internalPortEnabled: Boolean = get(INTERNAL_PORT_ENABLED)

def masterInternalEndpoints: Array[String] =
get(MASTER_INTERNAL_ENDPOINTS).toArray.map { endpoint =>
Utils.parseHostPort(endpoint.replace("<localhost>", Utils.localHostName(this))) match {
case (host, 0) => s"$host:${HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get}"
case (host, port) => s"$host:$port"
}
}

// //////////////////////////////////////////////////////
// Rack Resolver //
// //////////////////////////////////////////////////////
def rackResolverRefreshInterval = get(RACKRESOLVER_REFRESH_INTERVAL)

def haMasterNodeInternalPort(nodeId: String): Int = {
val key = HA_MASTER_NODE_INTERNAL_PORT.key.replace("<id>", nodeId)
val legacyKey = HA_MASTER_NODE_INTERNAL_PORT.alternatives.head._1.replace("<id>", nodeId)
getInt(key, getInt(legacyKey, HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get))
getInt(key, HA_MASTER_NODE_INTERNAL_PORT.defaultValue.get)
}

def masterInternalPort: Int = get(MASTER_INTERNAL_PORT)
Expand Down Expand Up @@ -4497,14 +4528,6 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("30s")

val AUTH_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.auth.enabled")
.categories("auth")
.version("0.5.0")
.doc("Whether to enable authentication.")
.booleanConf
.createWithDefault(false)

val INTERNAL_PORT_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.internal.port.enabled")
.categories("master", "worker")
Expand All @@ -4516,6 +4539,15 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(false)

val AUTH_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.auth.enabled")
.categories("auth")
.version("0.5.0")
.doc("Whether to enable authentication. Authentication will be enabled only when " +
s"${INTERNAL_PORT_ENABLED.key} is enabled as well.")
.booleanConf
.createWithDefault(false)

val MASTER_INTERNAL_PORT: ConfigEntry[Int] =
buildConf("celeborn.master.internal.port")
.categories("master")
Expand All @@ -4536,11 +4568,60 @@ object CelebornConf extends Logging {
.checkValue(p => p >= 1024 && p < 65535, "Invalid port")
.createWithDefault(8097)

val MASTER_INTERNAL_ENDPOINTS: ConfigEntry[Seq[String]] =
buildConf("celeborn.master.internal.endpoints")
.categories("worker")
.doc("Endpoints of master nodes just for celeborn workers to connect, allowed pattern " +
"is: `<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:8097,clb2:8097,clb3:8097`. " +
"If the port is omitted, 8097 will be used.")
.version("0.5.0")
.stringConf
.toSequence
.checkValue(
endpoints => endpoints.map(_ => Try(Utils.parseHostPort(_))).forall(_.isSuccess),
"Allowed pattern is: `<host1>:<port1>[,<host2>:<port2>]*`")
.createWithDefaultString(s"<localhost>:8097")

val RACKRESOLVER_REFRESH_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.master.rackResolver.refresh.interval")
.categories("master")
.version("0.5.0")
.doc("Interval for refreshing the node rack information periodically.")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("30s")

val MASTER_SECURED_PORT: ConfigEntry[Int] =
buildConf("celeborn.master.secured.port")
.categories("master", "auth")
.version("0.5.0")
.doc(
"Secured port on the master where clients connect.")
.intConf
.checkValue(p => p >= 1024 && p < 65535, "Invalid port")
.createWithDefault(19097)

val HA_MASTER_NODE_SECURED_PORT: ConfigEntry[Int] =
buildConf("celeborn.master.ha.node.<id>.secured.port")
.categories("ha", "auth")
.doc(
"Secured port for the clients to bind to a master node <id> in HA mode.")
.version("0.5.0")
.intConf
.checkValue(p => p >= 1024 && p < 65535, "Invalid port")
.createWithDefault(19097)

val MASTER_SECURED_ENDPOINTS: ConfigEntry[Seq[String]] =
buildConf("celeborn.master.secured.endpoints")
.categories("client", "auth")
.doc("Endpoints of master nodes for celeborn client to connect for secured communication, allowed pattern " +
"is: `<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:19097,clb2:19097,clb3:19097`. " +
"If the port is omitted, 19097 will be used.")
.version("0.5.0")
.stringConf
.toSequence
.checkValue(
endpoints => endpoints.map(_ => Try(Utils.parseHostPort(_))).forall(_.isSuccess),
"Allowed pattern is: `<host1>:<port1>[,<host2>:<port2>]*`")
.createWithDefaultString(s"<localhost>:19097")

}
Loading

0 comments on commit ab4c0bc

Please sign in to comment.