diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java index 0d8eec538..a8f31f79d 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java +++ b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java @@ -224,6 +224,7 @@ private void registerClient(final Rpc client) { @Override public void onSuccess(Void unused) { clients.remove(client); + client.unRegisterRpc(); if (!inShutdown.get()) { setupIdleTimeout(); } diff --git a/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java b/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java index 868dc6dee..5fce16410 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java +++ b/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java @@ -19,10 +19,11 @@ import java.io.Closeable; import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.LinkedList; -import java.util.Map; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; @@ -208,6 +209,7 @@ static Rpc createEmbedded(RpcDispatcher dispatcher) { dispatcher); Rpc rpc = new Rpc(new RSCConf(null), c, ImmediateEventExecutor.INSTANCE); rpc.dispatcher = dispatcher; + dispatcher.registerRpc(c, rpc); return rpc; } @@ -218,6 +220,10 @@ static Rpc createEmbedded(RpcDispatcher dispatcher) { private final EventExecutorGroup egroup; private volatile RpcDispatcher dispatcher; + private final Map, Method> handlers = new ConcurrentHashMap<>(); + private final Collection rpcCalls = new ConcurrentLinkedQueue(); + private volatile Rpc.MessageHeader lastHeader; + private Rpc(RSCConf config, Channel channel, EventExecutorGroup egroup) { Utils.checkArgument(channel != null); Utils.checkArgument(egroup != null); @@ -238,6 +244,166 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { }); } + /** + * For debugging purposes. + * @return The name of this Class. + */ + protected String name() { + return getClass().getSimpleName(); + } + + public void handleMsg(ChannelHandlerContext ctx, Object msg, Class handleClass, Object obj) + throws Exception { + if (lastHeader == null) { + if (!(msg instanceof MessageHeader)) { + LOG.warn("[{}] Expected RPC header, got {} instead.", name(), + msg != null ? msg.getClass().getName() : null); + throw new IllegalArgumentException(); + } + lastHeader = (MessageHeader) msg; + } else { + LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(), + lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() : null); + try { + switch (lastHeader.type) { + case CALL: + handleCall(ctx, msg, handleClass, obj); + break; + case REPLY: + handleReply(ctx, msg, findRpcCall(lastHeader.id)); + break; + case ERROR: + handleError(ctx, msg, findRpcCall(lastHeader.id)); + break; + default: + throw new IllegalArgumentException("Unknown RPC message type: " + lastHeader.type); + } + } finally { + lastHeader = null; + } + } + } + + private void handleCall(ChannelHandlerContext ctx, Object msg, Class handleClass, Object obj) + throws Exception { + Method handler = handlers.get(msg.getClass()); + if (handler == null) { + // Try both getDeclaredMethod() and getMethod() so that we try both private methods + // of the class, and public methods of parent classes. + try { + handler = handleClass.getDeclaredMethod("handle", ChannelHandlerContext.class, + msg.getClass()); + } catch (NoSuchMethodException e) { + try { + handler = handleClass.getMethod("handle", ChannelHandlerContext.class, + msg.getClass()); + } catch (NoSuchMethodException e2) { + LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", name(), + msg.getClass().getName())); + writeMessage(MessageType.ERROR, Utils.stackTraceAsString(e.getCause())); + return; + } + } + handler.setAccessible(true); + handlers.put(msg.getClass(), handler); + } + + try { + Object payload = handler.invoke(obj, ctx, msg); + if (payload == null) { + payload = new NullMessage(); + } + writeMessage(MessageType.REPLY, payload); + } catch (InvocationTargetException ite) { + LOG.debug(String.format("[%s] Error in RPC handler.", name()), ite.getCause()); + writeMessage(MessageType.ERROR, Utils.stackTraceAsString(ite.getCause())); + } + } + + private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) { + rpc.future.setSuccess(msg instanceof NullMessage ? null : msg); + } + + private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) { + if (msg instanceof String) { + LOG.warn("Received error message:{}.", msg); + rpc.future.setFailure(new RpcException((String) msg)); + } else { + String error = String.format("Received error with unexpected payload (%s).", + msg != null ? msg.getClass().getName() : null); + LOG.warn(String.format("[%s] %s", name(), error)); + rpc.future.setFailure(new IllegalArgumentException(error)); + ctx.close(); + } + } + + private void writeMessage(MessageType replyType, Object payload) { + channel.write(new MessageHeader(lastHeader.id, replyType)); + channel.writeAndFlush(payload); + } + + private OutstandingRpc findRpcCall(long id) { + for (Iterator it = rpcCalls.iterator(); it.hasNext();) { + OutstandingRpc rpc = it.next(); + if (rpc.id == id) { + it.remove(); + return rpc; + } + } + throw new IllegalArgumentException(String.format( + "Received RPC reply for unknown RPC (%d).", id)); + } + + private void registerRpcCall(long id, Promise promise, String type) { + LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type); + rpcCalls.add(new OutstandingRpc(id, promise)); + } + + private void discardRpcCall(long id) { + LOG.debug("[{}] Discarding failed RPC {}.", name(), id); + findRpcCall(id); + } + + private static class OutstandingRpc { + final long id; + final Promise future; + + @SuppressWarnings("unchecked") + OutstandingRpc(long id, Promise future) { + this.id = id; + this.future = (Promise) future; + } + } + + public void handleChannelException(ChannelHandlerContext ctx, Throwable cause) { + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("[%s] Caught exception in channel pipeline.", name()), cause); + } else { + LOG.info(String.format("[%s] Caught exception in channel pipeline.", name()), cause); + } + + if (lastHeader != null) { + // There's an RPC waiting for a reply. Exception was most probably caught while processing + // the RPC, so send an error. + channel.write(new MessageHeader(lastHeader.id, MessageType.ERROR)); + channel.writeAndFlush(Utils.stackTraceAsString(cause)); + lastHeader = null; + } + + ctx.close(); + } + + public void handleChannelInactive() { + if (rpcCalls.size() > 0) { + LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), rpcCalls.size()); + for (OutstandingRpc rpc : rpcCalls) { + rpc.future.cancel(true); + } + } else { + LOG.debug("Channel {} became inactive.", channel); + } + } + /** * Send an RPC call to the remote endpoint and returns a future that can be used to monitor the * operation. @@ -269,13 +435,13 @@ public void operationComplete(ChannelFuture cf) { if (!cf.isSuccess() && !promise.isDone()) { LOG.warn("Failed to send RPC, closing connection.", cf.cause()); promise.setFailure(cf.cause()); - dispatcher.discardRpc(id); + discardRpcCall(id); close(); } } }; - dispatcher.registerRpc(id, promise, msg.getClass().getName()); + registerRpcCall(id, promise, msg.getClass().getName()); channel.eventLoop().submit(new Runnable() { @Override public void run() { @@ -294,11 +460,18 @@ public Channel getChannel() { return channel; } + public void unRegisterRpc() { + if (dispatcher != null) { + dispatcher.unregisterRpc(channel); + } + } + void setDispatcher(RpcDispatcher dispatcher) { Utils.checkNotNull(dispatcher); Utils.checkState(this.dispatcher == null, "Dispatcher already set."); this.dispatcher = dispatcher; channel.pipeline().addLast("dispatcher", dispatcher); + dispatcher.registerRpc(channel, this); } @Override diff --git a/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java b/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java index 0c149b01e..88744c24a 100644 --- a/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java +++ b/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java @@ -17,22 +17,15 @@ package org.apache.livy.rsc.rpc; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.util.Collection; -import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.util.concurrent.Promise; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.livy.rsc.Utils; - /** * An implementation of ChannelInboundHandler that dispatches incoming messages to an instance * method based on the method signature. @@ -49,10 +42,7 @@ public abstract class RpcDispatcher extends SimpleChannelInboundHandler private static final Logger LOG = LoggerFactory.getLogger(RpcDispatcher.class); - private final Map, Method> handlers = new ConcurrentHashMap<>(); - private final Collection rpcs = new ConcurrentLinkedQueue(); - - private volatile Rpc.MessageHeader lastHeader; + private final Map channelRpc = new ConcurrentHashMap<>(); /** * Override this to add a name to the dispatcher, for debugging purposes. @@ -62,161 +52,36 @@ protected String name() { return getClass().getSimpleName(); } - @Override - protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { - if (lastHeader == null) { - if (!(msg instanceof Rpc.MessageHeader)) { - LOG.warn("[{}] Expected RPC header, got {} instead.", name(), - msg != null ? msg.getClass().getName() : null); - throw new IllegalArgumentException(); - } - lastHeader = (Rpc.MessageHeader) msg; - } else { - LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(), - lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() : null); - try { - switch (lastHeader.type) { - case CALL: - handleCall(ctx, msg); - break; - case REPLY: - handleReply(ctx, msg, findRpc(lastHeader.id)); - break; - case ERROR: - handleError(ctx, msg, findRpc(lastHeader.id)); - break; - default: - throw new IllegalArgumentException("Unknown RPC message type: " + lastHeader.type); - } - } finally { - lastHeader = null; - } - } + public void registerRpc(Channel channel, Rpc rpc) { + channelRpc.put(channel, rpc); } - private OutstandingRpc findRpc(long id) { - for (Iterator it = rpcs.iterator(); it.hasNext();) { - OutstandingRpc rpc = it.next(); - if (rpc.id == id) { - it.remove(); - return rpc; - } - } - throw new IllegalArgumentException(String.format( - "Received RPC reply for unknown RPC (%d).", id)); + public void unregisterRpc(Channel channel) { + channelRpc.remove(channel); } - private void handleCall(ChannelHandlerContext ctx, Object msg) throws Exception { - Method handler = handlers.get(msg.getClass()); - if (handler == null) { - // Try both getDeclaredMethod() and getMethod() so that we try both private methods - // of the class, and public methods of parent classes. - try { - handler = getClass().getDeclaredMethod("handle", ChannelHandlerContext.class, - msg.getClass()); - } catch (NoSuchMethodException e) { - try { - handler = getClass().getMethod("handle", ChannelHandlerContext.class, - msg.getClass()); - } catch (NoSuchMethodException e2) { - LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", name(), - msg.getClass().getName())); - writeMessage(ctx, Rpc.MessageType.ERROR, Utils.stackTraceAsString(e.getCause())); - return; - } - } - handler.setAccessible(true); - handlers.put(msg.getClass(), handler); - } - - try { - Object payload = handler.invoke(this, ctx, msg); - if (payload == null) { - payload = new Rpc.NullMessage(); - } - writeMessage(ctx, Rpc.MessageType.REPLY, payload); - } catch (InvocationTargetException ite) { - LOG.debug(String.format("[%s] Error in RPC handler.", name()), ite.getCause()); - writeMessage(ctx, Rpc.MessageType.ERROR, Utils.stackTraceAsString(ite.getCause())); + private Rpc getRpc(ChannelHandlerContext ctx) { + Channel channel = ctx.channel(); + if (!channelRpc.containsKey(channel)) { + throw new IllegalArgumentException("not existed channel:" + channel); } - } - - private void writeMessage(ChannelHandlerContext ctx, Rpc.MessageType replyType, Object payload) { - ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, replyType)); - ctx.channel().writeAndFlush(payload); - } - private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) - throws Exception { - rpc.future.setSuccess(msg instanceof Rpc.NullMessage ? null : msg); + return channelRpc.get(channel); } - private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) - throws Exception { - if (msg instanceof String) { - LOG.warn("Received error message:{}.", msg); - rpc.future.setFailure(new RpcException((String) msg)); - } else { - String error = String.format("Received error with unexpected payload (%s).", - msg != null ? msg.getClass().getName() : null); - LOG.warn(String.format("[%s] %s", name(), error)); - rpc.future.setFailure(new IllegalArgumentException(error)); - ctx.close(); - } + @Override + protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + getRpc(ctx).handleMsg(ctx, msg, getClass(), this); } @Override public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("[%s] Caught exception in channel pipeline.", name()), cause); - } else { - LOG.info("[{}] Closing channel due to exception in pipeline ({}).", name(), - cause.getMessage()); - } - - if (lastHeader != null) { - // There's an RPC waiting for a reply. Exception was most probably caught while processing - // the RPC, so send an error. - ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, Rpc.MessageType.ERROR)); - ctx.channel().writeAndFlush(Utils.stackTraceAsString(cause)); - lastHeader = null; - } - - ctx.close(); + getRpc(ctx).handleChannelException(ctx, cause); } @Override public final void channelInactive(ChannelHandlerContext ctx) throws Exception { - if (rpcs.size() > 0) { - LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), rpcs.size()); - for (OutstandingRpc rpc : rpcs) { - rpc.future.cancel(true); - } - } else { - LOG.debug("Channel {} became inactive.", ctx.channel()); - } + getRpc(ctx).handleChannelInactive(); super.channelInactive(ctx); } - - void registerRpc(long id, Promise promise, String type) { - LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type); - rpcs.add(new OutstandingRpc(id, promise)); - } - - void discardRpc(long id) { - LOG.debug("[{}] Discarding failed RPC {}.", name(), id); - findRpc(id); - } - - private static class OutstandingRpc { - final long id; - final Promise future; - - @SuppressWarnings("unchecked") - OutstandingRpc(long id, Promise future) { - this.id = id; - this.future = (Promise) future; - } - } - }