Skip to content

Commit

Permalink
Removes hop-by-hop headers.
Browse files Browse the repository at this point in the history
spencergibb committed Dec 22, 2017

Verified

This commit was signed with the committer’s verified signature.
spencergibb Spencer Gibb
1 parent e44165f commit e117906
Showing 10 changed files with 238 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
import java.util.List;
import java.util.function.Consumer;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.actuate.autoconfigure.web.ManagementContextConfiguration;
import org.springframework.boot.actuate.health.Health;
@@ -34,8 +35,10 @@
import org.springframework.cloud.gateway.actuate.GatewayWebfluxEndpoint;
import org.springframework.cloud.gateway.filter.ForwardRoutingFilter;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.HttpHeadersFilter;
import org.springframework.cloud.gateway.filter.NettyRoutingFilter;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.cloud.gateway.filter.RemoveHopByHopHeadersFilter;
import org.springframework.cloud.gateway.filter.RouteToRequestUrlFilter;
import org.springframework.cloud.gateway.filter.WebsocketRoutingFilter;
import org.springframework.cloud.gateway.filter.factory.AddRequestHeaderGatewayFilterFactory;
@@ -98,6 +101,7 @@
import reactor.core.publisher.Flux;
import reactor.ipc.netty.http.client.HttpClient;
import reactor.ipc.netty.http.client.HttpClientOptions;
import reactor.ipc.netty.http.client.HttpClientRequest;
import reactor.ipc.netty.resources.PoolResources;
import rx.RxReactiveStreams;

@@ -130,8 +134,9 @@ public Consumer<? super HttpClientOptions.Builder> nettyClientOptions() {
}

