Skip to content

Commit

Permalink
DefaultHttp2ConnectionEncoder writeHeaders method always send an head…
Browse files Browse the repository at this point in the history
… frame with a priority (netty#9852)

Motivation:

The current implementation delegates to writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, short weight, boolean exclusive, int padding, boolean endStream, ChannelPromise promise) that will send an header frame with the priority flag set and the default priority values even if the user didnt want too.

Modifications:

- Change DefaultHttp2ConnectionEncoder to call the correct Http2FrameWriter method depending on if the user wants to use priorities or not
- Adjust tests

Result:

Fixes netty#9842
  • Loading branch information
normanmaurer authored Dec 8, 2019
1 parent 15b6ed9 commit e875e59
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public ChannelFuture writeData(final ChannelHandlerContext ctx, final int stream
@Override
public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endStream, ChannelPromise promise) {
return writeHeaders(ctx, streamId, headers, 0, DEFAULT_PRIORITY_WEIGHT, false, padding, endStream, promise);
return writeHeaders0(ctx, streamId, headers, false, 0, (short) 0, false, padding, endStream, promise);
}

private static boolean validateHeadersSentState(Http2Stream stream, Http2Headers headers, boolean isServer,
Expand All @@ -164,6 +164,31 @@ private static boolean validateHeadersSentState(Http2Stream stream, Http2Headers
public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId,
final Http2Headers headers, final int streamDependency, final short weight,
final boolean exclusive, final int padding, final boolean endOfStream, ChannelPromise promise) {
return writeHeaders0(ctx, streamId, headers, true, streamDependency,
weight, exclusive, padding, endOfStream, promise);
}

/**
* Write headers via {@link Http2FrameWriter}. If {@code hasPriority} is {@code false} it will ignore the
* {@code streamDependency}, {@code weight} and {@code exclusive} parameters.
*/
private static ChannelFuture sendHeaders(Http2FrameWriter frameWriter, ChannelHandlerContext ctx, int streamId,
Http2Headers headers, final boolean hasPriority,
int streamDependency, final short weight,
boolean exclusive, final int padding,
boolean endOfStream, ChannelPromise promise) {
if (hasPriority) {
return frameWriter.writeHeaders(ctx, streamId, headers, streamDependency,
weight, exclusive, padding, endOfStream, promise);
}
return frameWriter.writeHeaders(ctx, streamId, headers, padding, endOfStream, promise);
}

private ChannelFuture writeHeaders0(final ChannelHandlerContext ctx, final int streamId,
final Http2Headers headers, final boolean hasPriority,
final int streamDependency, final short weight,
final boolean exclusive, final int padding,
final boolean endOfStream, ChannelPromise promise) {
try {
Http2Stream stream = connection.stream(streamId);
if (stream == null) {
Expand Down Expand Up @@ -205,8 +230,9 @@ public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int str
promise = promise.unvoid();
boolean isInformational = validateHeadersSentState(stream, headers, connection.isServer(), endOfStream);

ChannelFuture future = frameWriter.writeHeaders(ctx, streamId, headers, streamDependency,
weight, exclusive, padding, endOfStream, promise);
ChannelFuture future = sendHeaders(frameWriter, ctx, streamId, headers, hasPriority, streamDependency,
weight, exclusive, padding, endOfStream, promise);

// Writing headers may fail during the encode state if they violate HPACK limits.
Throwable failureCause = future.cause();
if (failureCause == null) {
Expand Down Expand Up @@ -236,8 +262,8 @@ public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int str
} else {
// Pass headers to the flow-controller so it can maintain their sequence relative to DATA frames.
flowController.addFlowControlled(stream,
new FlowControlledHeaders(stream, headers, streamDependency, weight, exclusive, padding,
true, promise));
new FlowControlledHeaders(stream, headers, hasPriority, streamDependency,
weight, exclusive, padding, true, promise));
return promise;
}
} catch (Throwable t) {
Expand Down Expand Up @@ -519,14 +545,17 @@ public void operationComplete(ChannelFuture future) throws Exception {
*/
private final class FlowControlledHeaders extends FlowControlledBase {
private final Http2Headers headers;
private final boolean hasPriorty;
private final int streamDependency;
private final short weight;
private final boolean exclusive;

FlowControlledHeaders(Http2Stream stream, Http2Headers headers, int streamDependency, short weight,
boolean exclusive, int padding, boolean endOfStream, ChannelPromise promise) {
FlowControlledHeaders(Http2Stream stream, Http2Headers headers, boolean hasPriority,
int streamDependency, short weight, boolean exclusive,
int padding, boolean endOfStream, ChannelPromise promise) {
super(stream, padding, endOfStream, promise.unvoid());
this.headers = headers;
this.hasPriorty = hasPriority;
this.streamDependency = streamDependency;
this.weight = weight;
this.exclusive = exclusive;
Expand All @@ -552,8 +581,8 @@ public void write(ChannelHandlerContext ctx, int allowedBytes) {
// closeStreamLocal().
promise.addListener(this);

ChannelFuture f = frameWriter.writeHeaders(ctx, stream.id(), headers, streamDependency, weight, exclusive,
padding, endOfStream, promise);
ChannelFuture f = sendHeaders(frameWriter, ctx, stream.id(), headers, hasPriorty, streamDependency,
weight, exclusive, padding, endOfStream, promise);
// Writing headers may fail during the encode state if they violate HPACK limits.
Throwable failureCause = f.cause();
if (failureCause == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,26 @@ public ChannelFuture answer(InvocationOnMock in) throws Throwable {
anyInt(), anyBoolean(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise promise = (ChannelPromise) invocationOnMock.getArguments()[8];
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
ChannelPromise promise = invocationOnMock.getArgument(8);
if (streamClosed) {
fail("Stream already closed");
} else {
streamClosed = (Boolean) invocationOnMock.getArguments()[5];
streamClosed = invocationOnMock.getArgument(5);
}
return promise.setSuccess();
}
});
when(writer.writeHeaders(eq(ctx), anyInt(), any(Http2Headers.class),
anyInt(), anyBoolean(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
ChannelPromise promise = invocationOnMock.getArgument(5);
if (streamClosed) {
fail("Stream already closed");
} else {
streamClosed = invocationOnMock.getArgument(4);
}
return promise.setSuccess();
}
Expand Down Expand Up @@ -335,12 +349,12 @@ public void emptyFrameShouldSplitPadding() throws Exception {
@Test
public void writeHeadersUsingVoidPromise() throws Exception {
final Throwable cause = new RuntimeException("fake exception");
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(),
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class),
anyInt(), anyBoolean(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise promise = invocationOnMock.getArgument(8);
ChannelPromise promise = invocationOnMock.getArgument(5);
assertFalse(promise.isVoid());
return promise.setFailure(cause);
}
Expand All @@ -349,7 +363,7 @@ public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable
// END_STREAM flag, so that a listener is added to the future.
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, newVoidPromise(channel));

verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(),
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class),
anyInt(), anyBoolean(), any(ChannelPromise.class));
// When using a void promise, the error should be propagated via the channel pipeline.
verify(pipeline).fireExceptionCaught(cause);
Expand Down Expand Up @@ -377,7 +391,7 @@ public void headersWriteForUnknownStreamShouldCreateStream() throws Exception {
ChannelPromise promise = newPromise();
encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
eq(false), eq(promise));
assertTrue(promise.isSuccess());
}

Expand All @@ -390,8 +404,8 @@ public void headersWriteShouldOpenStreamForPush() throws Exception {
ChannelPromise promise = newPromise();
encoder.writeHeaders(ctx, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
assertEquals(HALF_CLOSED_REMOTE, stream(PUSH_STREAM_ID).state());
verify(writer).writeHeaders(eq(ctx), eq(PUSH_STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
verify(writer).writeHeaders(eq(ctx), eq(PUSH_STREAM_ID), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
}

@Test
Expand All @@ -406,8 +420,8 @@ public void trailersDoNotEndStreamThrows() {
assertTrue(future.isDone());
assertFalse(future.isSuccess());

verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
}

@Test
Expand All @@ -425,8 +439,8 @@ public void trailersDoNotEndStreamWithDataThrows() {
assertTrue(future.isDone());
assertFalse(future.isSuccess());

verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
}

@Test
Expand All @@ -452,10 +466,10 @@ private void tooManyHeadersThrows(boolean eos) {
assertTrue(future.isDone());
assertFalse(future.isSuccess());

verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise2));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(true), eq(promise2));
}

@Test
Expand Down Expand Up @@ -493,13 +507,13 @@ private void infoHeadersAndTrailers(boolean eos, int infoHeaderCount) {
assertTrue(future.isDone());
assertEquals(eos, future.isSuccess());

verify(writer, times(infoHeaderCount)).writeHeaders(eq(ctx), eq(streamId), eq(infoHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise2));
verify(writer, times(infoHeaderCount)).writeHeaders(eq(ctx), eq(streamId), eq(infoHeaders),
eq(0), eq(false), any(ChannelPromise.class));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise2));
if (eos) {
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise3));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(true), eq(promise3));
}
}

Expand Down Expand Up @@ -536,10 +550,10 @@ private void tooManyHeadersWithDataThrows(boolean eos) {
assertTrue(future.isDone());
assertFalse(future.isSuccess());

verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise2));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(true), eq(promise2));
}

