Skip to content

Commit

Permalink
Merge branch 'fix/lb-reconstruct-uri' of git://github.com/TYsewyn/spr…
Browse files Browse the repository at this point in the history
…ing-cloud-gateway into TYsewyn-fix/lb-reconstruct-uri
spencergibb committed Jan 16, 2018
2 parents 858e706 + 09ba3ef commit f106c14
Showing 2 changed files with 100 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -26,16 +26,15 @@
import org.springframework.cloud.gateway.support.NotFoundException;
import org.springframework.core.Ordered;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.addOriginalRequestUrl;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.containsEncodedQuery;

import reactor.core.publisher.Mono;

/**
* @author Spencer Gibb
* @author Tim Ysewyn
*/
public class LoadBalancerClientFilter implements GlobalFilter, Ordered {

@@ -70,15 +69,9 @@ public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
throw new NotFoundException("Unable to find instance for " + url.getHost());
}

/*URI uri = exchange.getRequest().getURI();
URI requestUrl = loadBalancer.reconstructURI(instance, uri);*/
boolean encoded = containsEncodedQuery(url);
URI requestUrl = UriComponentsBuilder.fromUri(url)
.scheme(instance.isSecure()? "https" : "http") //TODO: support websockets
.host(instance.getHost())
.port(instance.getPort())
.build(encoded)
.toUri();
URI uri = exchange.getRequest().getURI();
URI requestUrl = loadBalancer.reconstructURI(instance, uri);

log.trace("LoadBalancerClientFilter url chosen: " + requestUrl);
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, requestUrl);
return chain.filter(exchange);
Original file line number Diff line number Diff line change
@@ -1,47 +1,124 @@
/*
* Copyright 2013-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package org.springframework.cloud.gateway.filter;

import java.net.URI;
import java.util.Collections;
import java.util.LinkedHashSet;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.cloud.client.DefaultServiceInstance;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.LoadBalancerClient;
import org.springframework.cloud.gateway.support.NotFoundException;
import org.springframework.http.HttpMethod;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;

import reactor.core.publisher.Mono;

/**
* @author Spencer Gibb
* @author Tim Ysewyn
*/
@RunWith(MockitoJUnitRunner.class)
public class LoadBalancerClientFilterTests {

private ServerWebExchange exchange;

@Mock
private GatewayFilterChain chain;

@Mock
private LoadBalancerClient loadBalancerClient;

@InjectMocks
private LoadBalancerClientFilter loadBalancerClientFilter;

@Before
public void setup() {
exchange = MockServerWebExchange.from(MockServerHttpRequest.get("loadbalancerclient.org").build());
}

@Test
public void shouldNotFilterWhenGatewayRequestUrlIsMissing() {
loadBalancerClientFilter.filter(exchange, chain);

verify(chain).filter(exchange);
verifyNoMoreInteractions(chain);
verifyZeroInteractions(loadBalancerClient);
}

@Test
public void shouldNotFilterWhenGatewayRequestUrlSchemeIsNotLb() {
URI uri = UriComponentsBuilder.fromUriString("http://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);

loadBalancerClientFilter.filter(exchange, chain);

verify(chain).filter(exchange);
verifyNoMoreInteractions(chain);
verifyZeroInteractions(loadBalancerClient);
}

@Test(expected = NotFoundException.class)
public void shouldThrowExceptionWhenNoServiceInstanceIsFound() {
URI uri = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);

loadBalancerClientFilter.filter(exchange, chain);
}

@Test
public void shouldFilter() {
URI url = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, url);

ServiceInstance serviceInstance = new DefaultServiceInstance("myservice", "localhost", 8080, true);
when(loadBalancerClient.choose("myservice")).thenReturn(serviceInstance);

URI requestUrl = UriComponentsBuilder.fromUriString("https://localhost:8080").build().toUri();
when(loadBalancerClient.reconstructURI(any(ServiceInstance.class), any(URI.class))).thenReturn(requestUrl);

loadBalancerClientFilter.filter(exchange, chain);

assertThat((LinkedHashSet<URI>)exchange.getAttribute(GATEWAY_ORIGINAL_REQUEST_URL_ATTR)).contains(url);

verify(loadBalancerClient).choose("myservice");

ArgumentCaptor<URI> urlArgumentCaptor = ArgumentCaptor.forClass(URI.class);
verify(loadBalancerClient).reconstructURI(eq(serviceInstance), urlArgumentCaptor.capture());

URI uri = urlArgumentCaptor.getValue();
assertThat(uri).isNotNull();
assertThat(uri.toString()).isEqualTo("loadbalancerclient.org");

verifyNoMoreInteractions(loadBalancerClient);

assertThat((URI)exchange.getAttribute(GATEWAY_REQUEST_URL_ATTR)).isEqualTo(requestUrl);

verify(chain).filter(exchange);
verifyNoMoreInteractions(chain);
}


@Test
public void happyPath() {
MockServerHttpRequest request = MockServerHttpRequest
@@ -126,9 +203,9 @@ private ServerWebExchange testFilter(MockServerHttpRequest request, URI uri) {

LoadBalancerClient loadBalancerClient = mock(LoadBalancerClient.class);
when(loadBalancerClient.choose("service1")).
thenReturn(new DefaultServiceInstance("service1", "service1-host1", 8081,
thenReturn(new DefaultServiceInstance("service1", "service1-host1", 8081,
false, Collections.emptyMap()));

LoadBalancerClientFilter filter = new LoadBalancerClientFilter(loadBalancerClient);
filter.filter(exchange, filterChain);

0 comments on commit f106c14

Please sign in to comment.