Skip to content

Commit

Permalink
Us UriComponentsBuilder instead of UrlEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Syer committed Mar 16, 2015
1 parent 847e6ff commit 10e2c91
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.util.Map;

import javax.servlet.Filter;
Expand All @@ -24,6 +23,7 @@
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
import org.springframework.web.util.NestedServletException;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

/**
* Security filter for an OAuth2 client.
Expand All @@ -34,8 +34,9 @@
public class OAuth2ClientContextFilter implements Filter, InitializingBean {

/**
* Key in request attributes for the current URI in case it is needed by rest client code that needs to send a
* redirect URI to an authorization server.
* Key in request attributes for the current URI in case it is needed by
* rest client code that needs to send a redirect URI to an authorization
* server.
*/
public static final String CURRENT_URI = "currentUri";

Expand All @@ -44,30 +45,30 @@ public class OAuth2ClientContextFilter implements Filter, InitializingBean {
private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();

public void afterPropertiesSet() throws Exception {
Assert.notNull(redirectStrategy, "A redirect strategy must be supplied.");
Assert.notNull(redirectStrategy,
"A redirect strategy must be supplied.");
}

public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
public void doFilter(ServletRequest servletRequest,
ServletResponse servletResponse, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
request.setAttribute(CURRENT_URI, calculateCurrentUri(request));

try {
chain.doFilter(servletRequest, servletResponse);
}
catch (IOException ex) {
} catch (IOException ex) {
throw ex;
}
catch (Exception ex) {
} catch (Exception ex) {
// Try to extract a SpringSecurityException from the stacktrace
Throwable[] causeChain = throwableAnalyzer.determineCauseChain(ex);
UserRedirectRequiredException redirect = (UserRedirectRequiredException) throwableAnalyzer
.getFirstThrowableOfType(UserRedirectRequiredException.class, causeChain);
.getFirstThrowableOfType(
UserRedirectRequiredException.class, causeChain);
if (redirect != null) {
redirectUser(redirect, request, response);
}
else {
} else {
if (ex instanceof ServletException) {
throw (ServletException) ex;
}
Expand All @@ -83,63 +84,64 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
* Redirect the user according to the specified exception.
*
* @param resourceThatNeedsAuthorization
* @param e The user redirect exception.
* @param request The request.
* @param response The response.
* @param e
* The user redirect exception.
* @param request
* The request.
* @param response
* The response.
*/
protected void redirectUser(UserRedirectRequiredException e, HttpServletRequest request,
HttpServletResponse response) throws IOException {
protected void redirectUser(UserRedirectRequiredException e,
HttpServletRequest request, HttpServletResponse response)
throws IOException {

String redirectUri = e.getRedirectUri();
StringBuilder builder = new StringBuilder(redirectUri);
UriComponentsBuilder builder = UriComponentsBuilder
.fromHttpUrl(redirectUri);
Map<String, String> requestParams = e.getRequestParams();
char appendChar = redirectUri.indexOf('?') < 0 ? '?' : '&';
for (Map.Entry<String, String> param : requestParams.entrySet()) {
try {
builder.append(appendChar).append(param.getKey()).append('=')
.append(URLEncoder.encode(param.getValue(), "UTF-8"));
}
catch (UnsupportedEncodingException uee) {
throw new IllegalStateException(uee);
}
appendChar = '&';
builder.queryParam(param.getKey(), param.getValue());
}

if (e.getStateKey() != null) {
builder.append(appendChar).append("state").append('=').append(e.getStateKey());
builder.queryParam("state", e.getStateKey());
}

this.redirectStrategy.sendRedirect(request, response, builder.toString());

this.redirectStrategy.sendRedirect(request, response, builder.build()
.toUriString());
}

/**
* Calculate the current URI given the request.
*
* @param request The request.
* @param request
* The request.
* @return The current uri.
*/
protected String calculateCurrentUri(HttpServletRequest request) throws UnsupportedEncodingException {
ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequest(request);
protected String calculateCurrentUri(HttpServletRequest request)
throws UnsupportedEncodingException {
ServletUriComponentsBuilder builder = ServletUriComponentsBuilder
.fromRequest(request);
// Now work around SPR-10172...
String queryString = request.getQueryString();
boolean legalSpaces = queryString != null && queryString.contains("+");
if (legalSpaces) {
builder.replaceQuery(queryString.replace("+", "%20"));
}
UriComponents uri = null;
try {
uri = builder.replaceQueryParam("code").build(true);
} catch (IllegalArgumentException ex) {
// ignore failures to parse the url (including query string). does't make sense
// for redirection purposes anyway.
return null;
}
try {
uri = builder.replaceQueryParam("code").build(true);
} catch (IllegalArgumentException ex) {
// ignore failures to parse the url (including query string). does't
// make sense for redirection purposes anyway.
return null;
}
String query = uri.getQuery();
if (legalSpaces) {
query = query.replace("%20", "+");
}
return ServletUriComponentsBuilder.fromUri(uri.toUri()).replaceQuery(query).build().toString();
return ServletUriComponentsBuilder.fromUri(uri.toUri())
.replaceQuery(query).build().toString();
}

public void init(FilterConfig filterConfig) throws ServletException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,79 @@

import static org.junit.Assert.assertEquals;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
import org.springframework.security.web.RedirectStrategy;

/**
* @author Ryan Heaton
* @author Dave Syer
*/
public class OAuth2ClientContextFilterTests {

@Test
public void testVanillaRedirectUri() throws Exception {
String redirect = "http://example.com/authorize";
Map<String, String> params = new LinkedHashMap<String, String>();
params.put("foo", "bar");
params.put("scope", "spam");
testRedirectUri(redirect, params, redirect + "?foo=bar&scope=spam");
}

@Test
public void testRedirectUriWithUrlInParams() throws Exception {
String redirect = "http://example.com/authorize";
Map<String, String> params = Collections.singletonMap("redirect",
"http://foo/bar");
testRedirectUri(redirect, params, redirect + "?redirect=http://foo/bar");
}

@Test
public void testRedirectUriWithQuery() throws Exception {
String redirect = "http://example.com/authorize?foo=bar";
Map<String, String> params = Collections.singletonMap("spam",
"bucket");
testRedirectUri(redirect, params, redirect + "&spam=bucket");
}

public void testRedirectUri(String redirect, Map<String, String> params,
String result) throws Exception {
OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter();
RedirectStrategy redirectStrategy = Mockito
.mock(RedirectStrategy.class);
filter.setRedirectStrategy(redirectStrategy);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
UserRedirectRequiredException exception = new UserRedirectRequiredException(
redirect, params);
filter.redirectUser(exception, request, response);
Mockito.verify(redirectStrategy)
.sendRedirect(request, response, result);
}

@Test
public void testVanillaCurrentUri() throws Exception {
OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter();
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar");
assertEquals("http://localhost?foo=bar", filter.calculateCurrentUri(request));
assertEquals("http://localhost?foo=bar",
filter.calculateCurrentUri(request));
}

@Test
public void testCurrentUriWithLegalSpaces() throws Exception {
OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter();
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar%20spam");
assertEquals("http://localhost?foo=bar%20spam", filter.calculateCurrentUri(request));
assertEquals("http://localhost?foo=bar%20spam",
filter.calculateCurrentUri(request));
}

@Test
Expand All @@ -38,23 +89,26 @@ public void testCurrentUriWithIllegalSpaces() throws Exception {
OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter();
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar+spam");
assertEquals("http://localhost?foo=bar+spam", filter.calculateCurrentUri(request));
assertEquals("http://localhost?foo=bar+spam",
filter.calculateCurrentUri(request));
}

@Test
public void testCurrentUriRemovingCode() throws Exception {
OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter();
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("code=XXXX&foo=bar");
assertEquals("http://localhost?foo=bar", filter.calculateCurrentUri(request));
assertEquals("http://localhost?foo=bar",
filter.calculateCurrentUri(request));
}

@Test
public void testCurrentUriRemovingCodeInSecond() throws Exception {
OAuth2ClientContextFilter filter = new OAuth2ClientContextFilter();
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar&code=XXXX");
assertEquals("http://localhost?foo=bar", filter.calculateCurrentUri(request));
assertEquals("http://localhost?foo=bar",
filter.calculateCurrentUri(request));
}

@Test
Expand Down

0 comments on commit 10e2c91

Please sign in to comment.