Skip to content

Commit

Permalink
Returning a Bad Gateway to clients when receiving an indecipherable r…
Browse files Browse the repository at this point in the history
…esponse from upstream servers
  • Loading branch information
jekh committed Dec 19, 2016
1 parent e8af6a5 commit f1a86d4
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
Expand Down Expand Up @@ -482,7 +481,7 @@ boolean shouldSuppressInitialRequest() {

protected Future<?> execute() {
LOG.debug("Responding with CONNECT successful");
HttpResponse response = responseFor(HttpVersion.HTTP_1_1,
HttpResponse response = ProxyUtils.createFullHttpResponse(HttpVersion.HTTP_1_1,
CONNECTION_ESTABLISHED);
response.headers().set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.KEEP_ALIVE);
ProxyUtils.addVia(response, proxyServer.getProxyAlias());
Expand Down Expand Up @@ -754,7 +753,7 @@ protected void exceptionCaught(Throwable cause) {
* descending ordering.
*
* Regarding the Javadoc of {@link HttpObjectAggregator} it's needed to have
* the {@link HttpResponseEncoder} or {@link HttpRequestEncoder} before the
* the {@link HttpResponseEncoder} or {@link io.netty.handler.codec.http.HttpRequestEncoder} before the
* {@link HttpObjectAggregator} in the {@link ChannelPipeline}.
*
* @param pipeline
Expand Down Expand Up @@ -999,7 +998,7 @@ private void writeAuthenticationRequired(String realm) {
+ "credentials (e.g., bad password), or your\n"
+ "browser doesn't understand how to supply\n"
+ "the credentials required.</p>\n" + "</body></html>\n";
DefaultFullHttpResponse response = responseFor(HttpVersion.HTTP_1_1,
FullHttpResponse response = ProxyUtils.createFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED, body);
HttpHeaders.setDate(response, new Date());
response.headers().set("Proxy-Authenticate",
Expand Down Expand Up @@ -1201,7 +1200,7 @@ private void stripHopByHopHeaders(HttpHeaders headers) {
*/
private boolean writeBadGateway(HttpRequest httpRequest) {
String body = "Bad Gateway: " + httpRequest.getUri();
DefaultFullHttpResponse response = responseFor(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_GATEWAY, body);
FullHttpResponse response = ProxyUtils.createFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_GATEWAY, body);

if (ProxyUtils.isHEAD(httpRequest)) {
// don't allow any body content in response to a HEAD request
Expand All @@ -1220,7 +1219,7 @@ private boolean writeBadGateway(HttpRequest httpRequest) {
*/
private boolean writeBadRequest(HttpRequest httpRequest) {
String body = "Bad Request to URI: " + httpRequest.getUri();
DefaultFullHttpResponse response = responseFor(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, body);
FullHttpResponse response = ProxyUtils.createFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, body);

if (ProxyUtils.isHEAD(httpRequest)) {
// don't allow any body content in response to a HEAD request
Expand All @@ -1240,7 +1239,7 @@ private boolean writeBadRequest(HttpRequest httpRequest) {
*/
private boolean writeGatewayTimeout(HttpRequest httpRequest) {
String body = "Gateway Timeout";
DefaultFullHttpResponse response = responseFor(HttpVersion.HTTP_1_1,
FullHttpResponse response = ProxyUtils.createFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.GATEWAY_TIMEOUT, body);

if (httpRequest != null && ProxyUtils.isHEAD(httpRequest)) {
Expand Down Expand Up @@ -1299,55 +1298,6 @@ private boolean respondWithShortCircuitResponse(HttpResponse httpResponse) {
return true;
}

/**
* Factory for {@link DefaultFullHttpResponse}s.
*
* @param httpVersion
* @param status
* @param body
* @return
*/
private DefaultFullHttpResponse responseFor(HttpVersion httpVersion,
HttpResponseStatus status, String body) {
byte[] bytes = body.getBytes(Charset.forName("UTF-8"));
ByteBuf content = Unpooled.copiedBuffer(bytes);
return responseFor(httpVersion, status, content, bytes.length);
}

/**
* Factory for {@link DefaultFullHttpResponse}s.
*
* @param httpVersion
* @param status
* @param body
* @param contentLength
* @return
*/
private DefaultFullHttpResponse responseFor(HttpVersion httpVersion,
HttpResponseStatus status, ByteBuf body, int contentLength) {
DefaultFullHttpResponse response = body != null ? new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, status, body)
: new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status);
if (body != null) {
response.headers().set(HttpHeaders.Names.CONTENT_LENGTH,
contentLength);
response.headers().set("Content-Type", "text/html; charset=UTF-8");
}
return response;
}

/**
* Factory for {@link DefaultFullHttpResponse}s.
*
* @param httpVersion
* @param status
* @return
*/
private DefaultFullHttpResponse responseFor(HttpVersion httpVersion,
HttpResponseStatus status) {
return responseFor(httpVersion, status, (ByteBuf) null, 0);
}

/**
* Identify the host and port for a request.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.channel.udt.nio.NioUdtProvider;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpObjectAggregator;
Expand All @@ -22,6 +24,8 @@
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.handler.traffic.GlobalTrafficShapingHandler;
Expand Down Expand Up @@ -215,6 +219,19 @@ protected void read(Object msg) {
protected ConnectionState readHTTPInitial(HttpResponse httpResponse) {
LOG.debug("Received raw response: {}", httpResponse);

if (httpResponse.getDecoderResult().isFailure()) {
LOG.debug("Could not parse response from server. Decoder result: {}", httpResponse.getDecoderResult().toString());

// create a "substitute" Bad Gateway response from the server, since we couldn't understand what the actual
// response from the server was. set the keep-alive on the substitute response to false so the proxy closes
// the connection to the server, since we don't know what state the server thinks the connection is in.
FullHttpResponse substituteResponse = ProxyUtils.createFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.BAD_GATEWAY,
"Unable to parse response from server");
HttpHeaders.setKeepAlive(substituteResponse, false);
httpResponse = substituteResponse;
}

currentFilters.serverToProxyResponseReceiving();

rememberCurrentResponse(httpResponse);
Expand Down
62 changes: 62 additions & 0 deletions src/main/java/org/littleshoot/proxy/impl/ProxyUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.udt.nio.NioUdtProvider;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
Expand All @@ -21,6 +25,7 @@

import java.io.IOException;
import java.net.InetAddress;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -575,4 +580,61 @@ public static boolean isUdtAvailable() {
return false;
}
}

/**
* Creates a new {@link FullHttpResponse} with the specified String as the body contents (encoded using UTF-8).
*
* @param httpVersion HTTP version of the response
* @param status HTTP status code
* @param body body to include in the FullHttpResponse; will be UTF-8 encoded
* @return new http response object
*/
public static FullHttpResponse createFullHttpResponse(HttpVersion httpVersion,
HttpResponseStatus status,
String body) {
byte[] bytes = body.getBytes(StandardCharsets.UTF_8);
ByteBuf content = Unpooled.copiedBuffer(bytes);

return createFullHttpResponse(httpVersion, status, "text/html; charset=utf-8", content, bytes.length);
}

/**
* Creates a new {@link FullHttpResponse} with no body content
*
* @param httpVersion HTTP version of the response
* @param status HTTP status code
* @return new http response object
*/
public static FullHttpResponse createFullHttpResponse(HttpVersion httpVersion,
HttpResponseStatus status) {
return createFullHttpResponse(httpVersion, status, null, null, 0);
}

/**
* Creates a new {@link FullHttpResponse} with the specified body.
*
* @param httpVersion HTTP version of the response
* @param status HTTP status code
* @param contentType the Content-Type of the body
* @param body body to include in the FullHttpResponse; if null
* @param contentLength number of bytes to send in the Content-Length header; should equal the number of bytes in the ByteBuf
* @return new http response object
*/
public static FullHttpResponse createFullHttpResponse(HttpVersion httpVersion,
HttpResponseStatus status,
String contentType,
ByteBuf body,
int contentLength) {
DefaultFullHttpResponse response;

if (body != null) {
response = new DefaultFullHttpResponse(httpVersion, status, body);
response.headers().set(HttpHeaders.Names.CONTENT_LENGTH, contentLength);
response.headers().set(HttpHeaders.Names.CONTENT_TYPE, contentType);
} else {
response = new DefaultFullHttpResponse(httpVersion, status);
}

return response;
}
}

0 comments on commit f1a86d4

Please sign in to comment.