PersistentSessionsChangelogBasedTransaction.java
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.sessions.infinispan.changes;
import org.infinispan.Cache;
import org.jboss.logging.Logger;
import org.keycloak.models.AbstractKeycloakTransaction;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.infinispan.SessionFunction;
import org.keycloak.models.sessions.infinispan.entities.SessionEntity;
import org.keycloak.models.sessions.infinispan.remotestore.RemoteCacheInvoker;
import org.keycloak.models.utils.KeycloakModelUtils;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.stream.Stream;
abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends SessionEntity> extends AbstractKeycloakTransaction implements SessionsChangelogBasedTransaction<K, V> {
private static final Logger LOG = Logger.getLogger(PersistentSessionsChangelogBasedTransaction.class);
protected final KeycloakSession kcSession;
protected final Map<K, SessionUpdatesList<V>> updates = new HashMap<>();
protected final Map<K, SessionUpdatesList<V>> offlineUpdates = new HashMap<>();
private final String cacheName;
private final Cache<K, SessionEntityWrapper<V>> cache;
private final Cache<K, SessionEntityWrapper<V>> offlineCache;
private final RemoteCacheInvoker remoteCacheInvoker;
private final SessionFunction<V> lifespanMsLoader;
private final SessionFunction<V> maxIdleTimeMsLoader;
private final SessionFunction<V> offlineLifespanMsLoader;
private final SessionFunction<V> offlineMaxIdleTimeMsLoader;
private final ArrayBlockingQueue<PersistentUpdate> batchingQueue;
private final SerializeExecutionsByKey<K> serializerOnline;
private final SerializeExecutionsByKey<K> serializerOffline;
public PersistentSessionsChangelogBasedTransaction(KeycloakSession session,
String cacheName,
Cache<K, SessionEntityWrapper<V>> cache,
Cache<K, SessionEntityWrapper<V>> offlineCache,
RemoteCacheInvoker remoteCacheInvoker,
SessionFunction<V> lifespanMsLoader,
SessionFunction<V> maxIdleTimeMsLoader,
SessionFunction<V> offlineLifespanMsLoader,
SessionFunction<V> offlineMaxIdleTimeMsLoader,
ArrayBlockingQueue<PersistentUpdate> batchingQueue,
SerializeExecutionsByKey<K> serializerOnline,
SerializeExecutionsByKey<K> serializerOffline) {
kcSession = session;
this.cacheName = cacheName;
this.cache = cache;
this.offlineCache = offlineCache;
this.remoteCacheInvoker = remoteCacheInvoker;
this.lifespanMsLoader = lifespanMsLoader;
this.maxIdleTimeMsLoader = maxIdleTimeMsLoader;
this.offlineLifespanMsLoader = offlineLifespanMsLoader;
this.offlineMaxIdleTimeMsLoader = offlineMaxIdleTimeMsLoader;
this.batchingQueue = batchingQueue;
this.serializerOnline = serializerOnline;
this.serializerOffline = serializerOffline;
}
protected Cache<K, SessionEntityWrapper<V>> getCache(boolean offline) {
if (offline) {
return offlineCache;
} else {
return cache;
}
}
protected SessionFunction<V> getLifespanMsLoader(boolean offline) {
if (offline) {
return offlineLifespanMsLoader;
} else {
return lifespanMsLoader;
}
}
protected SessionFunction<V> getMaxIdleMsLoader(boolean offline) {
if (offline) {
return offlineMaxIdleTimeMsLoader;
} else {
return maxIdleTimeMsLoader;
}
}
protected Map<K, SessionUpdatesList<V>> getUpdates(boolean offline) {
if (offline) {
return offlineUpdates;
} else {
return updates;
}
}
public SessionEntityWrapper<V> get(K key, boolean offline){
SessionUpdatesList<V> myUpdates = getUpdates(offline).get(key);
if (myUpdates == null) {
SessionEntityWrapper<V> wrappedEntity = getCache(offline).get(key);
if (wrappedEntity == null) {
return null;
}
wrappedEntity.getEntity().setOffline(offline);
RealmModel realm = kcSession.realms().getRealm(wrappedEntity.getEntity().getRealmId());
myUpdates = new SessionUpdatesList<>(realm, wrappedEntity);
getUpdates(offline).put(key, myUpdates);
return wrappedEntity;
} else {
V entity = myUpdates.getEntityWrapper().getEntity();
// If entity is scheduled for remove, we don't return it.
boolean scheduledForRemove = myUpdates.getUpdateTasks().stream()
.map(SessionUpdateTask::getOperation)
.anyMatch(SessionUpdateTask.CacheOperation.REMOVE::equals);
return scheduledForRemove ? null : myUpdates.getEntityWrapper();
}
}
List<SessionChangesPerformer<K, V>> prepareChangesPerformers() {
List<SessionChangesPerformer<K, V>> changesPerformers = new LinkedList<>();
if (batchingQueue != null) {
changesPerformers.add(new JpaChangesPerformer<>(cacheName, batchingQueue));
} else {
changesPerformers.add(new JpaChangesPerformer<>(cacheName, null) {
@Override
public void applyChanges() {
KeycloakModelUtils.runJobInTransaction(kcSession.getKeycloakSessionFactory(),
super::applyChangesSynchronously);
}
});
}
if (cache != null) {
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(cache, serializerOnline) {
@Override
public boolean shouldConsumeChange(V entity) {
return !entity.isOffline();
}
});
changesPerformers.add(new RemoteCachesChangesPerformer<>(kcSession, cache, remoteCacheInvoker) {
@Override
public boolean shouldConsumeChange(V entity) {
return !entity.isOffline();
}
});
}
if (offlineCache != null) {
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(offlineCache, serializerOffline){
@Override
public boolean shouldConsumeChange(V entity) {
return entity.isOffline();
}
});
changesPerformers.add(new RemoteCachesChangesPerformer<>(kcSession, offlineCache, remoteCacheInvoker) {
@Override
public boolean shouldConsumeChange(V entity) {
return entity.isOffline();
}
});
}
return changesPerformers;
}
@Override
protected void commitImpl() {
List<SessionChangesPerformer<K, V>> changesPerformers = null;
for (Map.Entry<K, SessionUpdatesList<V>> entry : Stream.concat(updates.entrySet().stream(), offlineUpdates.entrySet().stream()).toList()) {
SessionUpdatesList<V> sessionUpdates = entry.getValue();
SessionEntityWrapper<V> sessionWrapper = sessionUpdates.getEntityWrapper();
V entity = sessionWrapper.getEntity();
boolean isOffline = entity.isOffline();
// Don't save transient entities to infinispan. They are valid just for current transaction
if (sessionUpdates.getPersistenceState() == UserSessionModel.SessionPersistenceState.TRANSIENT) continue;
RealmModel realm = sessionUpdates.getRealm();
long lifespanMs = getLifespanMsLoader(isOffline).apply(realm, sessionUpdates.getClient(), entity);
long maxIdleTimeMs = getMaxIdleMsLoader(isOffline).apply(realm, sessionUpdates.getClient(), entity);
MergedUpdate<V> merged = MergedUpdate.computeUpdate(sessionUpdates.getUpdateTasks(), sessionWrapper, lifespanMs, maxIdleTimeMs);
if (merged != null) {
if (changesPerformers == null) {
changesPerformers = prepareChangesPerformers();
}
changesPerformers.stream()
.filter(performer -> performer.shouldConsumeChange(entity))
.forEach(p -> p.registerChange(entry, merged));
}
}
if (changesPerformers != null) {
changesPerformers.forEach(SessionChangesPerformer::applyChanges);
}
}
@Override
public void addTask(K key, SessionUpdateTask<V> originalTask) {
if (! (originalTask instanceof PersistentSessionUpdateTask)) {
throw new IllegalArgumentException("Task must be instance of PersistentSessionUpdateTask");
}
PersistentSessionUpdateTask<V> task = (PersistentSessionUpdateTask<V>) originalTask;
SessionUpdatesList<V> myUpdates = getUpdates(task.isOffline()).get(key);
if (myUpdates == null) {
// Lookup entity from cache
SessionEntityWrapper<V> wrappedEntity = getCache(task.isOffline()).get(key);
if (wrappedEntity == null) {
LOG.tracef("Not present cache item for key %s", key);
return;
}
// Cache does not contain the offline flag value so adding it
wrappedEntity.getEntity().setOffline(task.isOffline());
RealmModel realm = kcSession.realms().getRealm(wrappedEntity.getEntity().getRealmId());
myUpdates = new SessionUpdatesList<>(realm, wrappedEntity);
getUpdates(task.isOffline()).put(key, myUpdates);
}
// Run the update now, so reader in same transaction can see it (TODO: Rollback may not work correctly. See if it's an issue..)
task.runUpdate(myUpdates.getEntityWrapper().getEntity());
myUpdates.add(task);
}
public void addTask(K key, SessionUpdateTask<V> task, V entity, UserSessionModel.SessionPersistenceState persistenceState) {
if (entity == null) {
throw new IllegalArgumentException("Null entity not allowed");
}
RealmModel realm = kcSession.realms().getRealm(entity.getRealmId());
SessionEntityWrapper<V> wrappedEntity = new SessionEntityWrapper<>(entity);
SessionUpdatesList<V> myUpdates = new SessionUpdatesList<>(realm, wrappedEntity, persistenceState);
getUpdates(entity.isOffline()).put(key, myUpdates);
if (task != null) {
// Run the update now, so reader in same transaction can see it
task.runUpdate(entity);
myUpdates.add(task);
}
}
public void reloadEntityInCurrentTransaction(RealmModel realm, K key, SessionEntityWrapper<V> entity) {
if (entity == null) {
throw new IllegalArgumentException("Null entity not allowed");
}
boolean offline = entity.getEntity().isOffline();
SessionEntityWrapper<V> latestEntity = getCache(offline).get(key);
if (latestEntity == null) {
return;
}
SessionUpdatesList<V> newUpdates = new SessionUpdatesList<>(realm, latestEntity);
SessionUpdatesList<V> existingUpdates = getUpdates(entity.getEntity().isOffline()).get(key);
if (existingUpdates != null) {
newUpdates.setUpdateTasks(existingUpdates.getUpdateTasks());
}
getUpdates(entity.getEntity().isOffline()).put(key, newUpdates);
}
@Override
protected void rollbackImpl() {
}
}