Skip to content

Commit

Permalink
Correctly decrement pending bytes when submitting AbstractWriteTask f…
Browse files Browse the repository at this point in the history
…ails. (netty#8349)

Motivation:

Currently we may end up in the situation that we incremented the pending bytes before submitting the AbstractWriteTask but never decrement these again if the submitting of the task fails. This may result in incorrect watermark handling.

Modifications:

- Correctly decrement pending bytes if subimitting of task fails and also ensure we recycle it correctly.
- Add unit test.

Result:

Fixes netty#8343.
  • Loading branch information
normanmaurer authored Oct 11, 2018
1 parent 0e4186c commit 652650b
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -816,13 +816,19 @@ private void write(Object msg, boolean flush, ChannelPromise promise) {
next.invokeWrite(m, promise);
}
} else {
AbstractWriteTask task;
final AbstractWriteTask task;
if (flush) {
task = WriteAndFlushTask.newInstance(next, m, promise);
} else {
task = WriteTask.newInstance(next, m, promise);
}
safeExecute(executor, task, promise, m);
if (!safeExecute(executor, task, promise, m)) {
// We failed to submit the AbstractWriteTask. We need to cancel it so we decrement the pending bytes
// and put it back in the Recycler for re-use later.
//
// See https://github.com/netty/netty/issues/8343.
task.cancel();
}
}
}

Expand Down Expand Up @@ -1002,9 +1008,10 @@ public <T> boolean hasAttr(AttributeKey<T> key) {
return channel().hasAttr(key);
}

private static void safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) {
private static boolean safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) {
try {
executor.execute(runnable);
return true;
} catch (Throwable cause) {
try {
promise.setFailure(cause);
Expand All @@ -1013,6 +1020,7 @@ private static void safeExecute(EventExecutor executor, Runnable runnable, Chann
ReferenceCountUtil.release(msg);
}
}
return false;
}
}

Expand Down Expand Up @@ -1063,20 +1071,35 @@ protected static void init(AbstractWriteTask task, AbstractChannelHandlerContext
@Override
public final void run() {
try {
// Check for null as it may be set to null if the channel is closed already
if (ESTIMATE_TASK_SIZE_ON_SUBMIT) {
ctx.pipeline.decrementPendingOutboundBytes(size);
}
decrementPendingOutboundBytes();
write(ctx, msg, promise);
} finally {
// Set to null so the GC can collect them directly
ctx = null;
msg = null;
promise = null;
handle.recycle(this);
recycle();
}
}

void cancel() {
try {
decrementPendingOutboundBytes();
} finally {
recycle();
}
}

private void decrementPendingOutboundBytes() {
if (ESTIMATE_TASK_SIZE_ON_SUBMIT) {
ctx.pipeline.decrementPendingOutboundBytes(size);
}
}

private void recycle() {
// Set to null so the GC can collect them directly
ctx = null;
msg = null;
promise = null;
handle.recycle(this);
}

protected void write(AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
ctx.invokeWrite(msg, promise);
}
Expand All @@ -1091,7 +1114,7 @@ protected WriteTask newObject(Handle<WriteTask> handle) {
}
};

private static WriteTask newInstance(
static WriteTask newInstance(
AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
WriteTask task = RECYCLER.get();
init(task, ctx, msg, promise);
Expand All @@ -1112,7 +1135,7 @@ protected WriteAndFlushTask newObject(Handle<WriteAndFlushTask> handle) {
}
};

private static WriteAndFlushTask newInstance(
static WriteAndFlushTask newInstance(
AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
WriteAndFlushTask task = RECYCLER.get();
init(task, ctx, msg, promise);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.concurrent.RejectedExecutionHandlers;
import io.netty.util.concurrent.SingleThreadEventExecutor;
import org.junit.Test;

import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.RejectedExecutionException;

import static io.netty.buffer.Unpooled.*;
import static org.hamcrest.Matchers.*;
Expand Down Expand Up @@ -355,6 +361,85 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio
safeClose(ch);
}

@Test(timeout = 5000)
public void testWriteTaskRejected() throws Exception {
final SingleThreadEventExecutor executor = new SingleThreadEventExecutor(
null, new DefaultThreadFactory("executorPool"),
true, 1, RejectedExecutionHandlers.reject()) {
@Override
protected void run() {
do {
Runnable task = takeTask();
if (task != null) {
task.run();
updateLastExecutionTime();
}
} while (!confirmShutdown());
}

@Override
protected Queue<Runnable> newTaskQueue(int maxPendingTasks) {
return super.newTaskQueue(1);
}
};
final CountDownLatch handlerAddedLatch = new CountDownLatch(1);
EmbeddedChannel ch = new EmbeddedChannel();
ch.pipeline().addLast(executor, new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
promise.setFailure(new AssertionError("Should not be called"));
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
handlerAddedLatch.countDown();
}
});

// Lets wait until we are sure the handler was added.
handlerAddedLatch.await();

final CountDownLatch executeLatch = new CountDownLatch(1);
final CountDownLatch runLatch = new CountDownLatch(1);
executor.execute(new Runnable() {
@Override
public void run() {
try {
runLatch.countDown();
executeLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});

runLatch.await();

executor.execute(new Runnable() {
@Override
public void run() {
// Will not be executed but ensure the pending count is 1.
}
});

assertEquals(1, executor.pendingTasks());
assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes());

ByteBuf buffer = buffer(128).writeZero(128);
ChannelFuture future = ch.write(buffer);
ch.runPendingTasks();

assertTrue(future.cause() instanceof RejectedExecutionException);
assertEquals(0, buffer.refCnt());

// In case of rejected task we should not have anything pending.
assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes());
executeLatch.countDown();

safeClose(ch);
executor.shutdownGracefully();
}

private static void safeClose(EmbeddedChannel ch) {
ch.finish();
for (;;) {
Expand Down

0 comments on commit 652650b

Please sign in to comment.