diff --git a/web/src/main/java/org/apache/shiro/web/env/EnvironmentLoader.java b/web/src/main/java/org/apache/shiro/web/env/EnvironmentLoader.java index 4b6698d113..8ea47dc187 100644 --- a/web/src/main/java/org/apache/shiro/web/env/EnvironmentLoader.java +++ b/web/src/main/java/org/apache/shiro/web/env/EnvironmentLoader.java @@ -195,7 +195,7 @@ protected WebEnvironment createEnvironment(ServletContext sc) { Class clazz = determineWebEnvironmentClass(sc); if (!MutableWebEnvironment.class.isAssignableFrom(clazz)) { throw new ConfigurationException("Custom WebEnvironment class [" + clazz.getName() + - "] is not of required type [" + WebEnvironment.class.getName() + "]"); + "] is not of required type [" + MutableWebEnvironment.class.getName() + "]"); } String configLocations = sc.getInitParameter(CONFIG_LOCATIONS_PARAM); @@ -223,6 +223,11 @@ protected WebEnvironment createEnvironment(ServletContext sc) { return environment; } + /** + * Any additional customization of the Environment can be by overriding this method. For example setup shared + * resources, etc. By default this method does nothing. + * @param environment + */ protected void customizeEnvironment(WebEnvironment environment) { } @@ -235,9 +240,21 @@ public void destroyEnvironment(ServletContext servletContext) { servletContext.log("Cleaning up Shiro Environment"); try { Object environment = servletContext.getAttribute(ENVIRONMENT_ATTRIBUTE_KEY); + if (environment instanceof WebEnvironment) { + finalizeEnvironment((WebEnvironment) environment); + } LifecycleUtils.destroy(environment); } finally { servletContext.removeAttribute(ENVIRONMENT_ATTRIBUTE_KEY); } } + + /** + * Any additional cleanup of the Environment can be done by overriding this method. For example clean up shared + * resources, etc. By default this method does nothing. + * @param environment + * @since 1.3 + */ + protected void finalizeEnvironment(WebEnvironment environment) { + } } diff --git a/web/src/test/groovy/org/apache/shiro/web/env/EnvironmentLoaderTest.groovy b/web/src/test/groovy/org/apache/shiro/web/env/EnvironmentLoaderTest.groovy new file mode 100644 index 0000000000..c2b1c5c489 --- /dev/null +++ b/web/src/test/groovy/org/apache/shiro/web/env/EnvironmentLoaderTest.groovy @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.shiro.web.env + +import org.easymock.Capture +import org.easymock.IAnswer + +import java.util.concurrent.atomic.AtomicInteger; + +import static org.easymock.EasyMock.*; +import static org.junit.Assert.*; +import org.junit.Test + +import javax.servlet.ServletContext + +/** + * Unit tests for the {@link EnvironmentLoaderTest} implementation. + * + * @since 1.3 + */ +class EnvironmentLoaderTest { + + @Test + void testCustomizeAndFinalizeEnvironment() { + + final AtomicInteger customizeEnvironmentCalledTimes = new AtomicInteger(0); + final AtomicInteger finalizeEnvironmentCalledTimes = new AtomicInteger(0); + + EnvironmentLoader environmentLoader = new EnvironmentLoader() { + + // EasyMock supports partial mocks, and this should not be necessary, but I could not get the .times() + // to work correctly. + @Override + protected void customizeEnvironment(WebEnvironment environment) { + customizeEnvironmentCalledTimes.getAndIncrement(); + } + + @Override + protected void finalizeEnvironment(WebEnvironment environment) { + finalizeEnvironmentCalledTimes.getAndIncrement(); + } + }; + + ServletContext servletContext = createNiceMock(ServletContext.class); + Capture environmentObjectCapture = new Capture(); + // This class is loaded via ClassUtils.newInstance() + expect(servletContext.getInitParameter(EnvironmentLoader.ENVIRONMENT_CLASS_PARAM)).andReturn(MockWebEnvironment.class.getName()); + servletContext.setAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY), capture(environmentObjectCapture)); + expect(servletContext.getAttribute(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY)).andReturn(null); // the first time it will be null + // after that use what was passed to the setAttribute method + expect(servletContext.getAttribute(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY)).andAnswer(new IAnswer() { + @Override + Object answer() throws Throwable { + return environmentObjectCapture.getValue(); + } + }) + + replay(servletContext); + + // initEnvironment calls customizeEnvironment + environmentLoader.initEnvironment(servletContext); + assertEquals(1, customizeEnvironmentCalledTimes.get()) + assertEquals(0, finalizeEnvironmentCalledTimes.get()) + + // destroyEnvironment calls finalizeEnvironment + environmentLoader.destroyEnvironment(servletContext); + assertEquals(1, customizeEnvironmentCalledTimes.get()) + assertEquals(1, finalizeEnvironmentCalledTimes.get()) + + + } +} diff --git a/web/src/test/groovy/org/apache/shiro/web/env/MockWebEnvironment.groovy b/web/src/test/groovy/org/apache/shiro/web/env/MockWebEnvironment.groovy new file mode 100644 index 0000000000..4a2cf9e19e --- /dev/null +++ b/web/src/test/groovy/org/apache/shiro/web/env/MockWebEnvironment.groovy @@ -0,0 +1,48 @@ +package org.apache.shiro.web.env + +import org.apache.shiro.mgt.SecurityManager +import org.apache.shiro.web.filter.mgt.FilterChainResolver +import org.apache.shiro.web.mgt.WebSecurityManager + +import javax.servlet.ServletContext + +/** + * Mock WebEnvironment, replaces IniWebEnvironment in EnvironmentLoader tests, to avoid extra dependencies. + */ +class MockWebEnvironment implements MutableWebEnvironment { + + @Override + void setFilterChainResolver(FilterChainResolver filterChainResolver) { + + } + + @Override + void setServletContext(ServletContext servletContext) { + + } + + @Override + void setWebSecurityManager(WebSecurityManager webSecurityManager) { + + } + + @Override + FilterChainResolver getFilterChainResolver() { + return null + } + + @Override + ServletContext getServletContext() { + return null + } + + @Override + WebSecurityManager getWebSecurityManager() { + return null + } + + @Override + SecurityManager getSecurityManager() { + return null + } +}