Skip to content

Commit

Permalink
Merge pull request spring-cloud#439 from ryanjbaxter/readbody-cachebody
Browse files Browse the repository at this point in the history
Caches body object in ReadBodyPredicate
  • Loading branch information
Ryan Baxter authored Jul 30, 2018
2 parents 84ddb00 + 94832d9 commit 6c7168d
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import java.util.Map;
import java.util.function.Predicate;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.gateway.support.BodyInserterContext;

import reactor.core.publisher.Mono;

import org.springframework.cloud.gateway.support.CachedBodyOutputMessage;
Expand All @@ -29,6 +32,7 @@
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;

import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;

Expand All @@ -39,8 +43,10 @@
*/
public class ReadBodyPredicateFactory
extends AbstractRoutePredicateFactory<ReadBodyPredicateFactory.Config> {
protected static final Log LOGGER = LogFactory.getLog(ReadBodyPredicateFactory.class);

private static final String TEST_ATTRIBUTE = "read_body_predicate_test_attribute";
private static final String CACHE_REQUEST_BODY_OBJECT_KEY = "cachedRequestBodyObject";
private final ServerCodecConfigurer codecConfigurer;

public ReadBodyPredicateFactory(ServerCodecConfigurer codecConfigurer) {
Expand All @@ -54,17 +60,36 @@ public AsyncPredicate<ServerWebExchange> applyAsync(Config config) {
return exchange -> {
Class inClass = config.getInClass();

ServerRequest serverRequest = new DefaultServerRequest(exchange);
// TODO: flux or mono
Mono<?> modifiedBody = serverRequest.bodyToMono(inClass)
// .log("modify_request_mono", Level.INFO)
.flatMap(body -> {
// TODO: migrate to async
boolean test = config.predicate.test(body);
exchange.getAttributes().put(TEST_ATTRIBUTE, test);
return Mono.just(body);
});

Object cachedBody = exchange.getAttribute(CACHE_REQUEST_BODY_OBJECT_KEY);
Mono<?> modifiedBody;
// We can only read the body from the request once, once that happens if we try to read the body again an
// exception will be thrown. The below if/else caches the body object as a request attribute in the ServerWebExchange
// so if this filter is run more than once (due to more than one route using it) we do not try to read the
// request body multiple times
if(cachedBody != null) {
try {
boolean test = config.predicate.test(cachedBody);
exchange.getAttributes().put(TEST_ATTRIBUTE, test);
} catch(ClassCastException e) {
if(LOGGER.isDebugEnabled()) {
LOGGER.debug("Predicate test failed because class in predicate does not match the cached body object",
e);
}
}
modifiedBody = Mono.just(cachedBody);
} else {
ServerRequest serverRequest = new DefaultServerRequest(exchange);
// TODO: flux or mono
modifiedBody = serverRequest.bodyToMono(inClass)
// .log("modify_request_mono", Level.INFO)
.flatMap(body -> {
// TODO: migrate to async
exchange.getAttributes().put(CACHE_REQUEST_BODY_OBJECT_KEY, body);
boolean test = config.predicate.test(body);
exchange.getAttributes().put(TEST_ATTRIBUTE, test);
return Mono.just(body);
});
}
BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, inClass);
CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange,
exchange.getRequest().getHeaders());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright 2013-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package org.springframework.cloud.gateway.handler.predicate;


import java.util.function.Predicate;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.cloud.gateway.test.PermitAllSecurityConfiguration;
import org.springframework.cloud.netflix.ribbon.RibbonClient;
import org.springframework.cloud.netflix.ribbon.RibbonClients;
import org.springframework.cloud.netflix.ribbon.StaticServerList;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.BodyInserters;

import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ServerList;

import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;

/**
* @author Ryan Baxter
*/
@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment = RANDOM_PORT)
@DirtiesContext
public class ReadBodyPredicateFactoryTest {

@Autowired
private WebTestClient webClient;

@Test
public void readBodyWorks() {

Event messageEvent = new Event("message", "bar");
Event messageChannelEvent = new Event("message.channels", "bar");

webClient
.post().uri("/events").body(BodyInserters.fromObject(messageEvent))
.exchange()
.expectStatus().isOk()
.expectBody()
.jsonPath("$.headers.Hello").isEqualTo("World");

webClient
.post().uri("/events").body(BodyInserters.fromObject(messageChannelEvent))
.exchange()
.expectStatus().isOk()
.expectBody()
.jsonPath("$.headers.World").isEqualTo("Hello");

}

@EnableAutoConfiguration
@SpringBootConfiguration
@RibbonClients({
@RibbonClient(name = "message", configuration = TestRibbonConfig.class),
@RibbonClient(name = "messageChannel", configuration = TestRibbonConfig.class)
})
@Import(PermitAllSecurityConfiguration.class)
@RestController
public static class TestConfig {
@Bean
public RouteLocator routeLocator(RouteLocatorBuilder builder) {
return builder.routes()
.route(p -> p.path("/events").and().method(HttpMethod.POST).and().
readBody(Event.class, eventPredicate("message.channels")).
filters(f -> f.setPath("/messageChannel/events")).uri("lb://messageChannel"))
.route(p -> p.path("/events").and().method(HttpMethod.POST).and().
readBody(Event.class, eventPredicate("message"))
.filters(f -> f.setPath("/message/events")).uri("lb://message"))
.build();
}

private Predicate<Event> eventPredicate(String type) {
return r -> r.getFoo().equals(type);
}

@PostMapping(path = "message/events", produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
public String messageEvents(@RequestBody Event e) {
return "{\"headers\":{\"Hello\":\"World\"}}";
}

@PostMapping(path = "messageChannel/events", produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
public String messageChannelEvents(@RequestBody Event e) {
return "{\"headers\":{\"World\":\"Hello\"}}";
}
}

protected static class TestRibbonConfig {

@LocalServerPort
protected int port = 0;

@Bean
public ServerList<Server> ribbonServerList() {
return new StaticServerList<>(new Server("localhost", this.port));
}
}

}
class Event {
private String foo;
private String bar;

public Event() {}

public Event(String foo, String bar) {
this.foo = foo;
this.bar = bar;
}

public String getFoo() {
return foo;
}

public void setFoo(String foo) {
this.foo = foo;
}

public String getBar() {
return bar;
}

public void setBar(String bar) {
this.bar = bar;
}
}

0 comments on commit 6c7168d

Please sign in to comment.