Skip to content

Commit

Permalink
[fix][broker] TokenAuthenticationState: authenticate token only once (a…
Browse files Browse the repository at this point in the history
…pache#19314)

Co-authored-by: Lari Hotari <[email protected]>
  • Loading branch information
michaeljmarshall and lhotari authored Feb 1, 2023
1 parent 66a92c3 commit 0273f31
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,8 @@ public void testSaslServerAndClientAuth() throws Exception {

// prepare client and server side resource
AuthenticationDataProvider dataProvider = authSasl.getAuthData(hostName);
AuthenticationProviderList providerList = (AuthenticationProviderList)
(pulsar.getBrokerService().getAuthenticationService()
.getAuthenticationProvider(SaslConstants.AUTH_METHOD_NAME));
AuthenticationProviderSasl saslServer =
(AuthenticationProviderSasl) providerList.getProviders().get(0);
AuthenticationProviderSasl saslServer = (AuthenticationProviderSasl) pulsar.getBrokerService()
.getAuthenticationService().getAuthenticationProvider(SaslConstants.AUTH_METHOD_NAME);
AuthenticationState authState = saslServer.newAuthState(null, null, null);

// auth between server and client.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.io.IOException;
import java.net.SocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import javax.naming.AuthenticationException;
import javax.net.ssl.SSLSession;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -155,10 +156,20 @@ default CompletableFuture<Boolean> authenticateHttpRequestAsync(HttpServletReque
*/
@Deprecated
default boolean authenticateHttpRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
AuthenticationState authenticationState = newHttpAuthState(request);
String role = authenticateAsync(authenticationState.getAuthDataSource()).get();
request.setAttribute(AuthenticatedRoleAttributeName, role);
request.setAttribute(AuthenticatedDataAttributeName, authenticationState.getAuthDataSource());
return true;
try {
AuthenticationState authenticationState = newHttpAuthState(request);
String role = authenticateAsync(authenticationState.getAuthDataSource()).get();
request.setAttribute(AuthenticatedRoleAttributeName, role);
request.setAttribute(AuthenticatedDataAttributeName, authenticationState.getAuthDataSource());
return true;
} catch (AuthenticationException e) {
throw e;
} catch (Exception e) {
if (e instanceof ExecutionException && e.getCause() instanceof AuthenticationException) {
throw (AuthenticationException) e.getCause();
} else {
throw new AuthenticationException("Failed to authentication http request");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,17 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
try {
applyAuthProcessor(
providers,
provider -> {
AuthenticationState state = provider.newAuthState(authData, remoteAddress, sslSession);
states.add(state);
return state;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newAuthState(authData, remoteAddress, sslSession);
states.add(state);
} catch (AuthenticationException ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
);
} catch (AuthenticationException ae) {
authenticationException = ae;
// Store the exception so we can throw it later instead of a generic one
authenticationException = ae;
}
}
if (states.isEmpty()) {
log.debug("Failed to initialize a new auth state from {}", remoteAddress, authenticationException);
Expand All @@ -203,17 +203,17 @@ public AuthenticationState newHttpAuthState(HttpServletRequest request) throws A
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
try {
applyAuthProcessor(
providers,
provider -> {
AuthenticationState state = provider.newHttpAuthState(request);
states.add(state);
return state;
}
);
} catch (AuthenticationException ae) {
authenticationException = ae;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newHttpAuthState(request);
states.add(state);
} catch (AuthenticationException ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
// Store the exception so we can throw it later instead of a generic one
authenticationException = ae;
}
}
if (states.isEmpty()) {
log.debug("Failed to initialize a new http auth state from {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,6 @@ private static final class TokenAuthenticationState implements AuthenticationSta
SocketAddress remoteAddress,
SSLSession sslSession) throws AuthenticationException {
this.provider = provider;
String token = new String(authData.getBytes(), UTF_8);
this.authenticationDataSource = new AuthenticationDataCommand(token, remoteAddress, sslSession);
this.checkExpiration(token);
this.remoteAddress = remoteAddress;
this.sslSession = sslSession;
}
Expand All @@ -354,15 +351,9 @@ private static final class TokenAuthenticationState implements AuthenticationSta
AuthenticationProviderToken provider,
HttpServletRequest request) throws AuthenticationException {
this.provider = provider;
String httpHeaderValue = request.getHeader(HTTP_HEADER_NAME);
if (httpHeaderValue == null || !httpHeaderValue.startsWith(HTTP_HEADER_VALUE_PREFIX)) {
throw new AuthenticationException("Invalid HTTP Authorization header");
}

// Remove prefix
String token = httpHeaderValue.substring(HTTP_HEADER_VALUE_PREFIX.length());
// Set this for backwards compatibility with AuthenticationProvider#newHttpAuthState
this.authenticationDataSource = new AuthenticationDataHttps(request);
this.checkExpiration(token);

// These are not used when this constructor is invoked, set them to null.
this.sslSession = null;
Expand All @@ -371,6 +362,9 @@ private static final class TokenAuthenticationState implements AuthenticationSta

@Override
public String getAuthRole() throws AuthenticationException {
if (jwt == null) {
throw new AuthenticationException("Must authenticate before calling getAuthRole");
}
return provider.getPrincipal(jwt);
}

Expand Down Expand Up @@ -404,8 +398,8 @@ public AuthenticationDataSource getAuthDataSource() {

@Override
public boolean isComplete() {
// The authentication of tokens is always done in one single stage
return true;
// The authentication of tokens is always done in one single stage, so once jwt is set, it is "complete"
return jwt != null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ public AuthenticationService(ServiceConfiguration conf) throws PulsarServerExcep
}

for (Map.Entry<String, List<AuthenticationProvider>> entry : providerMap.entrySet()) {
AuthenticationProviderList provider = new AuthenticationProviderList(entry.getValue());
AuthenticationProvider provider;
if (entry.getValue().size() == 1) {
provider = entry.getValue().get(0);
} else {
provider = new AuthenticationProviderList(entry.getValue());
}
provider.initialize(conf);
providers.put(provider.getAuthMethodName(), provider);
LOG.info("[{}] has been loaded.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
package org.apache.pulsar.broker.authentication;

import static java.nio.charset.StandardCharsets.UTF_8;
import javax.servlet.http.HttpServletRequest;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedDataAttributeName;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedRoleAttributeName;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
Expand All @@ -35,6 +39,7 @@
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.utils.AuthTokenUtils;
import org.apache.pulsar.common.api.AuthData;
Expand Down Expand Up @@ -157,19 +162,13 @@ public void testAuthenticate() throws Exception {
}

private AuthenticationState newAuthState(String token, String expectedSubject) throws Exception {
// Must pass the token to the newAuthState for legacy reasons.
AuthenticationState authState = authProvider.newAuthState(
AuthData.of(token.getBytes(UTF_8)),
null,
null
);
assertEquals(authState.getAuthRole(), expectedSubject);
assertTrue(authState.isComplete());
assertFalse(authState.isExpired());
return authState;
}

private AuthenticationState newHttpAuthState(HttpServletRequest request, String expectedSubject) throws Exception {
AuthenticationState authState = authProvider.newHttpAuthState(request);
authState.authenticateAsync(AuthData.of(token.getBytes(UTF_8))).get();
assertEquals(authState.getAuthRole(), expectedSubject);
assertTrue(authState.isComplete());
assertFalse(authState.isExpired());
Expand Down Expand Up @@ -200,37 +199,42 @@ public void testNewAuthState() throws Exception {
}

@Test
public void testNewHttpAuthState() throws Exception {
public void testAuthenticateHttpRequest() throws Exception {
HttpServletRequest requestAA = mock(HttpServletRequest.class);
when(requestAA.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestAA.getRemotePort()).thenReturn(8080);
when(requestAA.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenAA);
AuthenticationState authStateAA = newHttpAuthState(requestAA, SUBJECT_A);
boolean doFilterAA = authProvider.authenticateHttpRequest(requestAA, null);
assertTrue(doFilterAA);
verify(requestAA).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_A));
verify(requestAA).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));

HttpServletRequest requestAB = mock(HttpServletRequest.class);
when(requestAB.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestAB.getRemotePort()).thenReturn(8080);
when(requestAB.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenAB);
AuthenticationState authStateAB = newHttpAuthState(requestAB, SUBJECT_B);
boolean doFilterAB = authProvider.authenticateHttpRequest(requestAB, null);
assertTrue(doFilterAB);
verify(requestAB).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_B));
verify(requestAB).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));

HttpServletRequest requestBA = mock(HttpServletRequest.class);
when(requestBA.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestBA.getRemotePort()).thenReturn(8080);
when(requestBA.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenBA);
AuthenticationState authStateBA = newHttpAuthState(requestBA, SUBJECT_A);
boolean doFilterBA = authProvider.authenticateHttpRequest(requestBA, null);
assertTrue(doFilterBA);
verify(requestBA).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_A));
verify(requestBA).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));

