Skip to content

Commit

Permalink
Add flag to auth code client to make state key mandatory by default
Browse files Browse the repository at this point in the history
User has to choose explicitly the old behaviour where a null state
key was acceptable even when the wrong state key was an error.

Fixes spring-atticgh-440
  • Loading branch information
Dave Syer committed Apr 14, 2015
1 parent bbd7b70 commit eb9b5cc
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ public class AuthorizationCodeAccessTokenProvider extends OAuth2AccessTokenSuppo

private RequestEnhancer authorizationRequestEnhancer = new DefaultRequestEnhancer();

private boolean stateMandatory = true;

/**
* Flag to say that the use of state parameter is mandatory.
*
* @param stateMandatory the flag value (default true)
*/
public void setStateMandatory(boolean stateMandatory) {
this.stateMandatory = stateMandatory;
}

/**
* A custom enhancer for the authorization request
* @param authorizationRequestEnhancer
Expand Down Expand Up @@ -237,12 +248,12 @@ private MultiValueMap<String, String> getParametersForTokenRequest(Authorization
form.set("code", request.getAuthorizationCode());

Object preservedState = request.getPreservedState();
if (request.getStateKey() != null) {
if (request.getStateKey() != null || stateMandatory) {
// The token endpoint has no use for the state so we don't send it back, but we are using it
// for CSRF detection client side...
if (preservedState == null) {
throw new InvalidRequestException(
"Possible CSRF detected - state parameter was present but no state could be found");
"Possible CSRF detected - state parameter was required but no state could be found");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.security.oauth2.client.token.DefaultAccessTokenRequest;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

Expand Down Expand Up @@ -52,9 +53,19 @@ protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2Prot

@Test
public void testGetAccessToken() throws Exception {
AccessTokenRequest request = new DefaultAccessTokenRequest();
request.setAuthorizationCode("foo");
request.setPreservedState(new Object());
resource.setAccessTokenUri("http://localhost/oauth/token");
assertEquals("FOO", provider.obtainAccessToken(resource, request).getValue());
}

@Test
public void testGetAccessTokenFailsWithNoState() throws Exception {
AccessTokenRequest request = new DefaultAccessTokenRequest();
request.setAuthorizationCode("foo");
resource.setAccessTokenUri("http://localhost/oauth/token");
expected.expect(InvalidRequestException.class);
assertEquals("FOO", provider.obtainAccessToken(resource, request).getValue());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO
AccessTokenRequest request = new DefaultAccessTokenRequest();
request.setAuthorizationCode("foo");
resource.setAccessTokenUri("http://localhost/oauth/token");
request.setPreservedState(new Object());
setUpRestTemplate();
assertEquals(token, provider.obtainAccessToken(resource, request));
}
Expand All @@ -171,6 +172,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO
};
AccessTokenRequest request = new DefaultAccessTokenRequest();
request.setAuthorizationCode("foo");
request.setPreservedState(new Object());
resource.setAccessTokenUri("http://localhost/oauth/token");
expected.expect(OAuth2AccessDeniedException.class);
expected.expect(hasCause(instanceOf(InvalidClientException.class)));
Expand All @@ -190,6 +192,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO
};
AccessTokenRequest request = new DefaultAccessTokenRequest();
request.setAuthorizationCode("foo");
request.setPreservedState(new Object());
resource.setAccessTokenUri("http://localhost/oauth/token");
setUpRestTemplate();
assertEquals(token, provider.obtainAccessToken(resource, request));
Expand All @@ -207,6 +210,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO
};
AccessTokenRequest request = new DefaultAccessTokenRequest();
request.setAuthorizationCode("foo");
request.setPreservedState(new Object());
resource.setAccessTokenUri("http://localhost/oauth/token");
expected.expect(OAuth2AccessDeniedException.class);
expected.expect(hasCause(instanceOf(InvalidClientException.class)));
Expand Down

0 comments on commit eb9b5cc

Please sign in to comment.