@Bean
public NettyRoutingFilter routingFilter(HttpClient httpClient) {
return new NettyRoutingFilter(httpClient);
public NettyRoutingFilter routingFilter(HttpClient httpClient,
ObjectProvider<List<HttpHeadersFilter>> headersFilters) {
return new NettyRoutingFilter(httpClient, headersFilters);
}

@Bean
@@ -205,8 +210,15 @@ public SecureHeadersProperties secureHeadersProperties() {
return new SecureHeadersProperties();
}

// GlobalFilter beans
// HttpHeaderFilter beans

@Bean
public RemoveHopByHopHeadersFilter removeHopByHopHeadersFilter() {
return new RemoveHopByHopHeadersFilter();
}


// GlobalFilter beans
@Bean
public RouteToRequestUrlFilter routeToRequestUrlFilter() {
return new RouteToRequestUrlFilter();
@@ -224,8 +236,10 @@ public WebSocketService webSocketService() {
}

@Bean
public WebsocketRoutingFilter websocketRoutingFilter(WebSocketClient webSocketClient, WebSocketService webSocketService) {
return new WebsocketRoutingFilter(webSocketClient, webSocketService);
public WebsocketRoutingFilter websocketRoutingFilter(WebSocketClient webSocketClient,
WebSocketService webSocketService,
ObjectProvider<List<HttpHeadersFilter>> headersFilters) {
return new WebsocketRoutingFilter(webSocketClient, webSocketService, headersFilters);
}

/*@Bean
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.springframework.cloud.gateway.filter;

import org.springframework.http.HttpHeaders;

import java.util.List;

@FunctionalInterface
public interface HttpHeadersFilter {

HttpHeaders filter(HttpHeaders original);

static HttpHeaders filter(List<HttpHeadersFilter> filters, HttpHeaders original) {
HttpHeaders filtered = original;
if (filters != null) {
for (HttpHeadersFilter filter: filters) {
filtered = filter.filter(filtered);
}
}
return filtered;
}
}
Original file line number Diff line number Diff line change
@@ -18,7 +18,9 @@
package org.springframework.cloud.gateway.filter;

import java.net.URI;
import java.util.List;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
@@ -47,9 +49,12 @@
public class NettyRoutingFilter implements GlobalFilter, Ordered {

private final HttpClient httpClient;
private final ObjectProvider<List<HttpHeadersFilter>> headersFilters;

public NettyRoutingFilter(HttpClient httpClient) {
public NettyRoutingFilter(HttpClient httpClient,
ObjectProvider<List<HttpHeadersFilter>> headersFilters) {
this.httpClient = httpClient;
this.headersFilters = headersFilters;
}

@Override
@@ -72,15 +77,22 @@ public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
final HttpMethod method = HttpMethod.valueOf(request.getMethod().toString());
final String url = requestUrl.toString();

HttpHeaders filtered = HttpHeadersFilter.filter(this.headersFilters.getIfAvailable(),
request.getHeaders());

final DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders();
request.getHeaders().forEach(httpHeaders::set);
filtered.forEach(httpHeaders::set);

String transferEncoding = request.getHeaders().getFirst(HttpHeaders.TRANSFER_ENCODING);
boolean chunkedTransfer = "chunked".equalsIgnoreCase(transferEncoding);

boolean preserveHost = exchange.getAttributeOrDefault(PRESERVE_HOST_HEADER_ATTRIBUTE, false);

return this.httpClient.request(method, url, req -> {
final HttpClientRequest proxyRequest = req.options(NettyPipeline.SendOptions::flushOnEach)
.failOnClientError(false)
.headers(httpHeaders);
.headers(httpHeaders)
.chunkedTransfer(chunkedTransfer)
.failOnClientError(false);

if (preserveHost) {
String host = request.getHeaders().getFirst(HttpHeaders.HOST);
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
// until the WebHandler is run
return chain.filter(exchange).then(Mono.defer(() -> {
HttpClientResponse clientResponse = exchange.getAttribute(CLIENT_RESPONSE_ATTR);
// HttpClientResponse clientResponse = getAttribute(exchange, CLIENT_RESPONSE_ATTR, HttpClientResponse.class);

if (clientResponse == null) {
return Mono.empty();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package org.springframework.cloud.gateway.filter;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;

public class RemoveHopByHopHeadersFilter implements HttpHeadersFilter, Ordered {

public static final Set<String> HEADERS_REMOVED_ON_REQUEST =
new HashSet<>(Arrays.asList(
"connection",
"keep-alive",
"transfer-encoding",
"te",
"trailer",
"proxy-authorization",
"proxy-authenticate",
"x-application-context",
"upgrade"
// these two are not listed in https://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-7.1.3
//"proxy-connection",
// "content-length",
));

@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}

@Override
public HttpHeaders filter(HttpHeaders original) {
HttpHeaders filtered = new HttpHeaders();
List<String> connection = original.getConnection();
Set<String> toFilter = new HashSet<>(connection);
toFilter.addAll(HEADERS_REMOVED_ON_REQUEST);

original.entrySet().stream()
.filter(entry -> !toFilter.contains(entry.getKey().toLowerCase()))
.forEach(entry -> filtered.addAll(entry.getKey(), entry.getValue()));

return filtered;
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
package org.springframework.cloud.gateway.filter;

import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.reactive.socket.server.support.HandshakeWebSocketService;
import org.springframework.web.server.ServerWebExchange;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
@@ -29,15 +29,14 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered {

private final WebSocketClient webSocketClient;
private final WebSocketService webSocketService;

public WebsocketRoutingFilter(WebSocketClient webSocketClient) {
this(webSocketClient, new HandshakeWebSocketService());
}
private final ObjectProvider<List<HttpHeadersFilter>> headersFilters;

public WebsocketRoutingFilter(WebSocketClient webSocketClient,
WebSocketService webSocketService) {
WebSocketService webSocketService,
ObjectProvider<List<HttpHeadersFilter>> headersFilters) {
this.webSocketClient = webSocketClient;
this.webSocketService = webSocketService;
this.headersFilters = headersFilters;
}

@Override
@@ -55,8 +54,33 @@ public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
}
setAlreadyRouted(exchange);


HttpHeaders headers = exchange.getRequest().getHeaders();
HttpHeaders filtered = HttpHeadersFilter.filter(getHeadersFilters(),
headers);

List<String> protocols = headers.get(SEC_WEBSOCKET_PROTOCOL);

return this.webSocketService.handleRequest(exchange,
new ProxyWebSocketHandler(requestUrl, this.webSocketClient, exchange.getRequest().getHeaders()));
new ProxyWebSocketHandler(requestUrl, this.webSocketClient,
filtered, protocols));
}

private List<HttpHeadersFilter> getHeadersFilters() {
List<HttpHeadersFilter> filters = this.headersFilters.getIfAvailable();
if (filters == null) {
filters = new ArrayList<>();
}

filters.add(original -> {
HttpHeaders filtered = new HttpHeaders();
original.entrySet().stream()
.filter(entry -> !entry.getKey().toLowerCase().startsWith("sec-websocket"))
.forEach(header -> filtered.addAll(header.getKey(), header.getValue()));
return filtered;
});

return filters;
}

private static class ProxyWebSocketHandler implements WebSocketHandler {
@@ -66,19 +90,10 @@ private static class ProxyWebSocketHandler implements WebSocketHandler {
private final HttpHeaders headers;
private final List<String> subProtocols;

public ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers) {
public ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers, List<String> protocols) {
this.client = client;
this.url = url;
this.headers = new HttpHeaders();//headers;
//TODO: better strategy to filter these headers?
headers.entrySet().forEach(header -> {
if (!header.getKey().toLowerCase().startsWith("sec-websocket")
&& !header.getKey().equalsIgnoreCase("upgrade")
&& !header.getKey().equalsIgnoreCase("connection")) {
this.headers.addAll(header.getKey(), header.getValue());
}
});
List<String> protocols = headers.get(SEC_WEBSOCKET_PROTOCOL);
this.headers = headers;
if (protocols != null) {
this.subProtocols = protocols;
} else {
Original file line number Diff line number Diff line change
@@ -38,4 +38,23 @@ public GatewayFilter apply() {
return chain.filter(exchange);
};
}

/*public static class RequestMutator implements ProxyRequestMutator<HttpClientRequest> {
@Override
public void mutate(ServerWebExchange exchange, HttpClientRequest request) {
boolean preserveHost = exchange.getAttributeOrDefault(PRESERVE_HOST_HEADER_ATTRIBUTE, false);
if (preserveHost) {
String host = exchange.getRequest().getHeaders().getFirst(HttpHeaders.HOST);
if (StringUtils.isEmpty(host)) {
List<String> hosts = exchange.getAttribute(ORIGINAL_HOST_HEADER_ATTRIBUTE);
if (!CollectionUtils.isEmpty(hosts)) {
host = hosts.get(0);
}
}
if (!StringUtils.isEmpty(host)) {
request.header(HttpHeaders.HOST, host);
}
}
}
}*/
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2013-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package org.springframework.cloud.gateway.filter;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.cloud.gateway.filter.RemoveHopByHopHeadersFilter.HEADERS_REMOVED_ON_REQUEST;

/**
* @author Spencer Gibb
*/
public class RemoveHopByHopHeadersFilterTests {

@Test
public void happyPath() {
MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
.get("http://localhost/get");

HEADERS_REMOVED_ON_REQUEST.forEach(header -> builder.header(header, header+"1"));

testFilter(builder.build());
}

@Test
public void caseInsensitive() {
MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
.get("http://localhost/get");

HEADERS_REMOVED_ON_REQUEST.forEach(header -> builder.header(header.toLowerCase(), header+"1"));

testFilter(builder.build());
}

@Test
public void removesHeadersListedInConnectionHeader() {
MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
.get("http://localhost/get");

builder.header(HttpHeaders.CONNECTION, "upgrade", "keep-alive");
builder.header(HttpHeaders.UPGRADE, "WebSocket");
builder.header("Keep-Alive", "timeout:5");

testFilter(builder.build(), "upgrade", "keep-alive");
}

private void testFilter(MockServerHttpRequest request, String... additionalHeaders) {
RemoveHopByHopHeadersFilter filter = new RemoveHopByHopHeadersFilter();
HttpHeaders headers = filter.filter(request.getHeaders());

Set<String> toRemove = new HashSet<>(HEADERS_REMOVED_ON_REQUEST);
toRemove.addAll(Arrays.asList(additionalHeaders));
assertThat(headers).doesNotContainKeys(toRemove.toArray(new String[0]));
}
}
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
public class PreserveHostHeaderGatewayFilterFactoryTests extends BaseWebClientTests {

@Test
public void setRequestHeaderFilterWorks() {
public void preserveHostHeaderGatewayFilterFactoryWorks() {
testClient.get().uri("/headers")
.header("Host", "www.preservehostheader.org")
.exchange()
Original file line number Diff line number Diff line change
@@ -24,9 +24,11 @@
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
import org.junit.runners.model.Statement;
import org.springframework.cloud.gateway.filter.RemoveHopByHopHeadersFilterTests;
import org.springframework.cloud.gateway.filter.factory.AddRequestHeaderGatewayFilterFactoryTests;
import org.springframework.cloud.gateway.filter.factory.AddRequestParameterGatewayFilterFactoryTests;
import org.springframework.cloud.gateway.filter.factory.HystrixGatewayFilterFactoryTests;
import org.springframework.cloud.gateway.filter.factory.PreserveHostHeaderGatewayFilterFactoryTests;
import org.springframework.cloud.gateway.filter.factory.RedirectToGatewayFilterFactoryTests;
import org.springframework.cloud.gateway.filter.factory.RemoveNonProxyHeadersGatewayFilterFactoryTests;
import org.springframework.cloud.gateway.filter.factory.RemoveRequestHeaderGatewayFilterFactoryTests;
@@ -67,6 +69,7 @@
PostTests.class,
ForwardTests.class,
WebSocketIntegrationTests.class,
RemoveHopByHopHeadersFilterTests.class,
// FilterFactory Tests
RemoveNonProxyHeadersGatewayFilterFactoryTests.class,
RemoveResponseHeaderGatewayFilterFactoryTests.class,
@@ -86,6 +89,7 @@
PrincipalNameKeyResolverIntegrationTests.class,
RedisRateLimiterTests.class,
RouteDefinitionRouteLocatorTests.class,
PreserveHostHeaderGatewayFilterFactoryTests.class,
// PredicateFactory Tests
MethodRoutePredicateFactoryTests.class,
HostRoutePredicateFactoryTests.class,

0 comments on commit e117906

Please sign in to comment.