From 0ce254dc48fe7d60451b60a7784bb93b858cb434 Mon Sep 17 00:00:00 2001 From: Kusal Kithul-Godage Date: Tue, 20 Dec 2022 22:06:20 +1100 Subject: [PATCH] WW-5270 Rework and fix Struts filter cleanup --- .../struts2/dispatcher/PrepareOperations.java | 105 +++++++++++------- .../filter/StrutsPrepareAndExecuteFilter.java | 3 +- .../filter/StrutsPrepareFilter.java | 6 + .../dispatcher/PrepareOperationsTest.java | 40 +++++++ 4 files changed, 114 insertions(+), 40 deletions(-) diff --git a/core/src/main/java/org/apache/struts2/dispatcher/PrepareOperations.java b/core/src/main/java/org/apache/struts2/dispatcher/PrepareOperations.java index a25e372cff..d8e4bf5705 100644 --- a/core/src/main/java/org/apache/struts2/dispatcher/PrepareOperations.java +++ b/core/src/main/java/org/apache/struts2/dispatcher/PrepareOperations.java @@ -47,18 +47,43 @@ public class PrepareOperations { /** * Maintains per-request override of devMode configuration. */ - private static ThreadLocal devModeOverride = new InheritableThreadLocal<>(); + private static final ThreadLocal devModeOverride = new InheritableThreadLocal<>(); - private Dispatcher dispatcher; + private final Dispatcher dispatcher; private static final String STRUTS_ACTION_MAPPING_KEY = "struts.actionMapping"; private static final String NO_ACTION_MAPPING = "noActionMapping"; - public static final String CLEANUP_RECURSION_COUNTER = "__cleanup_recursion_counter"; + private static final String PREPARE_COUNTER = "__prepare_recursion_counter"; + private static final String WRAP_COUNTER = "__wrap_recursion_counter"; public PrepareOperations(Dispatcher dispatcher) { this.dispatcher = dispatcher; } + /** + * Should be called by {@link org.apache.struts2.dispatcher.filter.StrutsPrepareFilter} to track how many times this + * request has been filtered. + */ + public void trackRecursion(HttpServletRequest request) { + incrementRecursionCounter(request, PREPARE_COUNTER); + } + + /** + * Cleans up request. When paired with {@link #trackRecursion}, only cleans up once the first filter instance has + * completed, preventing cleanup by recursive filter calls - i.e. before the request is completely processed. + */ + public void cleanupRequest(final HttpServletRequest request) { + decrementRecursionCounter(request, PREPARE_COUNTER, () -> { + try { + dispatcher.cleanUpRequest(request); + } finally { + ActionContext.clear(); + Dispatcher.setInstance(null); + devModeOverride.remove(); + } + }); + } + /** * Creates the action context and initializes the thread local * @@ -69,12 +94,6 @@ public PrepareOperations(Dispatcher dispatcher) { */ public ActionContext createActionContext(HttpServletRequest request, HttpServletResponse response) { ActionContext ctx; - int counter = 1; - Integer oldCounter = (Integer) request.getAttribute(CLEANUP_RECURSION_COUNTER); - if (oldCounter != null) { - counter = oldCounter + 1; - } - ActionContext oldContext = ActionContext.getContext(); if (oldContext != null) { // detected existing context, so we are probably in a forward @@ -87,35 +106,9 @@ public ActionContext createActionContext(HttpServletRequest request, HttpServlet ctx = ActionContext.of(stack.getContext()).bind(); } } - request.setAttribute(CLEANUP_RECURSION_COUNTER, counter); return ctx; } - /** - * Cleans up a request of thread locals - * - * @param request servlet request - */ - public void cleanupRequest(HttpServletRequest request) { - Integer counterVal = (Integer) request.getAttribute(CLEANUP_RECURSION_COUNTER); - if (counterVal != null) { - counterVal -= 1; - request.setAttribute(CLEANUP_RECURSION_COUNTER, counterVal); - if (counterVal > 0 ) { - LOG.debug("skipping cleanup counter={}", counterVal); - return; - } - } - // always clean up the thread request, even if an action hasn't been executed - try { - dispatcher.cleanUpRequest(request); - } finally { - ActionContext.clear(); - Dispatcher.setInstance(null); - devModeOverride.remove(); - } - } - /** * Assigns the dispatcher to the dispatcher thread local */ @@ -135,14 +128,15 @@ public void setEncodingAndLocale(HttpServletRequest request, HttpServletResponse /** * Wraps the request with the Struts wrapper that handles multipart requests better + * Also tracks additional calls to this method on the same request. * - * @param oldRequest servlet request + * @param request servlet request * * @return The new request, if there is one * @throws ServletException on any servlet related error */ - public HttpServletRequest wrapRequest(HttpServletRequest oldRequest) throws ServletException { - HttpServletRequest request = oldRequest; + public HttpServletRequest wrapRequest(HttpServletRequest request) throws ServletException { + incrementRecursionCounter(request, WRAP_COUNTER); try { // Wrap request first, just in case it is multipart/form-data // parameters might not be accessible through before encoding (ww-1278) @@ -154,6 +148,14 @@ public HttpServletRequest wrapRequest(HttpServletRequest oldRequest) throws Serv return request; } + /** + * Should be called after whenever {@link #wrapRequest} is called. Ensures the request is only cleaned up at the + * instance it was initially wrapped in the case of multiple wrap calls - i.e. filter recursion. + */ + public void cleanupWrappedRequest(final HttpServletRequest request) { + decrementRecursionCounter(request, WRAP_COUNTER, () -> dispatcher.cleanUpRequest(request)); + } + /** * Finds and optionally creates an {@link ActionMapping}. It first looks in the current request to see if one * has already been found, otherwise, it creates it and stores it in the request. No mapping will be created in the @@ -260,7 +262,6 @@ public static Boolean getDevModeOverride() /** * Clear any override of the static devMode value being applied to the current thread. - * * This can be useful for any situation where {@link #overrideDevMode(boolean)} might be called * in a flow where {@link #cleanupRequest(javax.servlet.http.HttpServletRequest)} does not get called. * May be very situational (such as some unit tests), but may have other utility as well. @@ -269,4 +270,30 @@ public static void clearDevModeOverride() { devModeOverride.remove(); // Remove current thread's value, enxure next read returns it to initialValue (typically null). } + /** + * Helper method to potentially count recursive executions with a request attribute. Should be used in conjunction + * with {@link #decrementRecursionCounter}. + */ + public static void incrementRecursionCounter(HttpServletRequest request, String attributeName) { + Integer setCounter = (Integer) request.getAttribute(attributeName); + if (setCounter == null) { + setCounter = 0; + } + request.setAttribute(attributeName, ++setCounter); + } + + /** + * Helper method to count execution completions with a request attribute, and optionally execute some code + * (e.g. cleanup) once all recursive executions have completed. Should be used in conjunction with + * {@link #incrementRecursionCounter}. + */ + public static void decrementRecursionCounter(HttpServletRequest request, String attributeName, Runnable runnable) { + Integer setCounter = (Integer) request.getAttribute(attributeName); + if (setCounter != null) { + request.setAttribute(attributeName, --setCounter); + } + if ((setCounter == null || setCounter == 0) && runnable != null) { + runnable.run(); + } + } } diff --git a/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareAndExecuteFilter.java b/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareAndExecuteFilter.java index 359f5ae8ac..194ae33c47 100644 --- a/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareAndExecuteFilter.java +++ b/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareAndExecuteFilter.java @@ -118,6 +118,7 @@ public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) HttpServletResponse response = (HttpServletResponse) res; try { + prepare.trackRecursion(request); String uri = RequestUtils.getUri(request); if (isRequestExcluded(request)) { LOG.trace("Request: {} is excluded from handling by Struts, passing request to other filters", uri); @@ -155,7 +156,7 @@ private void handleRequest(FilterChain chain, HttpServletRequest request, HttpSe execute.executeAction(wrappedRequest, response, mapping); } } finally { - prepare.cleanupRequest(wrappedRequest); + prepare.cleanupWrappedRequest(wrappedRequest); } } diff --git a/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareFilter.java b/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareFilter.java index e8a446cca7..9c3f434a72 100644 --- a/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareFilter.java +++ b/core/src/main/java/org/apache/struts2/dispatcher/filter/StrutsPrepareFilter.java @@ -78,7 +78,9 @@ public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) HttpServletRequest request = (HttpServletRequest) req; HttpServletResponse response = (HttpServletResponse) res; + boolean didWrap = false; try { + prepare.trackRecursion(request); if (excludedPatterns != null && prepare.isUrlExcluded(request, excludedPatterns)) { request.setAttribute(REQUEST_EXCLUDED_FROM_ACTION_MAPPING, true); } else { @@ -87,10 +89,14 @@ public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) prepare.createActionContext(request, response); prepare.assignDispatcherToThread(); request = prepare.wrapRequest(request); + didWrap = true; prepare.findActionMapping(request, response, true); } chain.doFilter(request, response); } finally { + if (didWrap) { + prepare.cleanupWrappedRequest(request); + } prepare.cleanupRequest(request); } } diff --git a/core/src/test/java/org/apache/struts2/dispatcher/PrepareOperationsTest.java b/core/src/test/java/org/apache/struts2/dispatcher/PrepareOperationsTest.java index 5a5e0bf304..000f970134 100644 --- a/core/src/test/java/org/apache/struts2/dispatcher/PrepareOperationsTest.java +++ b/core/src/test/java/org/apache/struts2/dispatcher/PrepareOperationsTest.java @@ -24,7 +24,9 @@ import org.apache.struts2.StrutsInternalTestCase; import org.springframework.mock.web.MockHttpServletRequest; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; +import java.util.stream.IntStream; public class PrepareOperationsTest extends StrutsInternalTestCase { public void testCreateActionContextWhenRequestHasOne() { @@ -39,4 +41,42 @@ public void testCreateActionContextWhenRequestHasOne() { assertEquals(stack.getContext(), actionContext.getContextMap()); } + + public void testRequestCleanup() { + HttpServletRequest req = new MockHttpServletRequest(); + PrepareOperations prepare = new PrepareOperations(dispatcher); + int mockedRecursions = 5; + IntStream.range(0, mockedRecursions).forEach(i -> prepare.trackRecursion(req)); + IntStream.range(0, mockedRecursions - 1).forEach(i -> prepare.cleanupRequest(req)); + + assertNotNull(ActionContext.getContext()); + assertNotNull(Dispatcher.getInstance()); + assertNotNull(ContainerHolder.get()); + + prepare.cleanupRequest(req); + + assertNull(ActionContext.getContext()); + assertNull(Dispatcher.getInstance()); + assertNull(ContainerHolder.get()); + } + + public void testWrappedRequestCleanup() { + HttpServletRequest req = new MockHttpServletRequest(); + PrepareOperations prepare = new PrepareOperations(dispatcher); + int mockedRecursions = 5; + IntStream.range(0, mockedRecursions).forEach(i -> { + try { + prepare.wrapRequest(req); + } catch (ServletException e) { + throw new RuntimeException(e); + } + }); + IntStream.range(0, mockedRecursions - 1).forEach(i -> prepare.cleanupWrappedRequest(req)); + + assertNotNull(ContainerHolder.get()); + + prepare.cleanupWrappedRequest(req); + + assertNull(ContainerHolder.get()); + } }