Skip to content

Commit

Permalink
Http2DefaultFrameWriter direct write instead of copy
Browse files Browse the repository at this point in the history
Motivation:
The Http2DefaultFrameWriter copies all contents into a buffer (or uses a CompositeBuffer in 1 case) and then writes that buffer to the socket. There is an opportunity to avoid the copy operations and write directly to the socket.

Modifications:
- Http2DefaultFrameWriter should avoid copy operations where possible.
- The Http2FrameWriter interface should be clarified to indicate that ByteBuf objects will be released.

Result:
Hopefully less allocation/copy leads to memory and throughput performance benefit.
  • Loading branch information
Scottmitch authored and nmittler committed Feb 6, 2015
1 parent abf7afc commit 8b5f2d7
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 197 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action;
import io.netty.util.concurrent.EventExecutor;

/**
* Constants and utility method used for encoding/decoding HTTP2 frames.
Expand All @@ -48,6 +51,18 @@ public final class Http2CodecUtil {
public static final short MAX_WEIGHT = 256;
public static final short MIN_WEIGHT = 1;

private static final int MAX_PADDING_LENGTH_LENGTH = 1;
public static final int DATA_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH;
public static final int HEADERS_FRAME_HEADER_LENGTH =
FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH + 1;
public static final int PRIORITY_FRAME_LENGTH = FRAME_HEADER_LENGTH + PRIORITY_ENTRY_LENGTH;
public static final int RST_STREAM_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH;
public static final int PUSH_PROMISE_FRAME_HEADER_LENGTH =
FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH;
public static final int GO_AWAY_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + 2 * INT_FIELD_LENGTH;
public static final int WINDOW_UPDATE_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH;
public static final int CONTINUATION_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH;

public static final int SETTINGS_HEADER_TABLE_SIZE = 1;
public static final int SETTINGS_ENABLE_PUSH = 2;
public static final int SETTINGS_MAX_CONCURRENT_STREAMS = 3;
Expand Down Expand Up @@ -183,12 +198,128 @@ public static void writeUnsignedShort(int value, ByteBuf out) {
public static void writeFrameHeader(ByteBuf out, int payloadLength, byte type,
Http2Flags flags, int streamId) {
out.ensureWritable(FRAME_HEADER_LENGTH + payloadLength);
writeFrameHeaderInternal(out, payloadLength, type, flags, streamId);
}

static void writeFrameHeaderInternal(ByteBuf out, int payloadLength, byte type,
Http2Flags flags, int streamId) {
out.writeMedium(payloadLength);
out.writeByte(type);
out.writeByte(flags.value());
out.writeInt(streamId);
}

/**
* Provides the ability to associate the outcome of multiple {@link ChannelPromise}
* objects into a single {@link ChannelPromise} object.
*/
static class SimpleChannelPromiseAggregator extends DefaultChannelPromise {
private final ChannelPromise promise;
private int expectedCount;
private int successfulCount;
private int failureCount;
private boolean doneAllocating;

SimpleChannelPromiseAggregator(ChannelPromise promise, Channel c, EventExecutor e) {
super(c, e);
assert promise != null;
this.promise = promise;
}

/**
* Allocate a new promise which will be used to aggregate the overall success of this promise aggregator.
* @return A new promise which will be aggregated.
* {@code null} if {@link #doneAllocatingPromises()} was previously called.
*/
public ChannelPromise newPromise() {
if (doneAllocating) {
throw new IllegalStateException("Done allocating. No more promises can be allocated.");
}
++expectedCount;
return this;
}

/**
* Signify that no more {@link #newPromise()} allocations will be made.
* The aggregation can not be successful until this method is called.
* @return The promise that is the aggregation of all promises allocated with {@link #newPromise()}.
*/
public ChannelPromise doneAllocatingPromises() {
if (!doneAllocating) {
doneAllocating = true;
if (successfulCount == expectedCount) {
promise.setSuccess();
return super.setSuccess();
}
}
return this;
}

@Override
public boolean tryFailure(Throwable cause) {
if (allowNotificationEvent()) {
++failureCount;
if (failureCount == 1) {
promise.tryFailure(cause);
return super.tryFailure(cause);
}
// TODO: We break the interface a bit here.
// Multiple failure events can be processed without issue because this is an aggregation.
return true;
}
return false;
}

/**
* Fail this object if it has not already been failed.
* <p>
* This method will NOT throw an {@link IllegalStateException} if called multiple times
* because that may be expected.
*/
@Override
public ChannelPromise setFailure(Throwable cause) {
if (allowNotificationEvent()) {
++failureCount;
if (failureCount == 1) {
promise.setFailure(cause);
return super.setFailure(cause);
}
}
return this;
}

private boolean allowNotificationEvent() {
return successfulCount + failureCount < expectedCount;
}

@Override
public ChannelPromise setSuccess(Void result) {
if (allowNotificationEvent()) {
++successfulCount;
if (successfulCount == expectedCount && doneAllocating) {
promise.setSuccess(result);
return super.setSuccess(result);
}
}
return this;
}

@Override
public boolean trySuccess(Void result) {
if (allowNotificationEvent()) {
++successfulCount;
if (successfulCount == expectedCount && doneAllocating) {
promise.trySuccess(result);
return super.trySuccess(result);
}
// TODO: We break the interface a bit here.
// Multiple success events can be processed without issue because this is an aggregation.
return true;
}
return false;
}
}

