diff --git a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala index b24290fe7e6..09cdd08e256 100644 --- a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala +++ b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/Inbox.scala @@ -106,7 +106,7 @@ private[celeborn] class Inbox( } } while (true) { - safelyCall(endpoint) { + safelyCall(endpoint, endpointRef.name) { message match { case RpcMessage(_sender, content, context) => try { @@ -218,7 +218,21 @@ private[celeborn] class Inbox( /** * Calls action closure, and calls the endpoint's onError function in the case of exceptions. */ - private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + private def safelyCall( + endpoint: RpcEndpoint, + endpointRefName: String)(action: => Unit): Unit = { + def dealWithFatalError(fatal: Throwable): Unit = { + inbox.synchronized { + assert(numActiveThreads > 0, "The number of active threads should be positive.") + // Should reduce the number of active threads before throw the error. + numActiveThreads -= 1 + } + logError( + s"An error happened while processing message in the inbox for $endpointRefName", + fatal) + throw fatal + } + try action catch { case NonFatal(e) => @@ -230,8 +244,18 @@ private[celeborn] class Inbox( } else { logError("Ignoring error", ee) } + case fatal: Throwable => + dealWithFatalError(fatal) } + case fatal: Throwable => + dealWithFatalError(fatal) } } + // exposed only for testing + def getNumActiveThreads: Int = { + inbox.synchronized { + inbox.numActiveThreads + } + } } diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala new file mode 100644 index 00000000000..9bc3220857f --- /dev/null +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcAddressSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.rpc + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.exception.CelebornException + +class RpcAddressSuite extends CelebornFunSuite { + + test("hostPort") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.host === "1.2.3.4") + assert(address.port === 1234) + assert(address.hostPort === "1.2.3.4:1234") + } + + test("fromCelebornURL") { + val address = RpcAddress.fromCelebornURL("celeborn://1.2.3.4:1234") + assert(address.host === "1.2.3.4") + assert(address.port === 1234) + } + + test("fromCelebornURL: a typo url") { + val e = intercept[CelebornException] { + RpcAddress.fromCelebornURL("celeborn://1.2. 3.4:1234") + } + assert("Invalid master URL: celeborn://1.2. 3.4:1234" === e.getMessage) + } + + test("fromCelebornURL: invalid scheme") { + val e = intercept[CelebornException] { + RpcAddress.fromCelebornURL("invalid://1.2.3.4:1234") + } + assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage) + } + + test("toCelebornURL") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.toCelebornURL === "celeborn://1.2.3.4:1234") + } + +} diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala new file mode 100644 index 00000000000..4843cf744e3 --- /dev/null +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/RpcEnvSuite.scala @@ -0,0 +1,855 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.rpc + +import java.io.NotSerializableException +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} + +import scala.collection.JavaConverters.collectionAsScalaIterableConverter +import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration._ + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, never, verify} +import org.scalatest.concurrent.Eventually._ + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.exception.CelebornException +import org.apache.celeborn.common.util.ThreadUtils + +/** + * Common tests for an RpcEnv implementation. + */ +abstract class RpcEnvSuite extends CelebornFunSuite { + + var env: RpcEnv = _ + + def createCelebornConf(): CelebornConf = { + new CelebornConf() + } + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = createCelebornConf() + env = createRpcEnv(conf, "local", 0) + + } + + override def afterAll(): Unit = { + try { + if (env != null) { + env.shutdown() + } + } finally { + super.afterAll() + } + } + + def createRpcEnv(conf: CelebornConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv + + test("send a message locally") { + @volatile var message: String = null + val rpcEndpointRef = env.setupEndpoint( + "send-locally", + new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case msg: String => message = msg + } + }) + rpcEndpointRef.send("hello") + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert("hello" === message) + } + } + + test("send a message remotely") { + @volatile var message: String = null + // Set up a RpcEndpoint using env + env.setupEndpoint( + "send-remotely", + new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => message = msg + } + }) + + val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, clientMode = true) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "send-remotely") + try { + rpcEndpointRef.send("hello") + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert("hello" === message) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("send a RpcEndpointRef") { + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case "Hello" => context.reply(self) + case "Echo" => context.reply("Echo") + } + } + val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) + val newRpcEndpointRef = rpcEndpointRef.askSync[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askSync[String]("Echo") + assert("Echo" === reply) + } + + test("ask a message locally") { + val rpcEndpointRef = env.setupEndpoint( + "ask-locally", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => + context.reply(msg) + } + }) + val reply = rpcEndpointRef.askSync[String]("hello") + assert("hello" === reply) + } + + test("ask a message remotely") { + env.setupEndpoint( + "ask-remotely", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => + context.reply(msg) + } + }) + + val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, clientMode = true) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") + try { + val reply = rpcEndpointRef.askSync[String]("hello") + assert("hello" === reply) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("ask a message timeout") { + env.setupEndpoint( + "ask-timeout", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => + Thread.sleep(100) + context.reply(msg) + } + }) + + val conf = createCelebornConf() + val shortProp = "celeborn.rpc.short.timeout" + val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-timeout") + try { + val e = intercept[RpcTimeoutException] { + rpcEndpointRef.askSync[String]("hello", new RpcTimeout(1.millisecond, shortProp)) + } + // The Celeborn exception cause should be a RpcTimeoutException with message indicating the + // controlling timeout property + assert(e.isInstanceOf[RpcTimeoutException]) + assert(e.getMessage.contains(shortProp)) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("onStart and onStop") { + val stopLatch = new CountDownLatch(1) + val calledMethods = mutable.ArrayBuffer[String]() + + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + calledMethods += "start" + } + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => + } + + override def onStop(): Unit = { + calledMethods += "stop" + stopLatch.countDown() + } + } + val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint) + env.stop(rpcEndpointRef) + stopLatch.await(10, TimeUnit.SECONDS) + assert(List("start", "stop") === calledMethods) + } + + test("onError: error in onStart") { + @volatile var e: Throwable = null + env.setupEndpoint( + "onError-onStart", + new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + throw new RuntimeException("Oops!") + } + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in onStop") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint( + "onError-onStop", + new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + + override def onStop(): Unit = { + throw new RuntimeException("Oops!") + } + }) + + env.stop(endpointRef) + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in receive") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint( + "onError-receive", + new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => throw new RuntimeException("Oops!") + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert(e.getMessage === "Oops!") + } + } + + test("self: call in onStart") { + @volatile var callSelfSuccessfully = false + + env.setupEndpoint( + "self-onStart", + new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + self + callSelfSuccessfully = true + } + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + }) + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + // Calling `self` in `onStart` is fine + assert(callSelfSuccessfully) + } + } + + test("self: call in receive") { + @volatile var callSelfSuccessfully = false + + val endpointRef = env.setupEndpoint( + "self-receive", + new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + self + callSelfSuccessfully = true + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + // Calling `self` in `receive` is fine + assert(callSelfSuccessfully) + } + } + + test("self: call in onStop") { + @volatile var selfOption: Option[RpcEndpointRef] = null + + val endpointRef = env.setupEndpoint( + "self-onStop", + new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onStop(): Unit = { + selfOption = Option(self) + } + }) + + env.stop(endpointRef) + + eventually(timeout(5.seconds), interval(10.milliseconds)) { + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption.isEmpty) + } + } + + test("call receive in sequence") { + // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it + for (i <- 0 until 100) { + @volatile var result = 0 + val endpointRef = env.setupEndpoint( + s"receive-in-sequence-$i", + new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => result += 1 + } + + }) + + (0 until 10) foreach { _ => + new Thread { + override def run(): Unit = { + (0 until 100) foreach { _ => + endpointRef.send("Hello") + } + } + }.start() + } + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + assert(result == 1000) + } + + env.stop(endpointRef) + } + } + + test("stop(RpcEndpointRef) reentrant") { + @volatile var onStopCount = 0 + val endpointRef = env.setupEndpoint( + "stop-reentrant", + new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onStop(): Unit = { + onStopCount += 1 + } + }) + + env.stop(endpointRef) + env.stop(endpointRef) + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + // Calling stop twice should only trigger onStop once. + assert(onStopCount == 1) + } + } + + test("sendWithReply") { + val endpointRef = env.setupEndpoint( + "sendWithReply", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply("ack") + } + }) + + val f = endpointRef.ask[String]("Hi") + val ack = ThreadUtils.awaitResult(f, 5.seconds) + assert("ack" === ack) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely") { + env.setupEndpoint( + "sendWithReply-remotely", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply("ack") + } + }) + + val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, clientMode = true) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") + try { + val f = rpcEndpointRef.ask[String]("hello") + val ack = ThreadUtils.awaitResult(f, 5.seconds) + assert("ack" === ack) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("sendWithReply: error") { + val endpointRef = env.setupEndpoint( + "sendWithReply-error", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.sendFailure(new CelebornException("Oops")) + } + }) + + val f = endpointRef.ask[String]("Hi") + val e = intercept[CelebornException] { + ThreadUtils.awaitResult(f, 5.seconds) + } + assert("Oops" === e.getCause.getMessage) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely error") { + env.setupEndpoint( + "sendWithReply-remotely-error", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.sendFailure(new CelebornException("Oops")) + } + }) + + val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, clientMode = true) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely-error") + try { + val f = rpcEndpointRef.ask[String]("hello") + val e = intercept[CelebornException] { + ThreadUtils.awaitResult(f, 5.seconds) + } + assert("Oops" === e.getCause.getMessage) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + /** + * Setup an [[RpcEndpoint]] to collect all network events. + * + * @return the [[RpcEndpointRef]] and a `ConcurrentLinkedQueue` that contains network events. + */ + private def setupNetworkEndpoint( + _env: RpcEnv, + name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = { + val events = new ConcurrentLinkedQueue[(Any, Any)] + val ref = _env.setupEndpoint( + "network-events-non-client", + new ThreadSafeRpcEndpoint { + override val rpcEnv = _env + + override def receive: PartialFunction[Any, Unit] = { + case "hello" => + case m => events.add("receive" -> m) + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events.add("onConnected" -> remoteAddress) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events.add("onDisconnected" -> remoteAddress) + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events.add("onNetworkError" -> remoteAddress) + } + + }) + (ref, events) + } + + test("network events in sever RpcEnv when another RpcEnv is in server mode") { + val serverEnv1 = createRpcEnv(createCelebornConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(createCelebornConf(), "server2", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") + val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") + try { + val serverRefInServer2 = serverEnv1.setupEndpointRef(serverRef2.address, serverRef2.name) + // Send a message to set up the connection + serverRefInServer2.send("hello") + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + assert(events.contains(("onConnected", serverEnv2.address))) + } + + serverEnv2.shutdown() + serverEnv2.awaitTermination() + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + assert(events.contains(("onConnected", serverEnv2.address))) + assert(events.contains(("onDisconnected", serverEnv2.address))) + } + } finally { + serverEnv1.shutdown() + serverEnv2.shutdown() + serverEnv1.awaitTermination() + serverEnv2.awaitTermination() + } + } + + test("network events in sever RpcEnv when another RpcEnv is in client mode") { + val serverEnv = createRpcEnv(createCelebornConf(), "server", 0, clientMode = false) + val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") + val clientEnv = createRpcEnv(createCelebornConf(), "client", 0, clientMode = true) + try { + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + } + + clientEnv.shutdown() + clientEnv.awaitTermination() + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onDisconnected")) + } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() + } + } + + test("network events in client RpcEnv when another RpcEnv is in server mode") { + val clientEnv = createRpcEnv(createCelebornConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(createCelebornConf(), "server", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") + val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") + try { + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + assert(events.contains(("onConnected", serverEnv.address))) + } + + serverEnv.shutdown() + serverEnv.awaitTermination() + + eventually(timeout(5.seconds), interval(5.milliseconds)) { + assert(events.contains(("onConnected", serverEnv.address))) + assert(events.contains(("onDisconnected", serverEnv.address))) + } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() + } + } + + test("sendWithReply: unserializable error") { + env.setupEndpoint( + "sendWithReply-unserializable-error", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.sendFailure(new UnserializableException) + } + }) + + val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0, clientMode = true) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = + anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") + try { + val f = rpcEndpointRef.ask[String]("hello") + val e = intercept[CelebornException] { + ThreadUtils.awaitResult(f, 1.second) + } + assert(e.getCause.isInstanceOf[NotSerializableException]) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("port conflict") { + val anotherEnv = createRpcEnv(createCelebornConf(), "remote", env.address.port) + try { + assert(anotherEnv.address.port != env.address.port) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + private def testSend(conf: CelebornConf): Unit = { + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) + + try { + @volatile var message: String = null + localEnv.setupEndpoint( + "send-authentication", + new RpcEndpoint { + override val rpcEnv = localEnv + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => message = msg + } + }) + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "send-authentication") + rpcEndpointRef.send("hello") + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert("hello" === message) + } + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + + private def testAsk(conf: CelebornConf): Unit = { + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) + + try { + localEnv.setupEndpoint( + "ask-authentication", + new RpcEndpoint { + override val rpcEnv = localEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => + context.reply(msg) + } + }) + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") + val reply = rpcEndpointRef.askSync[String]("hello") + assert("hello" === reply) + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + + test("construct RpcTimeout with conf property") { + val conf = new CelebornConf() + + val testProp = "celeborn.ask.test.timeout" + val testDurationSeconds = 30 + val secondaryProp = "celeborn.ask.secondary.timeout" + + conf.set(testProp, s"${testDurationSeconds}s") + conf.set(secondaryProp, "100s") + + // Construct RpcTimeout with a single property + val rt1 = RpcTimeout(conf, testProp) + assert(testDurationSeconds === rt1.duration.toSeconds) + + // Construct RpcTimeout with prioritized list of properties + val rt2 = RpcTimeout(conf, Seq("celeborn.ask.invalid.timeout", testProp, secondaryProp), "1s") + assert(testDurationSeconds === rt2.duration.toSeconds) + + // Construct RpcTimeout with default value, + val defaultProp = "celeborn.ask.default.timeout" + val defaultDurationSeconds = 1 + val rt3 = RpcTimeout(conf, Seq(defaultProp), defaultDurationSeconds.toString + "s") + assert(defaultDurationSeconds === rt3.duration.toSeconds) + assert(rt3.timeoutProp.contains(defaultProp)) + + // Try to construct RpcTimeout with an unconfigured property + intercept[NoSuchElementException] { + RpcTimeout(conf, "celeborn.ask.invalid.timeout") + } + } + + test("ask a message timeout on Future using RpcTimeout") { + case class NeverReply(msg: String) + + val rpcEndpointRef = env.setupEndpoint( + "ask-future", + new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.reply(msg) + case _: NeverReply => + } + }) + + val longTimeout = new RpcTimeout(1.second, "celeborn.rpc.long.timeout") + val shortTimeout = new RpcTimeout(10.milliseconds, "celeborn.rpc.short.timeout") + + // Ask with immediate response, should complete successfully + val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout) + val reply1 = longTimeout.awaitResult(fut1) + assert("hello" === reply1) + + // Ask with a delayed response and wait for response immediately that should timeout + val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout) + val reply2 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut2) + }.getMessage + + // RpcTimeout.awaitResult should have added the property to the TimeoutException message + assert(reply2.contains(shortTimeout.timeoutProp)) + + // Ask with delayed response and allow the Future to timeout before ThreadUtils.awaitResult + val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout) + + // scalastyle:off awaitresult + // Allow future to complete with failure using plain Await.result, this will return + // once the future is complete to verify addMessageIfTimeout was invoked + val reply3 = + intercept[RpcTimeoutException] { + Await.result(fut3, 2.seconds) + }.getMessage + // scalastyle:on awaitresult + + // When the future timed out, the recover callback should have used + // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message + assert(reply3.contains(shortTimeout.timeoutProp)) + + // Use RpcTimeout.awaitResult to process Future, since it has already failed with + // RpcTimeoutException, the same RpcTimeoutException should be thrown + val reply4 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut3) + }.getMessage + + // Ensure description is not in message twice after addMessageIfTimeout and awaitResult + assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) + } + + test("RpcEnv.shutdown should not fire onDisconnected events") { + env.setupEndpoint( + "test_ep_11212023", + new RpcEndpoint { + override val rpcEnv: RpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply(m) + } + }) + + val anotherEnv = createRpcEnv(createCelebornConf(), "remote", 0) + val endpoint = mock(classOf[RpcEndpoint]) + anotherEnv.setupEndpoint("test_ep_11212023", endpoint) + + val ref = anotherEnv.setupEndpointRef(env.address, "test_ep_11212023") + // Make sure the connect is set up + assert(ref.askSync[String]("hello") === "hello") + anotherEnv.shutdown() + anotherEnv.awaitTermination() + + env.stop(ref) + + verify(endpoint).onStop() + verify(endpoint, never()).onDisconnected(any()) + verify(endpoint, never()).onNetworkError(any(), any()) + } +} + +case class Register(ref: RpcEndpointRef) + +class UnserializableClass + +class UnserializableException extends Exception { + private val unserializableField = new UnserializableClass +} diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala new file mode 100644 index 00000000000..9d343b03738 --- /dev/null +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/TestRpcEndpoint.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.rpc + +import scala.collection.mutable.ArrayBuffer + +import org.scalactic.TripleEquals +import org.scalatest.Assertions._ + +class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals { + + override val rpcEnv: RpcEnv = null + + @volatile private var receiveMessages = ArrayBuffer[Any]() + + @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]() + + @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, RpcAddress)]() + + @volatile private var started = false + + @volatile private var stopped = false + + override def receive: PartialFunction[Any, Unit] = { + case message: Any => receiveMessages += message + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message: Any => receiveAndReplyMessages += message + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + onConnectedMessages += remoteAddress + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + onNetworkErrorMessages += cause -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + onDisconnectedMessages += remoteAddress + } + + def numReceiveMessages: Int = receiveMessages.size + + override def onStart(): Unit = { + started = true + } + + override def onStop(): Unit = { + stopped = true + } + + def verifyStarted(): Unit = { + assert(started, "RpcEndpoint is not started") + } + + def verifyStopped(): Unit = { + assert(stopped, "RpcEndpoint is not stopped") + } + + def verifyReceiveMessages(expected: Seq[Any]): Unit = { + assert(receiveMessages === expected) + } + + def verifySingleReceiveMessage(message: Any): Unit = { + verifyReceiveMessages(List(message)) + } + + def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = { + assert(receiveAndReplyMessages === expected) + } + + def verifySingleReceiveAndReplyMessage(message: Any): Unit = { + verifyReceiveAndReplyMessages(List(message)) + } + + def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnConnectedMessages(List(remoteAddress)) + } + + def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onConnectedMessages === expected) + } + + def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnDisconnectedMessages(List(remoteAddress)) + } + + def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onDisconnectedMessages === expected) + } + + def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: RpcAddress): Unit = { + verifyOnNetworkErrorMessages(List(cause -> remoteAddress)) + } + + def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): Unit = { + assert(onNetworkErrorMessages === expected) + } +} diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala new file mode 100644 index 00000000000..a8bc826dd4b --- /dev/null +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/InboxSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.rpc.netty + +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.rpc.{RpcAddress, TestRpcEndpoint} + +class InboxSuite extends CelebornFunSuite { + + test("post") { + val endpoint = new TestRpcEndpoint + val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(rpcEnvRef, endpoint) + val message = OneWayMessage(null, "hi") + inbox.post(message) + inbox.process(dispatcher) + assert(inbox.isEmpty) + + endpoint.verifySingleReceiveMessage("hi") + + inbox.stop() + inbox.process(dispatcher) + assert(inbox.isEmpty) + endpoint.verifyStarted() + endpoint.verifyStopped() + } + + test("post: with reply") { + val endpoint = new TestRpcEndpoint + val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(rpcEnvRef, endpoint) + val message = RpcMessage(null, "hi", null) + inbox.post(message) + inbox.process(dispatcher) + assert(inbox.isEmpty) + + endpoint.verifySingleReceiveAndReplyMessage("hi") + } + + test("post: multiple threads") { + val endpoint = new TestRpcEndpoint + val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val numDroppedMessages = new AtomicInteger(0) + val inbox = new Inbox(rpcEnvRef, endpoint) { + override def onDrop(message: InboxMessage): Unit = { + numDroppedMessages.incrementAndGet() + } + } + + val exitLatch = new CountDownLatch(10) + + for (_ <- 0 until 10) { + new Thread { + override def run(): Unit = { + for (_ <- 0 until 100) { + val message = OneWayMessage(null, "hi") + inbox.post(message) + } + exitLatch.countDown() + } + }.start() + } + // Try to process some messages + inbox.process(dispatcher) + inbox.stop() + // After `stop` is called, further messages will be dropped. However, while `stop` is called, + // some messages may be post to Inbox, so process them here. + inbox.process(dispatcher) + assert(inbox.isEmpty) + + exitLatch.await(30, TimeUnit.SECONDS) + + assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get) + endpoint.verifyStarted() + endpoint.verifyStopped() + } + + test("post: Associated") { + val endpoint = new TestRpcEndpoint + val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(rpcEnvRef, endpoint) + inbox.post(RemoteProcessConnected(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnConnectedMessage(remoteAddress) + } + + test("post: Disassociated") { + val endpoint = new TestRpcEndpoint + val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(rpcEnvRef, endpoint) + inbox.post(RemoteProcessDisconnected(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnDisconnectedMessage(remoteAddress) + } + + test("post: AssociationError") { + val endpoint = new TestRpcEndpoint + val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + val cause = new RuntimeException("Oops") + + val inbox = new Inbox(rpcEnvRef, endpoint) + inbox.post(RemoteProcessConnectionError(cause, remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress) + } + + test("should reduce the number of active threads when fatal error happens") { + val endpoint = mock(classOf[TestRpcEndpoint]) + when(endpoint.receive).thenThrow(new OutOfMemoryError()) + val rpcEnvRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + val inbox = new Inbox(rpcEnvRef, endpoint) + inbox.post(OneWayMessage(null, "hi")) + intercept[OutOfMemoryError] { + inbox.process(dispatcher) + } + assert(inbox.getNumActiveThreads === 0) + } +} diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala new file mode 100644 index 00000000000..1c3fe6ef4f2 --- /dev/null +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcAddressSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.rpc.netty + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.rpc.RpcEndpointAddress + +class NettyRpcAddressSuite extends CelebornFunSuite { + + test("toString") { + val addr = new RpcEndpointAddress("localhost", 12345, "test") + assert(addr.toString === "celeborn://test@localhost:12345") + } + + test("toString for client mode") { + val addr = RpcEndpointAddress(null, "test") + assert(addr.toString === "celeborn-client://test") + } +} diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala new file mode 100644 index 00000000000..8afbf598050 --- /dev/null +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnvSuite.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.rpc.netty + +import java.util.concurrent.ExecutionException + +import scala.concurrent.duration._ + +import org.mockito.Mockito.mock +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} + +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.exception.CelebornException +import org.apache.celeborn.common.network.client.TransportClient +import org.apache.celeborn.common.rpc._ +import org.apache.celeborn.common.util.ThreadUtils + +class NettyRpcEnvSuite extends RpcEnvSuite with TimeLimits { + + implicit private val signaler: Signaler = ThreadSignaler + + override def createRpcEnv( + conf: CelebornConf, + name: String, + port: Int, + clientMode: Boolean = false): RpcEnv = { + val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, 0) + new NettyRpcEnvFactory().create(config) + } + + test("non-existent endpoint") { + val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString + val e = intercept[CelebornException] { + env.setupEndpointRef(env.address, "nonexist-endpoint") + } + assert(e.getCause.isInstanceOf[RpcEndpointNotFoundException]) + assert(e.getCause.getMessage.contains(uri)) + } + + test("advertise address different from bind address") { + val celebornConf = createCelebornConf() + val config = RpcEnvConfig(celebornConf, "test", "localhost", "example.com", 0, 0) + val env = new NettyRpcEnvFactory().create(config) + try { + assert(env.address.hostPort.startsWith("example.com:")) + } finally { + env.shutdown() + } + } + + test("RequestMessage serialization") { + def assertRequestMessageEquals(expected: RequestMessage, actual: RequestMessage): Unit = { + assert(expected.senderAddress === actual.senderAddress) + assert(expected.receiver === actual.receiver) + assert(expected.content === actual.content) + } + + val nettyEnv = env.asInstanceOf[NettyRpcEnv] + val client = mock(classOf[TransportClient]) + val senderAddress = RpcAddress("localhost", 12345) + val receiverAddress = RpcEndpointAddress("localhost", 54321, "test") + val receiver = new NettyRpcEndpointRef(nettyEnv.celebornConf, receiverAddress, nettyEnv) + + val msg = new RequestMessage(senderAddress, receiver, "foo") + assertRequestMessageEquals( + msg, + RequestMessage(nettyEnv, client, msg.serialize(nettyEnv))) + + val msg2 = new RequestMessage(null, receiver, "foo") + assertRequestMessageEquals( + msg2, + RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv))) + + val msg3 = new RequestMessage(senderAddress, receiver, null) + assertRequestMessageEquals( + msg3, + RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv))) + } + + test("StackOverflowError should be sent back and Dispatcher should survive") { + val numUsableCores = 2 + val conf = createCelebornConf() + val config = RpcEnvConfig( + conf, + "test", + "localhost", + "localhost", + 0, + numUsableCores) + val anotherEnv = new NettyRpcEnvFactory().create(config) + anotherEnv.setupEndpoint( + "StackOverflowError", + new RpcEndpoint { + override val rpcEnv = anotherEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // scalastyle:off throwerror + case msg: String => throw new StackOverflowError + // scalastyle:on throwerror + case num: Int => context.reply(num) + } + }) + + val rpcEndpointRef = env.setupEndpointRef(anotherEnv.address, "StackOverflowError") + try { + // Send `numUsableCores` messages to trigger `numUsableCores` `StackOverflowError`s + for (_ <- 0 until numUsableCores) { + val e = intercept[CelebornException] { + rpcEndpointRef.askSync[String]("hello") + } + // The root cause `e.getCause.getCause` because it is boxed by Scala Promise. + assert(e.getCause.isInstanceOf[ExecutionException]) + assert(e.getCause.getCause.isInstanceOf[StackOverflowError]) + } + failAfter(10.seconds) { + assert(rpcEndpointRef.askSync[Int](100) === 100) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } +} diff --git a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala new file mode 100644 index 00000000000..3f6ebb0e6c5 --- /dev/null +++ b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/NettyRpcHandlerSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.rpc.netty + +import java.net.InetSocketAddress +import java.nio.ByteBuffer + +import io.netty.channel.Channel +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.network.client.{TransportClient, TransportResponseHandler} +import org.apache.celeborn.common.rpc.RpcAddress + +class NettyRpcHandlerSuite extends CelebornFunSuite { + + val env = mock(classOf[NettyRpcEnv]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) + .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null)) + + test("receive") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.channelActive(client) + + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) + } + + test("connectionTerminated") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.channelActive(client) + + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.channelInactive(client) + + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) + verify(dispatcher, times(1)).postToAll( + RemoteProcessDisconnected(RpcAddress("localhost", 40000))) + } + +}