Skip to content

Commit

Permalink
[netty#4533] Ensure replacement of decoder is delayed after finishHan…
Browse files Browse the repository at this point in the history
…dshake() is called

Motivation:

If the user calls handshake.finishHandshake() we need to ensure that the user has the chance to setup the pipeline before any WebSocketFrames are read. Because of this we need
to delay the removal of the HttpRequestDecoder.

Modifications:

- Remove the HttpRequestDecoder via the EventLoop and so delay it which gives the user a chance to setup the pipeline after finishHandshake() completes
- Add unit test for this.

Result:

Less surpising and correct behaviour even if the http response and websocket frame are received in one read operation.
  • Loading branch information
normanmaurer committed Feb 4, 2016
1 parent ef0e053 commit 7a56294
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.StringUtil;

import java.net.URI;
Expand Down Expand Up @@ -243,7 +244,7 @@ public final void finishHandshake(Channel channel, FullHttpResponse response) {

setHandshakeComplete();

ChannelPipeline p = channel.pipeline();
final ChannelPipeline p = channel.pipeline();
// Remove decompressor from pipeline if its in use
HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
if (decompressor != null) {
Expand All @@ -263,13 +264,38 @@ public final void finishHandshake(Channel channel, FullHttpResponse response) {
throw new IllegalStateException("ChannelPipeline does not contain " +
"a HttpRequestEncoder or HttpClientCodec");
}
p.replace(ctx.name(), "ws-decoder", newWebsocketDecoder());
final HttpClientCodec codec = (HttpClientCodec) ctx.handler();
// Remove the encoder part of the codec as the user may start writing frames after this method returns.
codec.removeOutboundHandler();

p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());

// Delay the removal of the decoder so the user can setup the pipeline if needed to handle
// WebSocketFrame messages.
// See https://github.com/netty/netty/issues/4533
channel.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
p.remove(codec);
}
});
} else {
if (p.get(HttpRequestEncoder.class) != null) {
// Remove the encoder part of the codec as the user may start writing frames after this method returns.
p.remove(HttpRequestEncoder.class);
}
p.replace(ctx.name(),
"ws-decoder", newWebsocketDecoder());
final ChannelHandlerContext context = ctx;
p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());

// Delay the removal of the decoder so the user can setup the pipeline if needed to handle
// WebSocketFrame messages.
// See https://github.com/netty/netty/issues/4533
channel.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
p.remove(context.handler());
}
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,27 @@
*/
package io.netty.handler.codec.http.websocketx;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.util.CharsetUtil;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.Test;

import java.net.URI;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public abstract class WebSocketClientHandshakerTest {
protected abstract WebSocketClientHandshaker newHandshaker(URI uri);
Expand All @@ -36,4 +51,105 @@ public void testRawPath() {
request.release();
}
}

@Test(timeout = 3000)
public void testHttpResponseAndFrameInSameBuffer() {
testHttpResponseAndFrameInSameBuffer(false);
}

@Test(timeout = 3000)
public void testHttpResponseAndFrameInSameBufferCodec() {
testHttpResponseAndFrameInSameBuffer(true);
}

private void testHttpResponseAndFrameInSameBuffer(boolean codec) {
String url = "ws://localhost:9999/ws";
final WebSocketClientHandshaker shaker = newHandshaker(URI.create(url));
final WebSocketClientHandshaker handshaker = new WebSocketClientHandshaker(
shaker.uri(), shaker.version(), null, EmptyHttpHeaders.INSTANCE, Integer.MAX_VALUE) {
@Override
protected FullHttpRequest newHandshakeRequest() {
return shaker.newHandshakeRequest();
}

@Override
protected void verify(FullHttpResponse response) {
// Not do any verification, so we not need to care sending the correct headers etc in the test,
// which would just make things more complicated.
}

@Override
protected WebSocketFrameDecoder newWebsocketDecoder() {
return shaker.newWebsocketDecoder();
}

@Override
protected WebSocketFrameEncoder newWebSocketEncoder() {
return shaker.newWebSocketEncoder();
}
};

byte[] data = new byte[24];
ThreadLocalRandom.current().nextBytes(data);

// Create a EmbeddedChannel which we will use to encode a BinaryWebsocketFrame to bytes and so use these
// to test the actual handshaker.
WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(url, null, false);
WebSocketServerHandshaker socketServerHandshaker = factory.newHandshaker(shaker.newHandshakeRequest());
EmbeddedChannel websocketChannel = new EmbeddedChannel(socketServerHandshaker.newWebSocketEncoder(),
socketServerHandshaker.newWebsocketDecoder());
assertTrue(websocketChannel.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(data))));

byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\nContent-Length: 0\r\n\r\n".getBytes(CharsetUtil.US_ASCII);

CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
compositeByteBuf.addComponent(Unpooled.wrappedBuffer(bytes));
compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + bytes.length);
for (;;) {
ByteBuf frameBytes = websocketChannel.readOutbound();
if (frameBytes == null) {
break;
}
compositeByteBuf.addComponent(frameBytes);
compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + frameBytes.readableBytes());
}

EmbeddedChannel ch = new EmbeddedChannel(new HttpObjectAggregator(Integer.MAX_VALUE),
new SimpleChannelInboundHandler<FullHttpResponse>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
handshaker.finishHandshake(ctx.channel(), msg);
ctx.pipeline().remove(this);
}
});
if (codec) {
ch.pipeline().addFirst(new HttpClientCodec());
} else {
ch.pipeline().addFirst(new HttpRequestEncoder(), new HttpResponseDecoder());
}
// We need to first write the request as HttpClientCodec will fail if we receive a response before a request
// was written.
shaker.handshake(ch).syncUninterruptibly();
for (;;) {
// Just consume the bytes, we are not interested in these.
ByteBuf buf = ch.readOutbound();
if (buf == null) {
break;
}
buf.release();
}
assertTrue(ch.writeInbound(compositeByteBuf));
assertTrue(ch.finish());

BinaryWebSocketFrame frame = ch.readInbound();
ByteBuf expect = Unpooled.wrappedBuffer(data);
try {
assertEquals(expect, frame.content());
assertTrue(frame.isFinalFragment());
assertEquals(0, frame.rsv());
} finally {
expect.release();
frame.release();
}
}
}

0 comments on commit 7a56294

Please sign in to comment.