Skip to content

Commit

Permalink
GEODE-9792: synchronize multi-user authentication on different threads (
Browse files Browse the repository at this point in the history
  • Loading branch information
jinmeiliao authored Nov 9, 2021
1 parent beb898c commit 9f41e48
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ public Object execute(Op op, int retries) {
absOp.getMessage().setIsRetry();
}
try {
authenticateIfRequired(conn, op);
// for single user, it's already authenticated when opening the connection
authenticateIfMultiUser(conn, op);
return executeWithPossibleReAuthentication(conn, op);
} catch (MessageTooLargeException e) {
throw new GemFireIOException("unable to transmit message to server", e);
Expand Down Expand Up @@ -695,7 +696,7 @@ private static StringBuilder getExceptionMessage(final String exceptionName, fin
}

@VisibleForTesting
void authenticateIfRequired(final Connection connection, final Op op) {
void authenticateIfMultiUser(final Connection connection, final Op op) {
if (!connection.getServer().getRequiresCredentials()) {
return;
}
Expand All @@ -704,28 +705,30 @@ void authenticateIfRequired(final Connection connection, final Op op) {
return;
}

if (pool.getMultiuserAuthentication()) {
final UserAttributes ua = UserAttributes.userAttributes.get();
if (ua == null || ua.getServerToId().containsKey(connection.getServer())) {
return;
}
authenticateMultiuser(pool, connection, ua);
if (!pool.getMultiuserAuthentication()) {
return;
}

if (connection.getServer().getUserId() == -1) {
// This should not be reached, but keeping this code here in case it is reached.
final Connection wrappedConnection = connection.getWrappedConnection();
connection.getServer().setUserId(AuthenticateUserOp.executeOn(wrappedConnection, pool));
if (logger.isDebugEnabled()) {
logger.debug(
"OpExecutorImpl.execute() - single user mode - authenticated this user on {}",
connection);
final UserAttributes ua = getUserAttributesFromThreadLocal();
if (ua == null) {
return;
}

synchronized (ua) {
if (ua.getServerToId().containsKey(connection.getServer())) {
return;
}
authenticateMultiuser(pool, connection, ua);
}
}

private void authenticateMultiuser(final PoolImpl pool, final Connection conn,
@VisibleForTesting
UserAttributes getUserAttributesFromThreadLocal() {
return UserAttributes.userAttributes.get();
}

@VisibleForTesting
void authenticateMultiuser(final PoolImpl pool, final Connection conn,
final UserAttributes ua) {
try {
final Long userId = AuthenticateUserOp.executeOn(conn.getServer(), pool, ua.getCredentials());
Expand Down Expand Up @@ -764,7 +767,7 @@ private Object executeWithPossibleReAuthentication(final Connection conn, final
// 2nd exception-message above is from AbstractOp.sendMessage()

if (pool.getMultiuserAuthentication()) {
final UserAttributes ua = UserAttributes.userAttributes.get();
final UserAttributes ua = getUserAttributesFromThreadLocal();
if (ua != null) {
authenticateMultiuser(pool, conn, ua);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.geode.cache.client.Pool;
import org.apache.geode.distributed.internal.ServerLocation;

/**
* An instance of the class is created per ProxyCache/RegionService
*/
public class UserAttributes {

private Properties credentials;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
package org.apache.geode.cache.client.internal;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -63,15 +64,15 @@ public void before() throws Exception {
@Test
public void authenticateIfRequired_noOp_WhenNotRequireCredential() {
when(server.getRequiresCredentials()).thenReturn(false);
executor.authenticateIfRequired(connection, op);
executor.authenticateIfMultiUser(connection, op);
verify(pool, never()).executeOn(any(Connection.class), any(Op.class));
}

@Test
public void authenticateIfRequired_noOp_WhenOpNeedsNoUserId() {
when(server.getRequiresCredentials()).thenReturn(true);
when(op.needsUserId()).thenReturn(false);
executor.authenticateIfRequired(connection, op);
executor.authenticateIfMultiUser(connection, op);
verify(pool, never()).executeOn(any(Connection.class), any(Op.class));
}

Expand All @@ -81,7 +82,7 @@ public void authenticateIfRequired_noOp_singleUser_hasId() {
when(op.needsUserId()).thenReturn(true);
when(pool.getMultiuserAuthentication()).thenReturn(false);
when(server.getUserId()).thenReturn(123L);
executor.authenticateIfRequired(connection, op);
executor.authenticateIfMultiUser(connection, op);
verify(pool, never()).executeOn(any(Connection.class), any(Op.class));

}
Expand All @@ -94,9 +95,57 @@ public void authenticateIfRequired_setId_singleUser_hasNoId() {
when(server.getUserId()).thenReturn(-1L);
when(pool.executeOn(any(Connection.class), any(Op.class))).thenReturn(123L);
when(connection.getWrappedConnection()).thenReturn(connection);
executor.authenticateIfRequired(connection, op);
verify(pool).executeOn(any(Connection.class), any(Op.class));
verify(server).setUserId(eq(123L));
executor.authenticateIfMultiUser(connection, op);
verify(pool, never()).executeOn(any(Connection.class), any(Op.class));
}

@Test
public void execute_calls_authenticateIfMultiUser() throws Exception {
when(connection.execute(any())).thenReturn(123L);
when(connectionManager.borrowConnection(5)).thenReturn(connection);
OpExecutorImpl spy = spy(executor);

spy.execute(op, 1);
verify(spy).authenticateIfMultiUser(any(), any());
}

@Test
public void authenticateIfMultiUser_calls_authenticateMultiUser() {
OpExecutorImpl spy = spy(executor);
when(connection.getServer()).thenReturn(server);
when(pool.executeOn(any(ServerLocation.class), any())).thenReturn(123L);
UserAttributes userAttributes = new UserAttributes(null, null);

when(server.getRequiresCredentials()).thenReturn(false);
spy.authenticateIfMultiUser(connection, op);
verify(spy, never()).authenticateMultiuser(any(), any(), any());

when(server.getRequiresCredentials()).thenReturn(true);
when(op.needsUserId()).thenReturn(false);
spy.authenticateIfMultiUser(connection, op);
verify(spy, never()).authenticateMultiuser(any(), any(), any());

when(server.getRequiresCredentials()).thenReturn(true);
when(op.needsUserId()).thenReturn(true);
when(pool.getMultiuserAuthentication()).thenReturn(false);
spy.authenticateIfMultiUser(connection, op);
verify(spy, never()).authenticateMultiuser(any(), any(), any());

when(server.getRequiresCredentials()).thenReturn(true);
when(op.needsUserId()).thenReturn(true);
when(pool.getMultiuserAuthentication()).thenReturn(true);
spy.authenticateIfMultiUser(connection, op);
verify(spy, never()).authenticateMultiuser(any(), any(), any());

when(server.getRequiresCredentials()).thenReturn(true);
when(op.needsUserId()).thenReturn(true);
when(pool.getMultiuserAuthentication()).thenReturn(true);
doReturn(userAttributes).when(spy).getUserAttributesFromThreadLocal();
spy.authenticateIfMultiUser(connection, op);
verify(spy).authenticateMultiuser(pool, connection, userAttributes);

// calling it again wont' increase the invocation time
spy.authenticateIfMultiUser(connection, op);
verify(spy).authenticateMultiuser(pool, connection, userAttributes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.util.Properties;
import java.util.concurrent.Future;

import org.junit.After;
import org.junit.BeforeClass;
Expand All @@ -31,6 +32,7 @@
import org.apache.geode.cache.RegionService;
import org.apache.geode.cache.RegionShortcut;
import org.apache.geode.cache.client.ClientCache;
import org.apache.geode.cache.client.ClientRegionShortcut;
import org.apache.geode.cache.client.Pool;
import org.apache.geode.cache.execute.FunctionService;
import org.apache.geode.cache.query.CqAttributesFactory;
Expand All @@ -39,11 +41,12 @@
import org.apache.geode.examples.SimpleSecurityManager;
import org.apache.geode.pdx.JSONFormatter;
import org.apache.geode.pdx.PdxInstance;
import org.apache.geode.security.templates.TrackableUserPasswordAuthInit;
import org.apache.geode.security.templates.CountableUserPasswordAuthInit;
import org.apache.geode.security.templates.UserPasswordAuthInit;
import org.apache.geode.test.dunit.rules.ClusterStartupRule;
import org.apache.geode.test.dunit.rules.MemberVM;
import org.apache.geode.test.junit.rules.ClientCacheRule;
import org.apache.geode.test.junit.rules.ExecutorServiceRule;

public class MultiUserAPIDUnitTest {
@ClassRule
Expand All @@ -54,6 +57,9 @@ public class MultiUserAPIDUnitTest {
@Rule
public ClientCacheRule client = new ClientCacheRule();

@Rule
public ExecutorServiceRule executor = new ExecutorServiceRule();

@BeforeClass
public static void setUp() throws Exception {
MemberVM locator =
Expand Down Expand Up @@ -213,7 +219,7 @@ public void noCredentialCanCreateCacheWithMultiUser() throws Exception {

@Test
public void jsonFormatterOnTheClientWithSingleUser() throws Exception {
client.withProperty(SECURITY_CLIENT_AUTH_INIT, TrackableUserPasswordAuthInit.class.getName())
client.withProperty(SECURITY_CLIENT_AUTH_INIT, CountableUserPasswordAuthInit.class.getName())
.withProperty(UserPasswordAuthInit.USER_NAME, "data")
.withProperty(UserPasswordAuthInit.PASSWORD, "data")
.withMultiUser(false)
Expand All @@ -226,12 +232,35 @@ public void jsonFormatterOnTheClientWithSingleUser() throws Exception {
region.put("key", value);

// make sure the client only needs to authenticate once
assertThat(TrackableUserPasswordAuthInit.timeInitialized.get()).isEqualTo(1);
assertThat(CountableUserPasswordAuthInit.count.get()).isEqualTo(1);
}

@Test
public void multiUser_OneUserShouldOnlyAuthenticateOnceByDifferentThread() throws Exception {
ClientCache cache = client.withServerConnection(server.getPort())
.withProperty(SECURITY_CLIENT_AUTH_INIT, CountableUserPasswordAuthInit.class.getName())
.withMultiUser(true)
.createCache();
Properties properties = new Properties();
properties.setProperty(UserPasswordAuthInit.USER_NAME, "data");
properties.setProperty(UserPasswordAuthInit.PASSWORD, "data");
RegionService regionService = cache.createAuthenticatedView(properties);

cache.createClientRegionFactory(ClientRegionShortcut.PROXY).create("region");
Region region = regionService.getRegion(SEPARATOR + "region");

Future<Object> put1 = executor.submit(() -> region.put("key", "value"));
Future<Object> put2 = executor.submit(() -> region.put("key", "value"));

put1.get();
put2.get();
assertThat(CountableUserPasswordAuthInit.count.get()).isEqualTo(1);
}


@After
public void after() throws Exception {
TrackableUserPasswordAuthInit.reset();
CountableUserPasswordAuthInit.reset();
}

@Test
Expand All @@ -252,4 +281,24 @@ public void jsonFormatterOnTheClientWithMultiUser() throws Exception {
PdxInstance value = regionService.getJsonFormatter().toPdxInstance(json);
regionView.put("key", value);
}

@Test
public void multiUserWithCQ_Should_Authentiate() throws Exception {
ClientCache cache = client.withServerConnection(server.getPort())
.withPoolSubscription(true)
.withProperty(SECURITY_CLIENT_AUTH_INIT, CountableUserPasswordAuthInit.class.getName())
.withMultiUser(true)
.createCache();
Properties properties = new Properties();
properties.setProperty(UserPasswordAuthInit.USER_NAME, "data");
properties.setProperty(UserPasswordAuthInit.PASSWORD, "wrongPassword");
RegionService regionService = cache.createAuthenticatedView(properties);

cache.createClientRegionFactory(ClientRegionShortcut.PROXY).create("region");
CqQuery cqQuery =
regionService.getQueryService()
.newCq("select * from /region", new CqAttributesFactory().create());
assertThatThrownBy(() -> cqQuery.execute())
.hasCauseInstanceOf(AuthenticationFailedException.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
import org.apache.geode.LogWriter;
import org.apache.geode.security.AuthenticationFailedException;

public class TrackableUserPasswordAuthInit extends UserPasswordAuthInit {
public static AtomicInteger timeInitialized = new AtomicInteger(0);
public class CountableUserPasswordAuthInit extends UserPasswordAuthInit {
public static AtomicInteger count = new AtomicInteger(0);

public static void reset() {
timeInitialized.set(0);
count.set(0);
}

@Override
public void init(final LogWriter systemLogWriter, final LogWriter securityLogWriter)
throws AuthenticationFailedException {
super.init(systemLogWriter, securityLogWriter);
timeInitialized.incrementAndGet();
count.incrementAndGet();
}
}

0 comments on commit 9f41e48

Please sign in to comment.