Skip to content

Commit

Permalink
WebSocket enhancements
Browse files Browse the repository at this point in the history
- Refactoring and adding suggestions from Norman and Vibul.
  • Loading branch information
danbev committed Sep 9, 2012
1 parent c6436ad commit 150e8b4
Show file tree
Hide file tree
Showing 4 changed files with 506 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;

import static io.netty.handler.codec.http.HttpHeaders.isKeepAlive;
import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.logging.InternalLogger;
import io.netty.logging.InternalLoggerFactory;

/**
* Handles the HTTP handshake (the HTTP Upgrade request)
*/
class WebSocketServerHandshakeHandler extends ChannelInboundMessageHandlerAdapter<HttpRequest> {

private static final InternalLogger logger =
InternalLoggerFactory.getInstance(WebSocketServerHandshakeHandler.class);
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;

public WebSocketServerHandshakeHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
}

@Override
public void messageReceived(final ChannelHandlerContext ctx, HttpRequest req) throws Exception {
if (req.getMethod() != GET) {
sendHttpResponse(ctx, req, new DefaultHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}

final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(req, websocketPath), subprotocols, allowExtensions);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
wsFactory.sendUnsupportedWebSocketVersionResponse(ctx.channel());
} else {
try {
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
}
}
});
WebSocketServerProtocolHandler.setHandshaker(ctx, handshaker);
ctx.pipeline().replace(this, "WS403Responder",
WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
} catch (WebSocketHandshakeException e) {
ctx.fireExceptionCaught(e);
}
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("Exception Caught", cause);
ctx.close();
}

private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.channel().write(res);
if (!isKeepAlive(req) || res.getStatus().getCode() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}

private static String getWebSocketLocation(HttpRequest req, String path) {
return "ws://" + req.getHeader(HttpHeaders.Names.HOST) + path;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;

import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.AttributeKey;

/**
* Handles WebSocket control frames (Close, Ping, Pong) and data frames (Text and Binary) are passed
* to the next handler in the pipeline.
*/
public class WebSocketServerProtocolHandler extends ChannelInboundMessageHandlerAdapter<WebSocketFrame> {

private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
new AttributeKey<WebSocketServerHandshaker>(WebSocketServerHandshaker.class.getName());

private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;

public WebSocketServerProtocolHandler(String websocketPath) {
this(websocketPath, null, false);
}

public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
this(websocketPath, subprotocols, false);
}

public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
}

@Override
public void afterAdd(ChannelHandlerContext ctx) {
// Add the WebSocketHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketServerHandshakeHandler.class.getName(),
new WebSocketServerHandshakeHandler(websocketPath, subprotocols, allowExtensions));
}

@Override
public void messageReceived(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
if (frame instanceof CloseWebSocketFrame) {
WebSocketServerHandshaker handshaker = WebSocketServerProtocolHandler.getHandshaker(ctx);
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame);
return;
} else if (frame instanceof PingWebSocketFrame) {
ctx.channel().write(new PongWebSocketFrame(frame.getBinaryData()));
return;
}

ctx.nextInboundMessageBuffer().add(frame);
ctx.fireInboundBufferUpdated();
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
try {
if (cause instanceof WebSocketHandshakeException) {
DefaultHttpResponse response = new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.BAD_REQUEST);
response.setContent(Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
ctx.channel().write(response);
}
} finally {
ctx.close();
}
}

static WebSocketServerHandshaker getHandshaker(ChannelHandlerContext ctx) {
return ctx.attr(HANDSHAKER_ATTR_KEY).get();
}

static void setHandshaker(ChannelHandlerContext ctx, WebSocketServerHandshaker handshaker) {
ctx.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
}

static ChannelHandler forbiddenHttpRequestResponder() {
return new ChannelInboundMessageHandlerAdapter<Object>() {
@Override
public void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof WebSocketFrame)) {
DefaultHttpResponse response = new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);
ctx.channel().write(response);
} else {
ctx.nextInboundMessageBuffer().add(msg);
ctx.fireInboundBufferUpdated();
}
}
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version
* 2.0 (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec.http.websocketx;

import static io.netty.handler.codec.http.HttpHeaders.Values.WEBSOCKET;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.HttpHeaders.Names;

public class WebSocketRequestBuilder {

private HttpVersion httpVersion;
private HttpMethod method;
private String uri;
private String host;
private String upgrade;
private String connection;
private String key;
private String origin;
private WebSocketVersion version;

public WebSocketRequestBuilder httpVersion(HttpVersion httpVersion) {
this.httpVersion = httpVersion;
return this;
}

public WebSocketRequestBuilder method(HttpMethod method) {
this.method = method;
return this;
}

public WebSocketRequestBuilder uri(String uri) {
this.uri = uri;
return this;
}

public WebSocketRequestBuilder host(String host) {
this.host = host;
return this;
}

public WebSocketRequestBuilder upgrade(String upgrade) {
this.upgrade = upgrade;
return this;
}

public WebSocketRequestBuilder connection(String connection) {
this.connection = connection;
return this;
}

public WebSocketRequestBuilder key(String key) {
this.key = key;
return this;
}

public WebSocketRequestBuilder origin(String origin) {
this.origin = origin;
return this;
}

public WebSocketRequestBuilder version13() {
this.version = WebSocketVersion.V13;
return this;
}

public WebSocketRequestBuilder version8() {
this.version = WebSocketVersion.V08;
return this;
}

public WebSocketRequestBuilder version00() {
this.version = null;
return this;
}

public WebSocketRequestBuilder noVersion() {
return this;
}

public HttpRequest build() {
HttpRequest req = new DefaultHttpRequest(httpVersion, method, uri);
if (host != null) {
req.setHeader(Names.HOST, host);
}
if (upgrade != null) {
req.setHeader(Names.UPGRADE, upgrade);
}
if (connection != null) {
req.setHeader(Names.CONNECTION, connection);
}
if (key != null) {
req.setHeader(Names.SEC_WEBSOCKET_KEY, key);
}
if (origin != null) {
req.setHeader(Names.SEC_WEBSOCKET_ORIGIN, origin);
}
if (version != null) {
req.setHeader(Names.SEC_WEBSOCKET_VERSION, version.toHttpHeaderValue());
}
return req;
}

public static HttpRequest sucessful() {
return new WebSocketRequestBuilder().httpVersion(HTTP_1_1)
.method(HttpMethod.GET)
.uri("/test")
.host("server.example.com")
.upgrade(WEBSOCKET.toLowerCase())
.key("dGhlIHNhbXBsZSBub25jZQ==")
.origin("http://example.com")
.version13()
.build();
}
}
Loading

0 comments on commit 150e8b4

Please sign in to comment.