/**
* Fails the given promise with the cause and then re-throws the cause.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public interface Http2DataWriter {
*
* @param ctx the context to use for writing.
* @param streamId the stream for which to send the frame.
* @param data the payload of the frame.
* @param data the payload of the frame. This will be released by this method.
* @param padding the amount of padding to be added to the end of the frame
* @param endStream indicates if this is the last frame to be sent for the stream.
* @param promise the promise for the write.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ ChannelFuture writeSettings(ChannelHandlerContext ctx, Http2Settings settings,
* @param ctx the context to use for writing.
* @param ack indicates whether this is an ack of a PING frame previously received from the
* remote endpoint.
* @param data the payload of the frame.
* @param data the payload of the frame. This will be released by this method.
* @param promise the promise for the write.
* @return the future for the write.
*/
Expand All @@ -156,7 +156,7 @@ ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, int prom
* @param ctx the context to use for writing.
* @param lastStreamId the last known stream of this endpoint.
* @param errorCode the error code, if the connection was abnormally terminated.
* @param debugData application-defined debug data.
* @param debugData application-defined debug data. This will be released by this method.
* @param promise the promise for the write.
* @return the future for the write.
*/
Expand All @@ -183,7 +183,7 @@ ChannelFuture writeWindowUpdate(ChannelHandlerContext ctx, int streamId,
* @param frameType the frame type identifier.
* @param streamId the stream for which to send the frame.
* @param flags the flags to write for this frame.
* @param payload the payload to write for this frame.
* @param payload the payload to write for this frame. This will be released by this method.
* @param promise the promise for the write.
* @return the future for the write.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.ChannelPromiseAggregator;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;

/**
* Translates HTTP/1.x object writes into HTTP/2 frames.
Expand Down Expand Up @@ -65,6 +65,7 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
FullHttpMessage httpMsg = (FullHttpMessage) msg;
boolean hasData = httpMsg.content().isReadable();
boolean httpMsgNeedRelease = true;
SimpleChannelPromiseAggregator promiseAggregator = null;
try {
// Provide the user the opportunity to specify the streamId
int streamId = getStreamId(httpMsg.headers());
Expand All @@ -74,18 +75,20 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
Http2ConnectionEncoder encoder = encoder();

if (hasData) {
ChannelPromiseAggregator promiseAggregator = new ChannelPromiseAggregator(promise);
ChannelPromise headerPromise = ctx.newPromise();
ChannelPromise dataPromise = ctx.newPromise();
promiseAggregator.add(headerPromise, dataPromise);
encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, headerPromise);
promiseAggregator = new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, promiseAggregator.newPromise());
httpMsgNeedRelease = false;
encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, dataPromise);
encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, promiseAggregator.newPromise());
promiseAggregator.doneAllocatingPromises();
} else {
encoder.writeHeaders(ctx, streamId, http2Headers, 0, true, promise);
}
} catch (Throwable t) {
promise.tryFailure(t);
if (promiseAggregator == null) {
promise.tryFailure(t);
} else {
promiseAggregator.setFailure(t);
}
} finally {
if (httpMsgNeedRelease) {
httpMsg.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
*/
package io.netty.handler.codec.http2;

import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.TooLongFrameException;
Expand All @@ -28,7 +29,6 @@
import io.netty.handler.codec.http.HttpStatusClass;
import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap;
import static io.netty.util.internal.ObjectUtil.*;

/**
* This adapter provides just header/data events from the HTTP message flow defined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,33 @@
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT;
import static io.netty.handler.codec.http2.Http2TestUtil.as;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.EventExecutor;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/**
* Integration tests for {@link DefaultHttp2FrameReader} and {@link DefaultHttp2FrameWriter}.
Expand All @@ -43,6 +54,8 @@ public class DefaultHttp2FrameIOTest {
private DefaultHttp2FrameReader reader;
private DefaultHttp2FrameWriter writer;
private ByteBufAllocator alloc;
private CountDownLatch latch;
private ByteBuf buffer;

@Mock
private ChannelHandlerContext ctx;
Expand All @@ -53,13 +66,56 @@ public class DefaultHttp2FrameIOTest {
@Mock
private ChannelPromise promise;

@Mock
private Channel channel;

@Mock
private EventExecutor executor;

@Before
public void setup() {
MockitoAnnotations.initMocks(this);

alloc = UnpooledByteBufAllocator.DEFAULT;
buffer = alloc.buffer();
latch = new CountDownLatch(1);

when(executor.inEventLoop()).thenReturn(true);
when(ctx.alloc()).thenReturn(alloc);
when(ctx.channel()).thenReturn(channel);
when(ctx.executor()).thenReturn(executor);
doAnswer(new Answer<ChannelPromise>() {
@Override
public ChannelPromise answer(InvocationOnMock invocation) throws Throwable {
return new DefaultChannelPromise(channel, executor);
}
}).when(ctx).newPromise();

doAnswer(new Answer<ChannelPromise>() {
@Override
public ChannelPromise answer(InvocationOnMock in) throws Throwable {
latch.countDown();
return promise;
}
}).when(promise).setSuccess();

doAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock in) throws Throwable {
if (in.getArguments()[0] instanceof ByteBuf) {
ByteBuf tmp = (ByteBuf) in.getArguments()[0];
try {
buffer.writeBytes(tmp);
} finally {
tmp.release();
}
}
if (in.getArguments()[1] instanceof ChannelPromise) {
return ((ChannelPromise) in.getArguments()[1]).setSuccess();
}
return null;
}
}).when(ctx).write(any(), any(ChannelPromise.class));

reader = new DefaultHttp2FrameReader();
writer = new DefaultHttp2FrameWriter();
Expand Down Expand Up @@ -452,10 +508,9 @@ public void continuedPushPromiseWithPaddingShouldRoundtrip() throws Exception {
}
}

private ByteBuf captureWrite() {
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(ctx).write(captor.capture(), eq(promise));
return captor.getValue();
private ByteBuf captureWrite() throws InterruptedException {
assertTrue(latch.await(2, TimeUnit.SECONDS));
return buffer;
}

private ByteBuf dummyData() {
Expand All @@ -471,9 +526,8 @@ private static Http2Headers dummyBinaryHeaders() {
}

private static Http2Headers dummyHeaders() {
return new DefaultHttp2Headers().method(as("GET")).scheme(as("https"))
.authority(as("example.org")).path(as("/some/path"))
.add(as("accept"), as("*/*"));
return new DefaultHttp2Headers().method(as("GET")).scheme(as("https")).authority(as("example.org"))
.path(as("/some/path")).add(as("accept"), as("*/*"));
}

private static Http2Headers largeHeaders() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
* Testing the {@link HttpToHttp2ConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames
*/
public class HttpToHttp2ConnectionHandlerTest {
private static final int WAIT_TIME_SECONDS = 5;
private static final int WAIT_TIME_SECONDS = 500;

@Mock
private Http2FrameListener clientListener;
Expand Down

0 comments on commit 8b5f2d7

Please sign in to comment.