Skip to content

Commit

Permalink
Disable cache store and load only if a remote store is used
Browse files Browse the repository at this point in the history
Closes keycloak#10803
Closes keycloak#24766

Signed-off-by: Alexander Schwartz <[email protected]>
Co-authored-by: daviddelannoy <[email protected]>
  • Loading branch information
ahus1 and daviddelannoy committed Nov 20, 2023
1 parent 62d5eb0 commit a45934a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.infinispan.Cache;
import org.infinispan.context.Flag;

import static org.keycloak.connections.infinispan.InfinispanUtil.getRemoteStores;

/**
* @author <a href="mailto:[email protected]">Marek Posolda</a>
*/
Expand All @@ -40,17 +42,31 @@ public static <K, V> AdvancedCache<K, V> localCache(Cache<K, V> cache) {
* @param cache
* @return Cache with the flags applied.
*/
public static <K, V> AdvancedCache<K, V> skipCacheLoaders(Cache<K, V> cache) {
return cache.getAdvancedCache().withFlags(Flag.SKIP_CACHE_LOAD, Flag.SKIP_CACHE_STORE);
public static <K, V> AdvancedCache<K, V> skipCacheLoadersIfRemoteStoreIsEnabled(Cache<K, V> cache) {
if (!getRemoteStores(cache).isEmpty()) {
// Disabling of the cache load and cache store is only needed when a remote store is used and handled separately.
return cache.getAdvancedCache().withFlags(Flag.SKIP_CACHE_LOAD, Flag.SKIP_CACHE_STORE);
} else {
// If there is no remote store, use write through for all stores of the cache.
// Mixing remote and non-remote caches is not supported.
return cache.getAdvancedCache().withFlags(Flag.SKIP_CACHE_LOAD);
}
}

/**
* Adds {@link Flag#SKIP_CACHE_STORE} flag to the cache.
* @param cache
* @return Cache with the flags applied.
*/
public static <K, V> AdvancedCache<K, V> skipCacheStore(Cache<K, V> cache) {
return cache.getAdvancedCache().withFlags(Flag.SKIP_CACHE_STORE);
public static <K, V> AdvancedCache<K, V> skipCacheStoreIfRemoteCacheIsEnabled(Cache<K, V> cache) {
if (!getRemoteStores(cache).isEmpty()) {
// Disabling of the cache load and cache store is only needed when a remote store is used and handled separately.
return cache.getAdvancedCache().withFlags(Flag.SKIP_CACHE_STORE);
} else {
// If there is no remote store, use write through for all stores of the cache.
// Mixing remote and non-remote caches is not supported.
return cache.getAdvancedCache();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ protected void removeAllLocalUserLoginFailuresEvent(String realmId) {

Cache<LoginFailureKey, SessionEntityWrapper<LoginFailureEntity>> localCache = CacheDecorators.localCache(loginFailureCache);

Cache<LoginFailureKey, SessionEntityWrapper<LoginFailureEntity>> localCacheStoreIgnore = CacheDecorators.skipCacheLoaders(localCache);
Cache<LoginFailureKey, SessionEntityWrapper<LoginFailureEntity>> localCacheStoreIgnore = CacheDecorators.skipCacheLoadersIfRemoteStoreIsEnabled(localCache);

localCacheStoreIgnore
.entrySet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ protected Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, UserS
}

Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline);
cache = CacheDecorators.skipCacheLoaders(cache);
cache = CacheDecorators.skipCacheLoadersIfRemoteStoreIsEnabled(cache);

// return a stream that 'wraps' the infinispan cache stream so that the cache stream's elements are read one by one
// and then mapped locally to avoid serialization issues when trying to manipulate the cache stream directly.
Expand Down Expand Up @@ -564,7 +564,7 @@ public Map<String, Long> getActiveClientSessionStats(RealmModel realm, boolean o
}

Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline);
cache = CacheDecorators.skipCacheLoaders(cache);
cache = CacheDecorators.skipCacheLoadersIfRemoteStoreIsEnabled(cache);
return cache.entrySet().stream()
.filter(UserSessionPredicate.create(realm.getId()))
.map(Mappers.authClientSessionSetMapper())
Expand Down Expand Up @@ -603,7 +603,7 @@ public void removeUserSessions(RealmModel realm, UserModel user) {
protected void removeUserSessions(RealmModel realm, UserModel user, boolean offline) {
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline);

cache = CacheDecorators.skipCacheLoaders(cache);
cache = CacheDecorators.skipCacheLoadersIfRemoteStoreIsEnabled(cache);

Iterator<UserSessionEntity> itr = cache.entrySet().stream().filter(UserSessionPredicate.create(realm.getId()).user(user.getId())).map(Mappers.userSessionEntity()).iterator();

Expand Down Expand Up @@ -647,7 +647,7 @@ public void removeLocalUserSessions(String realmId, boolean offline) {
Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> clientSessionCache = getClientSessionCache(offline);
Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> localClientSessionCache = CacheDecorators.localCache(clientSessionCache);

Cache<String, SessionEntityWrapper<UserSessionEntity>> localCacheStoreIgnore = CacheDecorators.skipCacheLoaders(localCache);
Cache<String, SessionEntityWrapper<UserSessionEntity>> localCacheStoreIgnore = CacheDecorators.skipCacheLoadersIfRemoteStoreIsEnabled(localCache);

final AtomicInteger userSessionsSize = new AtomicInteger();

Expand Down Expand Up @@ -906,7 +906,7 @@ public void importUserSessions(Collection<UserSessionModel> persistentUserSessio
.collect(Collectors.toMap(sessionEntityWrapper -> sessionEntityWrapper.getEntity().getId(), Function.identity()));

// Directly put all entities to the infinispan cache
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = CacheDecorators.skipCacheLoaders(getCache(offline));
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = CacheDecorators.skipCacheLoadersIfRemoteStoreIsEnabled(getCache(offline));

boolean importWithExpiration = sessionsById.size() == 1;
if (importWithExpiration) {
Expand Down Expand Up @@ -951,7 +951,7 @@ public void importUserSessions(Collection<UserSessionModel> persistentUserSessio

// Import client sessions
Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> clientSessCache =
CacheDecorators.skipCacheLoaders(offline ? offlineClientSessionCache : clientSessionCache);
CacheDecorators.skipCacheLoadersIfRemoteStoreIsEnabled(offline ? offlineClientSessionCache : clientSessionCache);

if (importWithExpiration) {
importSessionsWithExpiration(clientSessionsById, clientSessCache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.infinispan.Cache;
import org.infinispan.context.Flag;
import org.jboss.logging.Logger;
import org.keycloak.models.ClientModel;
import org.keycloak.models.AbstractKeycloakTransaction;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
Expand Down Expand Up @@ -189,19 +188,19 @@ private void runOperationInCluster(K key, MergedUpdate<V> task, SessionEntityWr
switch (operation) {
case REMOVE:
// Just remove it
CacheDecorators.skipCacheStore(cache)
.getAdvancedCache().withFlags(Flag.IGNORE_RETURN_VALUES)
CacheDecorators.skipCacheStoreIfRemoteCacheIsEnabled(cache)
.withFlags(Flag.IGNORE_RETURN_VALUES)
.remove(key);
break;
case ADD:
CacheDecorators.skipCacheStore(cache)
.getAdvancedCache().withFlags(Flag.IGNORE_RETURN_VALUES)
CacheDecorators.skipCacheStoreIfRemoteCacheIsEnabled(cache)
.withFlags(Flag.IGNORE_RETURN_VALUES)
.put(key, sessionWrapper, task.getLifespanMs(), TimeUnit.MILLISECONDS, task.getMaxIdleTimeMs(), TimeUnit.MILLISECONDS);

logger.tracef("Added entity '%s' to the cache '%s' . Lifespan: %d ms, MaxIdle: %d ms", key, cache.getName(), task.getLifespanMs(), task.getMaxIdleTimeMs());
break;
case ADD_IF_ABSENT:
SessionEntityWrapper<V> existing = CacheDecorators.skipCacheStore(cache).putIfAbsent(key, sessionWrapper, task.getLifespanMs(), TimeUnit.MILLISECONDS, task.getMaxIdleTimeMs(), TimeUnit.MILLISECONDS);
SessionEntityWrapper<V> existing = CacheDecorators.skipCacheStoreIfRemoteCacheIsEnabled(cache).putIfAbsent(key, sessionWrapper, task.getLifespanMs(), TimeUnit.MILLISECONDS, task.getMaxIdleTimeMs(), TimeUnit.MILLISECONDS);
if (existing != null) {
logger.debugf("Existing entity in cache for key: %s . Will update it", key);

Expand Down Expand Up @@ -234,7 +233,7 @@ private void replace(K key, MergedUpdate<V> task, SessionEntityWrapper<V> oldVer
SessionEntityWrapper<V> newVersionEntity = generateNewVersionAndWrapEntity(session, oldVersionEntity.getLocalMetadata());

// Atomic cluster-aware replace
replaced = CacheDecorators.skipCacheStore(cache).replace(key, oldVersionEntity, newVersionEntity, lifespanMs, TimeUnit.MILLISECONDS, maxIdleTimeMs, TimeUnit.MILLISECONDS);
replaced = CacheDecorators.skipCacheStoreIfRemoteCacheIsEnabled(cache).replace(key, oldVersionEntity, newVersionEntity, lifespanMs, TimeUnit.MILLISECONDS, maxIdleTimeMs, TimeUnit.MILLISECONDS);

// Replace fail. Need to load latest entity from cache, apply updates again and try to replace in cache again
if (!replaced) {
Expand Down

0 comments on commit a45934a

Please sign in to comment.