Skip to content

Commit

Permalink
Throw Saml2AuthenticationException
Browse files Browse the repository at this point in the history
  • Loading branch information
GitHanter authored and jzheaux committed Mar 2, 2021
1 parent 3e8ad4b commit 6e41246
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,12 +30,14 @@

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AuthnRequest;

Expand All @@ -62,10 +64,13 @@
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
Expand All @@ -78,6 +83,7 @@
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
Expand Down Expand Up @@ -210,6 +216,24 @@ public void authenticateWhenCustomAuthenticationConverterThenUses() throws Excep
verify(CustomAuthenticationConverter.authenticationConverter).convert(any(HttpServletRequest.class));
}

@Test
public void authenticateWithInvalidDeflatedSAMLResponseThenFailureHandlerUses() throws Exception {
this.spring.register(CustomAuthenticationFailureHandler.class).autowire();
byte[] invalidDeflated = "invalid".getBytes();
String encoded = Saml2Utils.samlEncode(invalidDeflated);
MockHttpServletRequestBuilder request = get("/login/saml2/sso/registration-id").queryParam("SAMLResponse",
encoded);
this.mvc.perform(request);
ArgumentCaptor<Saml2AuthenticationException> captor = ArgumentCaptor
.forClass(Saml2AuthenticationException.class);
verify(CustomAuthenticationFailureHandler.authenticationFailureHandler).onAuthenticationFailure(
any(HttpServletRequest.class), any(HttpServletResponse.class), captor.capture());
Saml2AuthenticationException exception = captor.getValue();
assertThat(exception.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE);
assertThat(exception.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string");
assertThat(exception.getCause()).isInstanceOf(IOException.class);
}

private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
// get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
Expand Down Expand Up @@ -314,6 +338,21 @@ public <O extends OpenSamlAuthenticationProvider> O postProcess(O provider) {

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationFailureHandler extends WebSecurityConfigurerAdapter {

static final AuthenticationFailureHandler authenticationFailureHandler = mock(
AuthenticationFailureHandler.class);

@Override
protected void configure(HttpSecurity http) throws Exception {
http.authorizeRequests((authz) -> authz.anyRequest().authenticated())
.saml2Login((saml2) -> saml2.failureHandler(authenticationFailureHandler));
}

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,7 +28,9 @@

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpMethod;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.web.authentication.AuthenticationConverter;
Expand Down Expand Up @@ -83,7 +85,13 @@ private String inflateIfRequired(HttpServletRequest request, byte[] b) {
}

private byte[] samlDecode(String s) {
return BASE64.decode(s);
try {
return BASE64.decode(s);
}
catch (Exception ex) {
throw new Saml2AuthenticationException(
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex);
}
}

private String samlInflate(byte[] b) {
Expand All @@ -94,8 +102,9 @@ private String samlInflate(byte[] b) {
inflaterOutputStream.finish();
return new String(out.toByteArray(), StandardCharsets.UTF_8);
}
catch (IOException ex) {
throw new Saml2Exception("Unable to inflate string", ex);
catch (Exception ex) {
throw new Saml2AuthenticationException(
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,14 +29,17 @@
import org.springframework.core.convert.converter.Converter;
import org.springframework.core.io.ClassPathResource;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.util.StreamUtils;
import org.springframework.web.util.UriUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
Expand Down Expand Up @@ -64,6 +67,22 @@ public void convertWhenSamlResponseThenToken() {
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
}

@Test
public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("SAMLResponse", "invalid");
assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
.withCauseInstanceOf(IllegalArgumentException.class)
.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
.satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription())
.isEqualTo("Failed to decode SAMLResponse"));
}

@Test
public void convertWhenNoSamlResponseThenNull() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
Expand Down Expand Up @@ -100,6 +119,25 @@ public void convertWhenGetRequestThenInflates() {
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
}

@Test
public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException() {
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
this.relyingPartyRegistrationResolver);
given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
.willReturn(this.relyingPartyRegistration);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("GET");
byte[] invalidDeflated = "invalid".getBytes();
String encoded = Saml2Utils.samlEncode(invalidDeflated);
request.setParameter("SAMLResponse", encoded);
assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
.withCauseInstanceOf(IOException.class)
.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
.satisfies(
(ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string"));
}

@Test
public void constructorWhenResolverIsNullThenIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));
Expand Down

0 comments on commit 6e41246

Please sign in to comment.