HttpServletRequest requestBB = mock(HttpServletRequest.class);
when(requestBB.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestBB.getRemotePort()).thenReturn(8080);
when(requestBB.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenBB);
AuthenticationState authStateBB = newHttpAuthState(requestBB, SUBJECT_B);

Thread.sleep(TimeUnit.SECONDS.toMillis(6));

verifyAuthStateExpired(authStateAA, SUBJECT_A);
verifyAuthStateExpired(authStateAB, SUBJECT_B);
verifyAuthStateExpired(authStateBA, SUBJECT_A);
verifyAuthStateExpired(authStateBB, SUBJECT_B);
boolean doFilterBB = authProvider.authenticateHttpRequest(requestBB, null);
assertTrue(doFilterBB);
verify(requestBB).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_B));
verify(requestBB).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static org.testng.Assert.assertNotEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -682,6 +683,7 @@ public void testExpiringToken() throws Exception {
Optional.of(new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(3))));

AuthenticationState authState = provider.newAuthState(AuthData.of(expiringToken.getBytes()), null, null);
authState.authenticate(AuthData.of(expiringToken.getBytes()));
assertTrue(authState.isComplete());
assertFalse(authState.isExpired());

Expand All @@ -693,6 +695,34 @@ public void testExpiringToken() throws Exception {
assertEquals(brokerData, AuthData.REFRESH_AUTH_DATA);
}