@Test
Expand Down Expand Up @@ -581,13 +595,13 @@ private void infoHeadersAndTrailersWithData(boolean eos, int infoHeaderCount) {
assertTrue(future.isDone());
assertEquals(eos, future.isSuccess());

verify(writer, times(infoHeaderCount)).writeHeaders(eq(ctx), eq(streamId), eq(infoHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise2));
verify(writer, times(infoHeaderCount)).writeHeaders(eq(ctx), eq(streamId), eq(infoHeaders),
eq(0), eq(false), any(ChannelPromise.class));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise2));
if (eos) {
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise3));
verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(true), eq(promise3));
}
}

Expand Down Expand Up @@ -752,8 +766,7 @@ public void headersWriteShouldHalfCloseAfterOnErrorForPreCreatedStream() throws
final ChannelPromise promise = newPromise();
final Throwable ex = new RuntimeException();
// Fake an encoding error, like HPACK's HeaderListSizeException
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise)))
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), eq(true), eq(promise)))
.thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocation) {
Expand All @@ -779,8 +792,7 @@ public void headersWriteShouldHalfCloseAfterOnErrorForImplicitlyCreatedStream()
final ChannelPromise promise = newPromise();
final Throwable ex = new RuntimeException();
// Fake an encoding error, like HPACK's HeaderListSizeException
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise)))
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), eq(true), eq(promise)))
.thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocation) {
Expand Down Expand Up @@ -849,8 +861,8 @@ public void canWriteHeaderFrameAfterGoAwaySent() throws Exception {
goAwaySent(0);
ChannelPromise promise = newPromise();
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
}

@Test
Expand All @@ -868,8 +880,29 @@ public void canWriteHeaderFrameAfterGoAwayReceived() throws Http2Exception {
goAwayReceived(STREAM_ID);
ChannelPromise promise = newPromise();
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
}

@Test
public void headersWithNoPriority() {
writeAllFlowControlledFrames();
final int streamId = 6;
ChannelPromise promise = newPromise();
encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(false), eq(promise));
}

@Test
public void headersWithPriority() {
writeAllFlowControlledFrames();
final int streamId = 6;
ChannelPromise promise = newPromise();
encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 10, DEFAULT_PRIORITY_WEIGHT,
true, 1, false, promise);
verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(10),
eq(DEFAULT_PRIORITY_WEIGHT), eq(true), eq(1), eq(false), eq(promise));
}

private void writeAllFlowControlledFrames() {
Expand Down
Loading

0 comments on commit e875e59

Please sign in to comment.