Skip to content

Commit

Permalink
Add websocket encoder / decoder in correct order to the pipeline when…
Browse files Browse the repository at this point in the history
… HttpServerCodec is used (netty#9386)

Motivation:

We need to ensure we place the encoder before the decoder when doing the websockets upgrade as the decoder may produce a close frame when protocol violations are detected.

Modifications:

- Correctly place encoder before decoder
- Add unit test

Result:

Fixes netty#9300
  • Loading branch information
normanmaurer authored Jul 18, 2019
1 parent dd1785b commit 26c3abc
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
return promise;
}
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
encoderName = ctx.name();
} else {
p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package io.netty.handler.codec.http.websocketx;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
Expand All @@ -27,10 +29,15 @@
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Test;

import java.util.Iterator;

import static io.netty.handler.codec.http.HttpVersion.*;

public class WebSocketServerHandshaker13Test {
Expand All @@ -47,8 +54,50 @@ public void testPerformOpeningHandshakeSubProtocolNotSupported() {

private static void testPerformOpeningHandshake0(boolean subProtocol) {
EmbeddedChannel ch = new EmbeddedChannel(
new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder());
new HttpObjectAggregator(42), new HttpResponseEncoder(), new HttpRequestDecoder());

if (subProtocol) {
testUpgrade0(ch, new WebSocketServerHandshaker13(
"ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false));
} else {
testUpgrade0(ch, new WebSocketServerHandshaker13(
"ws://example.com/chat", null, false, Integer.MAX_VALUE, false));
}
Assert.assertFalse(ch.finish());
}

@Test
public void testCloseReasonWithEncoderAndDecoder() {
testCloseReason0(new HttpResponseEncoder(), new HttpRequestDecoder());
}

@Test
public void testCloseReasonWithCodec() {
testCloseReason0(new HttpServerCodec());
}

private static void testCloseReason0(ChannelHandler... handlers) {
EmbeddedChannel ch = new EmbeddedChannel(
new HttpObjectAggregator(42));
ch.pipeline().addLast(handlers);
testUpgrade0(ch, new WebSocketServerHandshaker13("ws://example.com/chat", "chat",
WebSocketDecoderConfig.newBuilder().maxFramePayloadLength(4).closeOnProtocolViolation(true).build()));

ch.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[8])));
ByteBuf buffer = ch.readOutbound();
try {
ch.writeInbound(buffer);
Assert.fail();
} catch (CorruptedWebSocketFrameException expected) {
// expected
}
ReferenceCounted closeMessage = ch.readOutbound();
Assert.assertThat(closeMessage, CoreMatchers.instanceOf(ByteBuf.class));
closeMessage.release();
Assert.assertFalse(ch.finish());
}

private static void testUpgrade0(EmbeddedChannel ch, WebSocketServerHandshaker13 handshaker) {
FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/chat");
req.headers().set(HttpHeaderNames.HOST, "server.example.com");
req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET);
Expand All @@ -58,13 +107,7 @@ private static void testPerformOpeningHandshake0(boolean subProtocol) {
req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat");
req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13");

if (subProtocol) {
new WebSocketServerHandshaker13(
"ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false).handshake(ch, req);
} else {
new WebSocketServerHandshaker13(
"ws://example.com/chat", null, false, Integer.MAX_VALUE, false).handshake(ch, req);
}
handshaker.handshake(ch, req);

ByteBuf resBuf = ch.readOutbound();

Expand All @@ -74,8 +117,10 @@ private static void testPerformOpeningHandshake0(boolean subProtocol) {

Assert.assertEquals(
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT));
if (subProtocol) {
Assert.assertEquals("chat", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
Iterator<String> subProtocols = handshaker.subprotocols().iterator();
if (subProtocols.hasNext()) {
Assert.assertEquals(subProtocols.next(),
res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
} else {
Assert.assertNull(res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL));
}
Expand Down

0 comments on commit 26c3abc

Please sign in to comment.