Skip to content

Commit

Permalink
CorsHandler to respect http connection (keep-alive) header.
Browse files Browse the repository at this point in the history
Motivation:

The CorsHandler currently closes the channel when it responds to a preflight (OPTIONS)
request or in the event of a short circuit due to failed validation.

Especially in an environment where there's a proxy in front of the service this causes
unnecessary connection churn.

Modifications:

CorsHandler now uses HttpUtil to determine if the connection should be closed
after responding and to set the Connection header on the response.

Result:

Channel will stay open when the CorsHandler responds unless the client specifies otherwise
or the protocol version is HTTP/1.0
  • Loading branch information
willblackie authored and normanmaurer committed Sep 6, 2016
1 parent dfa3bbb commit e3aca1f
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.handler.codec.http.cors;

import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
Expand All @@ -24,6 +25,7 @@
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

Expand Down Expand Up @@ -80,7 +82,7 @@ private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest
setPreflightHeaders(response);
}
release(request);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
respond(ctx, request, response);
}

/**
Expand Down Expand Up @@ -203,8 +205,22 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final Chann
}

private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN))
.addListener(ChannelFutureListener.CLOSE);
release(request);
respond(ctx, request, new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN));
}

private static void respond(
final ChannelHandlerContext ctx,
final HttpRequest request,
final HttpResponse response) {

final boolean keepAlive = HttpUtil.isKeepAlive(request);

HttpUtil.setKeepAlive(response, keepAlive);

final ChannelFuture future = ctx.writeAndFlush(response);
if (!keepAlive) {
future.addListener(ChannelFutureListener.CLOSE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.util.AsciiString;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;

import java.util.Arrays;
Expand All @@ -35,10 +38,13 @@
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD;
import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.DATE;
import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN;
import static io.netty.handler.codec.http.HttpHeaderNames.VARY;
import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE;
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
import static io.netty.handler.codec.http.HttpMethod.DELETE;
import static io.netty.handler.codec.http.HttpMethod.GET;
Expand Down Expand Up @@ -288,15 +294,118 @@ public void shortCurcuitNonCorsRequest() {
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
}

@Test
public void shortCurcuitWithConnectionKeepAliveShouldStayOpen() {
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = createHttpRequest(GET);
request.headers().set(ORIGIN, "http://localhost:8888");
request.headers().set(CONNECTION, KEEP_ALIVE);

assertThat(channel.writeInbound(request), is(false));
final HttpResponse response = channel.readOutbound();
assertThat(HttpUtil.isKeepAlive(response), is(true));

assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(FORBIDDEN));
assertThat(ReferenceCountUtil.release(response), is(true));
assertThat(channel.finish(), is(false));
}

@Test
public void shortCurcuitWithoutConnectionShouldStayOpen() {
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = createHttpRequest(GET);
request.headers().set(ORIGIN, "http://localhost:8888");

assertThat(channel.writeInbound(request), is(false));
final HttpResponse response = channel.readOutbound();
assertThat(HttpUtil.isKeepAlive(response), is(true));

assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(FORBIDDEN));
assertThat(ReferenceCountUtil.release(response), is(true));
assertThat(channel.finish(), is(false));
}

@Test
public void shortCurcuitWithConnectionCloseShouldClose() {
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = createHttpRequest(GET);
request.headers().set(ORIGIN, "http://localhost:8888");
request.headers().set(CONNECTION, CLOSE);

assertThat(channel.writeInbound(request), is(false));
final HttpResponse response = channel.readOutbound();
assertThat(HttpUtil.isKeepAlive(response), is(false));

assertThat(channel.isOpen(), is(false));
assertThat(response.status(), is(FORBIDDEN));
assertThat(ReferenceCountUtil.release(response), is(true));
assertThat(channel.finish(), is(false));
}

@Test
public void preflightRequestShouldReleaseRequest() {
final CorsConfig config = forOrigin("http://localhost:8888")
.preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2"))
.build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1");
channel.writeInbound(request);
final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1", null);
assertThat(channel.writeInbound(request), is(false));
assertThat(request.refCnt(), is(0));
assertThat(ReferenceCountUtil.release(channel.readOutbound()), is(true));
assertThat(channel.finish(), is(false));
}

@Test
public void preflightRequestWithConnectionKeepAliveShouldStayOpen() throws Exception {

final CorsConfig config = forOrigin("http://localhost:8888").build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", KEEP_ALIVE);
assertThat(channel.writeInbound(request), is(false));
final HttpResponse response = channel.readOutbound();
assertThat(HttpUtil.isKeepAlive(response), is(true));

assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(OK));
assertThat(ReferenceCountUtil.release(response), is(true));
assertThat(channel.finish(), is(false));
}

@Test
public void preflightRequestWithoutConnectionShouldStayOpen() throws Exception {

final CorsConfig config = forOrigin("http://localhost:8888").build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", null);
assertThat(channel.writeInbound(request), is(false));
final HttpResponse response = channel.readOutbound();
assertThat(HttpUtil.isKeepAlive(response), is(true));

assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(OK));
assertThat(ReferenceCountUtil.release(response), is(true));
assertThat(channel.finish(), is(false));
}

@Test
public void preflightRequestWithConnectionCloseShouldClose() throws Exception {

final CorsConfig config = forOrigin("http://localhost:8888").build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", CLOSE);
assertThat(channel.writeInbound(request), is(false));
final HttpResponse response = channel.readOutbound();
assertThat(HttpUtil.isKeepAlive(response), is(false));

assertThat(channel.isOpen(), is(false));
assertThat(response.status(), is(OK));
assertThat(ReferenceCountUtil.release(response), is(true));
assertThat(channel.finish(), is(false));
}

@Test
Expand All @@ -305,8 +414,10 @@ public void forbiddenShouldReleaseRequest() {
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler());
final FullHttpRequest request = createHttpRequest(GET);
request.headers().set(ORIGIN, "http://localhost:8888");
channel.writeInbound(request);
assertThat(channel.writeInbound(request), is(false));
assertThat(request.refCnt(), is(0));
assertThat(ReferenceCountUtil.release(channel.readOutbound()), is(true));
assertThat(channel.finish(), is(false));
}

private static HttpResponse simpleRequest(final CorsConfig config, final String origin) {
Expand All @@ -331,23 +442,31 @@ private static HttpResponse simpleRequest(final CorsConfig config,
if (requestHeaders != null) {
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders);
}
channel.writeInbound(httpRequest);
assertThat(channel.writeInbound(httpRequest), is(false));
return (HttpResponse) channel.readOutbound();
}

private static HttpResponse preflightRequest(final CorsConfig config,
final String origin,
final String requestHeaders) {
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
channel.writeInbound(optionsRequest(origin, requestHeaders));
return (HttpResponse) channel.readOutbound();
assertThat(channel.writeInbound(optionsRequest(origin, requestHeaders, null)), is(false));
HttpResponse response = channel.readOutbound();
assertThat(channel.finish(), is(false));
return response;
}

private static FullHttpRequest optionsRequest(final String origin, final String requestHeaders) {
private static FullHttpRequest optionsRequest(final String origin,
final String requestHeaders,
final AsciiString connection) {
final FullHttpRequest httpRequest = createHttpRequest(OPTIONS);
httpRequest.headers().set(ORIGIN, origin);
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.method().toString());
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders);
if (connection != null) {
httpRequest.headers().set(CONNECTION, connection);
}

return httpRequest;
}

Expand Down

0 comments on commit e3aca1f

Please sign in to comment.