Skip to content

Commit

Permalink
RPC: propagate failures when they occur during deserialization.
Browse files Browse the repository at this point in the history
Before this change, a failure to deserialize an RPC reply would leave
the caller hanging because we'd never set the future.
  • Loading branch information
Mike Hearn committed Aug 24, 2018
1 parent fd8c2e4 commit 8fd4d0d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,15 @@ import org.apache.activemq.artemis.api.core.ActiveMQException
import org.apache.activemq.artemis.api.core.ActiveMQNotConnectedException
import org.apache.activemq.artemis.api.core.RoutingType
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.*
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.client.ClientSessionFactory
import org.apache.activemq.artemis.api.core.client.FailoverEventType
import org.apache.activemq.artemis.api.core.client.ServerLocator
import rx.Notification
import rx.Observable
import rx.subjects.UnicastSubject
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlin.reflect.jvm.javaMethod
Expand Down Expand Up @@ -288,56 +277,71 @@ class RPCClientProxyHandler(

// The handler for Artemis messages.
private fun artemisMessageHandler(message: ClientMessage) {
val serverToClient = RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message)
val deduplicationSequenceNumber = message.getLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME)
if (deduplicationChecker.checkDuplicateMessageId(serverToClient.deduplicationIdentity, deduplicationSequenceNumber)) {
log.info("Message duplication detected, discarding message")
return
fun completeExceptionally(id: InvocationId, e: Throwable, future: SettableFuture<Any?>?) {
val rpcCallSite: Throwable? = callSiteMap?.get(id)
if (rpcCallSite != null) addRpcCallSiteToThrowable(e, rpcCallSite)
future?.setException(e.cause ?: e)
}
log.debug { "Got message from RPC server $serverToClient" }
when (serverToClient) {
is RPCApi.ServerToClient.RpcReply -> {
val replyFuture = rpcReplyMap.remove(serverToClient.id)
if (replyFuture == null) {
log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.")
} else {
val result = serverToClient.result
when (result) {
is Try.Success -> replyFuture.set(result.value)
is Try.Failure -> {
val rpcCallSite = callSiteMap?.get(serverToClient.id)
if (rpcCallSite != null) addRpcCallSiteToThrowable(result.exception, rpcCallSite)
replyFuture.setException(result.exception)

try {
// Deserialize the reply from the server, both the wrapping metadata and the actual body of the return value.
val serverToClient: RPCApi.ServerToClient = try {
RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message)
} catch (e: RPCApi.ServerToClient.FailedToDeserializeReply) {
// Might happen if something goes wrong during mapping the response to classes, evolution, class synthesis etc.
log.error("Failed to deserialize RPC body", e)
completeExceptionally(e.id, e, rpcReplyMap.remove(e.id))
return
}
val deduplicationSequenceNumber = message.getLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME)
if (deduplicationChecker.checkDuplicateMessageId(serverToClient.deduplicationIdentity, deduplicationSequenceNumber)) {
log.info("Message duplication detected, discarding message")
return
}
log.debug { "Got message from RPC server $serverToClient" }
when (serverToClient) {
is RPCApi.ServerToClient.RpcReply -> {
val replyFuture = rpcReplyMap.remove(serverToClient.id)
if (replyFuture == null) {
log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.")
} else {
val result: Try<Any?> = serverToClient.result
when (result) {
is Try.Success -> replyFuture.set(result.value)
is Try.Failure -> {
completeExceptionally(serverToClient.id, result.exception, replyFuture)
}
}
}
}
}
is RPCApi.ServerToClient.Observation -> {
val observable = observableContext.observableMap.getIfPresent(serverToClient.id)
if (observable == null) {
log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " +
"This may be due to an observation arriving before the server was " +
"notified of observable shutdown")
} else {
// We schedule the onNext() on an executor sticky-pooled based on the Observable ID.
observationExecutorPool.run(serverToClient.id) { executor ->
executor.submit {
val content = serverToClient.content
if (content.isOnCompleted || content.isOnError) {
observableContext.observableMap.invalidate(serverToClient.id)
is RPCApi.ServerToClient.Observation -> {
val observable: UnicastSubject<Notification<*>>? = observableContext.observableMap.getIfPresent(serverToClient.id)
if (observable == null) {
log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " +
"This may be due to an observation arriving before the server was " +
"notified of observable shutdown")
} else {
// We schedule the onNext() on an executor sticky-pooled based on the Observable ID.
observationExecutorPool.run(serverToClient.id) { executor ->
executor.submit {
val content = serverToClient.content
if (content.isOnCompleted || content.isOnError) {
observableContext.observableMap.invalidate(serverToClient.id)
}
// Add call site information on error
if (content.isOnError) {
val rpcCallSite = callSiteMap?.get(serverToClient.id)
if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite)
}
observable.onNext(content)
}
// Add call site information on error
if (content.isOnError) {
val rpcCallSite = callSiteMap?.get(serverToClient.id)
if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite)
}
observable.onNext(content)
}
}
}
}
} finally {
message.acknowledge()
}
message.acknowledge()
}

/**
Expand Down
17 changes: 15 additions & 2 deletions node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.Try
import org.apache.activemq.artemis.api.core.ActiveMQBuffer
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.*
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.management.CoreNotificationType
import org.apache.activemq.artemis.api.core.management.ManagementHelper
import org.apache.activemq.artemis.reader.MessageUtil
Expand Down Expand Up @@ -212,6 +212,11 @@ object RPCApi {
}
}

/**
* Thrown if the RPC reply body couldn't be deserialized.
*/
class FailedToDeserializeReply(val id: InvocationId, cause: Throwable) : RuntimeException("Failed to deserialize RPC reply: ${cause.message}", cause)

companion object {
private fun Any.safeSerialize(context: SerializationContext, wrap: (Throwable) -> Any) = try {
serialize(context = context)
Expand All @@ -226,10 +231,18 @@ object RPCApi {
RPCApi.ServerToClient.Tag.RPC_REPLY -> {
val id = message.invocationId(RPC_ID_FIELD_NAME, RPC_ID_TIMESTAMP_FIELD_NAME) ?: throw IllegalStateException("Cannot parse invocation id from client message.")
val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id)
// The result here is a Try<> that represents the attempt to try the operation on the server side.
// If anything goes wrong with deserialisation of the response, we propagate it differently because
// we also need to pass through the invocation and dedupe IDs.
val result: Try<Any?> = try {
message.getBodyAsByteArray().deserialize(context = poolWithIdContext)
} catch (e: Exception) {
throw FailedToDeserializeReply(id, e)
}
RpcReply(
id = id,
deduplicationIdentity = deduplicationIdentity,
result = message.getBodyAsByteArray().deserialize(context = poolWithIdContext)
result = result
)
}
RPCApi.ServerToClient.Tag.OBSERVATION -> {
Expand Down

0 comments on commit 8fd4d0d

Please sign in to comment.