From eb9b5cc1c1c72282e187ebdb64e153cc975995fc Mon Sep 17 00:00:00 2001 From: Dave Syer Date: Tue, 14 Apr 2015 11:40:01 +0100 Subject: [PATCH] Add flag to auth code client to make state key mandatory by default User has to choose explicitly the old behaviour where a null state key was acceptable even when the wrong state key was an error. Fixes gh-440 --- .../AuthorizationCodeAccessTokenProvider.java | 15 +++++++++++++-- ...AuthorizationCodeAccessTokenProviderTests.java | 11 +++++++++++ ...odeAccessTokenProviderWithConversionTests.java | 4 ++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProvider.java b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProvider.java index 9e312f752..37743c624 100644 --- a/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProvider.java +++ b/spring-security-oauth2/src/main/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProvider.java @@ -77,6 +77,17 @@ public class AuthorizationCodeAccessTokenProvider extends OAuth2AccessTokenSuppo private RequestEnhancer authorizationRequestEnhancer = new DefaultRequestEnhancer(); + private boolean stateMandatory = true; + + /** + * Flag to say that the use of state parameter is mandatory. + * + * @param stateMandatory the flag value (default true) + */ + public void setStateMandatory(boolean stateMandatory) { + this.stateMandatory = stateMandatory; + } + /** * A custom enhancer for the authorization request * @param authorizationRequestEnhancer @@ -237,12 +248,12 @@ private MultiValueMap getParametersForTokenRequest(Authorization form.set("code", request.getAuthorizationCode()); Object preservedState = request.getPreservedState(); - if (request.getStateKey() != null) { + if (request.getStateKey() != null || stateMandatory) { // The token endpoint has no use for the state so we don't send it back, but we are using it // for CSRF detection client side... if (preservedState == null) { throw new InvalidRequestException( - "Possible CSRF detected - state parameter was present but no state could be found"); + "Possible CSRF detected - state parameter was required but no state could be found"); } } diff --git a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderTests.java b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderTests.java index 8c3c26ca4..deeacc458 100644 --- a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderTests.java +++ b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderTests.java @@ -25,6 +25,7 @@ import org.springframework.security.oauth2.client.token.DefaultAccessTokenRequest; import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; import org.springframework.security.oauth2.common.OAuth2AccessToken; +import org.springframework.security.oauth2.common.exceptions.InvalidRequestException; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -52,9 +53,19 @@ protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2Prot @Test public void testGetAccessToken() throws Exception { + AccessTokenRequest request = new DefaultAccessTokenRequest(); + request.setAuthorizationCode("foo"); + request.setPreservedState(new Object()); + resource.setAccessTokenUri("http://localhost/oauth/token"); + assertEquals("FOO", provider.obtainAccessToken(resource, request).getValue()); + } + + @Test + public void testGetAccessTokenFailsWithNoState() throws Exception { AccessTokenRequest request = new DefaultAccessTokenRequest(); request.setAuthorizationCode("foo"); resource.setAccessTokenUri("http://localhost/oauth/token"); + expected.expect(InvalidRequestException.class); assertEquals("FOO", provider.obtainAccessToken(resource, request).getValue()); } diff --git a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderWithConversionTests.java b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderWithConversionTests.java index 34fca6f85..568a2ebee 100644 --- a/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderWithConversionTests.java +++ b/spring-security-oauth2/src/test/java/org/springframework/security/oauth2/client/token/grant/code/AuthorizationCodeAccessTokenProviderWithConversionTests.java @@ -156,6 +156,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO AccessTokenRequest request = new DefaultAccessTokenRequest(); request.setAuthorizationCode("foo"); resource.setAccessTokenUri("http://localhost/oauth/token"); + request.setPreservedState(new Object()); setUpRestTemplate(); assertEquals(token, provider.obtainAccessToken(resource, request)); } @@ -171,6 +172,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO }; AccessTokenRequest request = new DefaultAccessTokenRequest(); request.setAuthorizationCode("foo"); + request.setPreservedState(new Object()); resource.setAccessTokenUri("http://localhost/oauth/token"); expected.expect(OAuth2AccessDeniedException.class); expected.expect(hasCause(instanceOf(InvalidClientException.class))); @@ -190,6 +192,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO }; AccessTokenRequest request = new DefaultAccessTokenRequest(); request.setAuthorizationCode("foo"); + request.setPreservedState(new Object()); resource.setAccessTokenUri("http://localhost/oauth/token"); setUpRestTemplate(); assertEquals(token, provider.obtainAccessToken(resource, request)); @@ -207,6 +210,7 @@ public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IO }; AccessTokenRequest request = new DefaultAccessTokenRequest(); request.setAuthorizationCode("foo"); + request.setPreservedState(new Object()); resource.setAccessTokenUri("http://localhost/oauth/token"); expected.expect(OAuth2AccessDeniedException.class); expected.expect(hasCause(instanceOf(InvalidClientException.class)));