@SuppressWarnings("deprecation")
@Test
public void testExpiredTokenFailsOnAuthenticate() throws Exception {
SecretKey secretKey = AuthTokenUtils.createSecretKey(SignatureAlgorithm.HS256);

@Cleanup
AuthenticationProviderToken provider = new AuthenticationProviderToken();

Properties properties = new Properties();
properties.setProperty(AuthenticationProviderToken.CONF_TOKEN_SECRET_KEY,
AuthTokenUtils.encodeKeyBase64(secretKey));

ServiceConfiguration conf = new ServiceConfiguration();
conf.setProperties(properties);
provider.initialize(conf);

// Create a token that is already expired
String expiringToken = AuthTokenUtils.createToken(secretKey, SUBJECT,
Optional.of(new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(3))));

AuthData expiredAuthData = AuthData.of(expiringToken.getBytes());

// It is important that this call doesn't fail because we no longer authenticate the auth data at construction
AuthenticationState authState = provider.newAuthState(expiredAuthData,null, null);
// The call to authenticate the token is the call that fails
assertThrows(AuthenticationException.class, () -> authState.authenticate(expiredAuthData));
}

// tests for Token Audience
@Test
public void testRightTokenAudienceClaim() throws Exception {
Expand Down Expand Up @@ -916,6 +946,7 @@ public void testTokenFromHttpHeaders() throws Exception {
assertTrue(doFilter, "Authentication should have passed");
}

@SuppressWarnings("deprecation")
@Test
public void testTokenStateUpdatesAuthenticationDataSource() throws Exception {
SecretKey secretKey = AuthTokenUtils.createSecretKey(SignatureAlgorithm.HS256);
Expand All @@ -931,20 +962,26 @@ public void testTokenStateUpdatesAuthenticationDataSource() throws Exception {
conf.setProperties(properties);
provider.initialize(conf);

String firstToken = AuthTokenUtils.createToken(secretKey, SUBJECT, Optional.empty());
AuthenticationState authState = provider.newAuthState(null,null, null);

// Haven't authenticated yet, so cannot get role when using constructor with no auth data
assertThrows(AuthenticationException.class, authState::getAuthRole);
assertNull(authState.getAuthDataSource(), "Haven't created a source yet.");

AuthenticationState authState = provider.newAuthState(AuthData.of(firstToken.getBytes()),null, null);
String firstToken = AuthTokenUtils.createToken(secretKey, SUBJECT, Optional.empty());

AuthData firstChallenge = authState.authenticate(AuthData.of(firstToken.getBytes()));
AuthenticationDataSource firstAuthDataSource = authState.getAuthDataSource();
assertNotNull(firstAuthDataSource, "Should be initialized.");

String secondToken = AuthTokenUtils.createToken(secretKey, SUBJECT,
Optional.of(new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(3))));
assertNull(firstChallenge, "TokenAuth doesn't respond with challenges");
assertNotNull(firstAuthDataSource, "Created authDataSource");

String secondToken = AuthTokenUtils.createToken(secretKey, SUBJECT, Optional.empty());

AuthData challenge = authState.authenticate(AuthData.of(secondToken.getBytes()));
AuthData secondChallenge = authState.authenticate(AuthData.of(secondToken.getBytes()));
AuthenticationDataSource secondAuthDataSource = authState.getAuthDataSource();

assertNull(challenge, "TokenAuth doesn't respond with challenges");
assertNull(secondChallenge, "TokenAuth doesn't respond with challenges");
assertNotNull(secondAuthDataSource, "Created authDataSource");

assertNotEquals(firstAuthDataSource, secondAuthDataSource);
Expand Down

0 comments on commit 0273f31

Please sign in to comment.