Skip to content

Commit

Permalink
DefaultHttp2ConnectionEncoder#writeHeaders shouldn't send GO_AWAY if …
Browse files Browse the repository at this point in the history
…stream is closed

Motivation:
DefaultHttp2ConnectionEncoder#writeHeaders attempts to find a stream object, and if one doesn't exist it tries to create one. However in the event that the local endpoint has received a RST_STREAM frame before writing the response headers we attempt to create a stream. Since this stream ID is for the incorrect endpoint we then generate a GO_AWAY for what appears to be a protocol error, but can instead be failed locally.

Modifications:
- Just fail the local promise in the above situation instead of sending a GO_AWAY

Result:
Less severe consequences if the server asynchronously sends headers after a RST_STREAM has been received.
Fixes netty#6906.
  • Loading branch information
Scottmitch committed Jun 28, 2017
1 parent 07a6419 commit bc46a99
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ public ChannelFuture writeData(final ChannelHandlerContext ctx, final int stream
// Allowed sending DATA frames in these states.
break;
default:
throw new IllegalStateException(String.format(
"Stream %d in unexpected state: %s", stream.id(), stream.state()));
throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " + stream.state());
}
} catch (Throwable e) {
data.release();
Expand All @@ -153,7 +152,15 @@ public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int str
try {
Http2Stream stream = connection.stream(streamId);
if (stream == null) {
stream = connection.local().createStream(streamId, endOfStream);
try {
stream = connection.local().createStream(streamId, endOfStream);
} catch (Http2Exception cause) {
if (connection.remote().mayHaveCreatedStream(streamId)) {
promise.tryFailure(new IllegalStateException("Stream no longer exists: " + streamId, cause));
return promise;
}
throw cause;
}
} else {
switch (stream.state()) {
case RESERVED_LOCAL:
Expand All @@ -164,8 +171,8 @@ public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int str
// Allowed sending headers in these states.
break;
default:
throw new IllegalStateException(String.format(
"Stream %d in unexpected state: %s", stream.id(), stream.state()));
throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " +
stream.state());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,16 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
Expand All @@ -67,9 +71,9 @@
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyShort;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -502,6 +506,71 @@ public void run() throws Http2Exception {
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}

@Test
public void headersWriteForPeerStreamWhichWasResetShouldNotGoAway() throws Exception {
bootstrapEnv(1, 1, 1, 0);

final CountDownLatch serverGotRstLatch = new CountDownLatch(1);
final CountDownLatch serverWriteHeadersLatch = new CountDownLatch(1);
final AtomicReference<Throwable> serverWriteHeadersCauseRef = new AtomicReference<Throwable>();

final Http2Headers headers = dummyHeaders();
final int streamId = 3;
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), streamId, headers, CONNECTION_STREAM_ID,
DEFAULT_PRIORITY_WEIGHT, false, 0, false, newPromise());
http2Client.encoder().writeRstStream(ctx(), streamId, Http2Error.CANCEL.code(), newPromise());
http2Client.flush(ctx());
}
});

doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
if (streamId == (Integer) invocationOnMock.getArgument(1)) {
serverGotRstLatch.countDown();
}
return null;
}
}).when(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(streamId), anyLong());

assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(serverGotRstLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));

verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(streamId), eq(headers), anyInt(),
anyShort(), anyBoolean(), anyInt(), eq(false));

// Now have the server attempt to send a headers frame simulating some asynchronous work.
runInChannel(serverConnectedChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Server.encoder().writeHeaders(serverCtx(), streamId, headers, 0, true, serverNewPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
serverWriteHeadersCauseRef.set(future.cause());
serverWriteHeadersLatch.countDown();
}
});
http2Server.flush(serverCtx());
}
});

assertTrue(serverWriteHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
Throwable serverWriteHeadersCause = serverWriteHeadersCauseRef.get();
assertNotNull(serverWriteHeadersCause);
assertThat(serverWriteHeadersCauseRef.get(), not(instanceOf(Http2Exception.class)));

// Server should receive a RST_STREAM for stream 3.
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}

@Test
public void http2ExceptionInPipelineShouldCloseConnection() throws Exception {
bootstrapEnv(1, 1, 2, 1);
Expand Down

0 comments on commit bc46a99

Please sign in to comment.