diff --git a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilter.java b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilter.java index 8bb345548..f89706865 100644 --- a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilter.java +++ b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilter.java @@ -2,7 +2,6 @@ import java.io.IOException; import java.io.UnsupportedEncodingException; -import java.net.URLEncoder; import java.util.Map; import javax.servlet.Filter; @@ -24,6 +23,7 @@ import org.springframework.web.servlet.support.ServletUriComponentsBuilder; import org.springframework.web.util.NestedServletException; import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; /** * Security filter for an OAuth2 client. @@ -34,8 +34,9 @@ public class OAuth2ClientContextFilter implements Filter, InitializingBean { /** - * Key in request attributes for the current URI in case it is needed by rest client code that needs to send a - * redirect URI to an authorization server. + * Key in request attributes for the current URI in case it is needed by + * rest client code that needs to send a redirect URI to an authorization + * server. */ public static final String CURRENT_URI = "currentUri"; @@ -44,10 +45,12 @@ public class OAuth2ClientContextFilter implements Filter, InitializingBean { private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); public void afterPropertiesSet() throws Exception { - Assert.notNull(redirectStrategy, "A redirect strategy must be supplied."); + Assert.notNull(redirectStrategy, + "A redirect strategy must be supplied."); } - public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) + public void doFilter(ServletRequest servletRequest, + ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) servletRequest; HttpServletResponse response = (HttpServletResponse) servletResponse; @@ -55,19 +58,17 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo try { chain.doFilter(servletRequest, servletResponse); - } - catch (IOException ex) { + } catch (IOException ex) { throw ex; - } - catch (Exception ex) { + } catch (Exception ex) { // Try to extract a SpringSecurityException from the stacktrace Throwable[] causeChain = throwableAnalyzer.determineCauseChain(ex); UserRedirectRequiredException redirect = (UserRedirectRequiredException) throwableAnalyzer - .getFirstThrowableOfType(UserRedirectRequiredException.class, causeChain); + .getFirstThrowableOfType( + UserRedirectRequiredException.class, causeChain); if (redirect != null) { redirectUser(redirect, request, response); - } - else { + } else { if (ex instanceof ServletException) { throw (ServletException) ex; } @@ -83,44 +84,44 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo * Redirect the user according to the specified exception. * * @param resourceThatNeedsAuthorization - * @param e The user redirect exception. - * @param request The request. - * @param response The response. + * @param e + * The user redirect exception. + * @param request + * The request. + * @param response + * The response. */ - protected void redirectUser(UserRedirectRequiredException e, HttpServletRequest request, - HttpServletResponse response) throws IOException { + protected void redirectUser(UserRedirectRequiredException e, + HttpServletRequest request, HttpServletResponse response) + throws IOException { String redirectUri = e.getRedirectUri(); - StringBuilder builder = new StringBuilder(redirectUri); + UriComponentsBuilder builder = UriComponentsBuilder + .fromHttpUrl(redirectUri); Map requestParams = e.getRequestParams(); - char appendChar = redirectUri.indexOf('?') < 0 ? '?' : '&'; for (Map.Entry param : requestParams.entrySet()) { - try { - builder.append(appendChar).append(param.getKey()).append('=') - .append(URLEncoder.encode(param.getValue(), "UTF-8")); - } - catch (UnsupportedEncodingException uee) { - throw new IllegalStateException(uee); - } - appendChar = '&'; + builder.queryParam(param.getKey(), param.getValue()); } if (e.getStateKey() != null) { - builder.append(appendChar).append("state").append('=').append(e.getStateKey()); + builder.queryParam("state", e.getStateKey()); } - this.redirectStrategy.sendRedirect(request, response, builder.toString()); - + this.redirectStrategy.sendRedirect(request, response, builder.build() + .toUriString()); } /** * Calculate the current URI given the request. * - * @param request The request. + * @param request + * The request. * @return The current uri. */ - protected String calculateCurrentUri(HttpServletRequest request) throws UnsupportedEncodingException { - ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequest(request); + protected String calculateCurrentUri(HttpServletRequest request) + throws UnsupportedEncodingException { + ServletUriComponentsBuilder builder = ServletUriComponentsBuilder + .fromRequest(request); // Now work around SPR-10172... String queryString = request.getQueryString(); boolean legalSpaces = queryString != null && queryString.contains("+"); @@ -128,18 +129,19 @@ protected String calculateCurrentUri(HttpServletRequest request) throws Unsuppor builder.replaceQuery(queryString.replace("+", "%20")); } UriComponents uri = null; - try { - uri = builder.replaceQueryParam("code").build(true); - } catch (IllegalArgumentException ex) { - // ignore failures to parse the url (including query string). does't make sense - // for redirection purposes anyway. - return null; - } + try { + uri = builder.replaceQueryParam("code").build(true); + } catch (IllegalArgumentException ex) { + // ignore failures to parse the url (including query string). does't + // make sense for redirection purposes anyway. + return null; + } String query = uri.getQuery(); if (legalSpaces) { query = query.replace("%20", "+"); } - return ServletUriComponentsBuilder.fromUri(uri.toUri()).replaceQuery(query).build().toString(); + return ServletUriComponentsBuilder.fromUri(uri.toUri()) + .replaceQuery(query).build().toString(); } public void init(FilterConfig filterConfig) throws ServletException { diff --git a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilterTests.java b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilterTests.java index 1959a5dfa..2049a8592 100644 --- a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilterTests.java +++ b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilterTests.java @@ -2,20 +2,70 @@ import static org.junit.Assert.assertEquals; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + import org.junit.Test; +import org.mockito.Mockito; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException; +import org.springframework.security.web.RedirectStrategy; /** * @author Ryan Heaton + * @author Dave Syer */ public class OAuth2ClientContextFilterTests { + @Test + public void testVanillaRedirectUri() throws Exception { + String redirect = "http://example.com/authorize"; + Map params = new LinkedHashMap(); + params.put("foo", "bar"); + params.put("scope", "spam"); + testRedirectUri(redirect, params, redirect + "?foo=bar&scope=spam"); + } + + @Test + public void testRedirectUriWithUrlInParams() throws Exception { + String redirect = "http://example.com/authorize"; + Map params = Collections.singletonMap("redirect", + "http://foo/bar"); + testRedirectUri(redirect, params, redirect + "?redirect=http://foo/bar"); + } + + @Test + public void testRedirectUriWithQuery() throws Exception { + String redirect = "http://example.com/authorize?foo=bar"; + Map params = Collections.singletonMap("spam", + "bucket"); + testRedirectUri(redirect, params, redirect + "&spam=bucket"); + } + + public void testRedirectUri(String redirect, Map params, + String result) throws Exception { + OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter(); + RedirectStrategy redirectStrategy = Mockito + .mock(RedirectStrategy.class); + filter.setRedirectStrategy(redirectStrategy); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + UserRedirectRequiredException exception = new UserRedirectRequiredException( + redirect, params); + filter.redirectUser(exception, request, response); + Mockito.verify(redirectStrategy) + .sendRedirect(request, response, result); + } + @Test public void testVanillaCurrentUri() throws Exception { OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setQueryString("foo=bar"); - assertEquals("http://localhost?foo=bar", filter.calculateCurrentUri(request)); + assertEquals("http://localhost?foo=bar", + filter.calculateCurrentUri(request)); } @Test @@ -23,7 +73,8 @@ public void testCurrentUriWithLegalSpaces() throws Exception { OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setQueryString("foo=bar%20spam"); - assertEquals("http://localhost?foo=bar%20spam", filter.calculateCurrentUri(request)); + assertEquals("http://localhost?foo=bar%20spam", + filter.calculateCurrentUri(request)); } @Test @@ -38,7 +89,8 @@ public void testCurrentUriWithIllegalSpaces() throws Exception { OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setQueryString("foo=bar+spam"); - assertEquals("http://localhost?foo=bar+spam", filter.calculateCurrentUri(request)); + assertEquals("http://localhost?foo=bar+spam", + filter.calculateCurrentUri(request)); } @Test @@ -46,7 +98,8 @@ public void testCurrentUriRemovingCode() throws Exception { OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setQueryString("code=XXXX&foo=bar"); - assertEquals("http://localhost?foo=bar", filter.calculateCurrentUri(request)); + assertEquals("http://localhost?foo=bar", + filter.calculateCurrentUri(request)); } @Test @@ -54,7 +107,8 @@ public void testCurrentUriRemovingCodeInSecond() throws Exception { OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter(); MockHttpServletRequest request = new MockHttpServletRequest(); request.setQueryString("foo=bar&code=XXXX"); - assertEquals("http://localhost?foo=bar", filter.calculateCurrentUri(request)); + assertEquals("http://localhost?foo=bar", + filter.calculateCurrentUri(request)); } @Test