Skip to content

Commit

Permalink
[netty#4754] Correctly detect websocket upgrade
Browse files Browse the repository at this point in the history
Motivation:

If the Connection header contains multiple values (which is valid) we fail to detect a websocket upgrade

Modification:

- Add new method which allows to check if a header field contains a specific value (and also respect multiple header values)
- Use this method to detect handshake

Result:

Correct detect handshake if Connection header contains multiple values (seperated by ',').
  • Loading branch information
normanmaurer committed Feb 4, 2016
1 parent a0758e7 commit 7ef6db3
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.netty.buffer.ByteBufUtil;
import io.netty.handler.codec.Headers;
import io.netty.util.AsciiString;
import io.netty.util.internal.StringUtil;

import java.text.ParseException;
import java.util.Calendar;
Expand Down Expand Up @@ -1565,6 +1566,48 @@ public boolean contains(String name, String value, boolean ignoreCase) {
return false;
}

/**
* Returns {@code true} if a header with the {@code name} and {@code value} exists, {@code false} otherwise.
* This also handles multiple values that are seperated with a {@code ,}.
* <p>
* If {@code ignoreCase} is {@code true} then a case insensitive compare is done on the value.
* @param name the name of the header to find
* @param value the value of the header to find
* @param ignoreCase {@code true} then a case insensitive compare is run to compare values.
* otherwise a case sensitive compare is run to compare values.
*/
public boolean containsValue(CharSequence name, CharSequence value, boolean ignoreCase) {
List<String> values = getAll(name);
if (values.isEmpty()) {
return false;
}

for (String v: values) {
if (contains(v, value, ignoreCase)) {
return true;
}
}
return false;
}

private static boolean contains(String value, CharSequence expected, boolean ignoreCase) {
String[] parts = StringUtil.split(value, ',');
if (ignoreCase) {
for (String s: parts) {
if (AsciiString.contentEqualsIgnoreCase(expected, s.trim())) {
return true;
}
}
} else {
for (String s: parts) {
if (AsciiString.contentEquals(expected, s.trim())) {
return true;
}
}
}
return false;
}

/**
* {@link Headers#get(Object)} and convert the result to a {@link String}.
* @param name the name of the header to retrieve
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,9 @@ protected void verify(FullHttpResponse response) {
+ upgrade);
}

CharSequence connection = headers.get(HttpHeaderNames.CONNECTION);
if (!HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(connection)) {
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ connection);
+ headers.get(HttpHeaderNames.CONNECTION));
}

ByteBuf challenge = response.content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ protected void verify(FullHttpResponse response) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade);
}

CharSequence connection = headers.get(HttpHeaderNames.CONNECTION);
if (!HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(connection)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: " + connection);
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ headers.get(HttpHeaderNames.CONNECTION));
}

CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ protected void verify(FullHttpResponse response) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade);
}

CharSequence connection = headers.get(HttpHeaderNames.CONNECTION);
if (!HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(connection)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: " + connection);
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ headers.get(HttpHeaderNames.CONNECTION));
}

CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ protected void verify(FullHttpResponse response) {
throw new WebSocketHandshakeException("Invalid handshake response upgrade: " + upgrade);
}

CharSequence connection = headers.get(HttpHeaderNames.CONNECTION);
if (!HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(connection)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: " + connection);
if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) {
throw new WebSocketHandshakeException("Invalid handshake response connection: "
+ headers.get(HttpHeaderNames.CONNECTION));
}

CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public WebSocketServerHandshaker00(String webSocketURL, String subprotocols, int
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) {

// Serve the WebSocket handshake request.
if (!HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(req.headers().get(HttpHeaderNames.CONNECTION))
if (!req.headers().containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)
|| !HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE))) {
throw new WebSocketHandshakeException("not a WebSocket handshake request: missing upgrade");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public WebSocketClientExtensionHandler(WebSocketClientExtensionHandshaker... ext

@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade((HttpRequest) msg)) {
if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade(((HttpRequest) msg).headers())) {
HttpRequest request = (HttpRequest) msg;
String headerValue = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);

Expand All @@ -83,7 +83,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg)
if (msg instanceof HttpResponse) {
HttpResponse response = (HttpResponse) msg;

if (WebSocketExtensionUtil.isWebsocketUpgrade(response)) {
if (WebSocketExtensionUtil.isWebsocketUpgrade(response.headers())) {
String extensionsHeader = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);

if (extensionsHeader != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.util.internal.StringUtil;

import java.util.ArrayList;
Expand All @@ -37,12 +37,9 @@ public final class WebSocketExtensionUtil {

private static final Pattern PARAMETER = Pattern.compile("^([^=]+)(=[\\\"]?([^\\\"]+)[\\\"]?)?$");

static boolean isWebsocketUpgrade(HttpMessage httpMessage) {
if (httpMessage == null) {
throw new NullPointerException("httpMessage");
}
return httpMessage.headers().contains(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true) &&
httpMessage.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true);
static boolean isWebsocketUpgrade(HttpHeaders headers) {
return headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true) &&
headers.contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true);
}

public static List<WebSocketExtensionData> extractExtensions(String extensionHeader) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg)
if (msg instanceof HttpRequest) {
HttpRequest request = (HttpRequest) msg;

if (WebSocketExtensionUtil.isWebsocketUpgrade(request)) {
if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);

if (extensionsHeader != null) {
Expand Down Expand Up @@ -105,7 +105,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg)
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof HttpResponse &&
WebSocketExtensionUtil.isWebsocketUpgrade((HttpResponse) msg) && validExtensions != null) {
WebSocketExtensionUtil.isWebsocketUpgrade(((HttpResponse) msg).headers()) && validExtensions != null) {
HttpResponse response = (HttpResponse) msg;
String headerValue = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx.extensions;

import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import org.junit.Test;

import static org.junit.Assert.*;

public class WebSocketExtensionUtilTest {

@Test
public void testIsWebsocketUpgrade() {
HttpHeaders headers = new DefaultHttpHeaders();
assertFalse(WebSocketExtensionUtil.isWebsocketUpgrade(headers));

headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET);
assertFalse(WebSocketExtensionUtil.isWebsocketUpgrade(headers));

headers.add(HttpHeaderNames.CONNECTION, "Keep-Alive, Upgrade");
assertTrue(WebSocketExtensionUtil.isWebsocketUpgrade(headers));
}
}

0 comments on commit 7ef6db3

Please sign in to comment.