ClientSessionPersistentChangelogBasedTransaction.java

/*
 * Copyright 2024 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.keycloak.models.sessions.infinispan.changes;

import org.infinispan.Cache;
import org.jboss.logging.Logger;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.session.UserSessionPersisterProvider;
import org.keycloak.models.sessions.infinispan.PersistentUserSessionProvider;
import org.keycloak.models.sessions.infinispan.SessionFunction;
import org.keycloak.models.sessions.infinispan.UserSessionAdapter;
import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionEntity;
import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionStore;
import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity;
import org.keycloak.models.sessions.infinispan.remotestore.RemoteCacheInvoker;
import org.keycloak.models.sessions.infinispan.util.SessionTimeouts;

import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;

import static org.keycloak.connections.infinispan.InfinispanConnectionProvider.CLIENT_SESSION_CACHE_NAME;

public class ClientSessionPersistentChangelogBasedTransaction extends PersistentSessionsChangelogBasedTransaction<UUID, AuthenticatedClientSessionEntity> {

    private static final Logger LOG = Logger.getLogger(ClientSessionPersistentChangelogBasedTransaction.class);
    private final UserSessionPersistentChangelogBasedTransaction userSessionTx;

    public ClientSessionPersistentChangelogBasedTransaction(KeycloakSession session,
                                                            Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> cache,
                                                            Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> offlineCache,
                                                            RemoteCacheInvoker remoteCacheInvoker,
                                                            SessionFunction<AuthenticatedClientSessionEntity> lifespanMsLoader,
                                                            SessionFunction<AuthenticatedClientSessionEntity> maxIdleTimeMsLoader,
                                                            SessionFunction<AuthenticatedClientSessionEntity> offlineLifespanMsLoader,
                                                            SessionFunction<AuthenticatedClientSessionEntity> offlineMaxIdleTimeMsLoader,
                                                            UserSessionPersistentChangelogBasedTransaction userSessionTx,
                                                            ArrayBlockingQueue<PersistentUpdate> batchingQueue,
                                                            SerializeExecutionsByKey<UUID> serializerOnline,
                                                            SerializeExecutionsByKey<UUID> serializerOffline) {
        super(session, CLIENT_SESSION_CACHE_NAME, cache, offlineCache, remoteCacheInvoker, lifespanMsLoader, maxIdleTimeMsLoader, offlineLifespanMsLoader, offlineMaxIdleTimeMsLoader, batchingQueue, serializerOnline, serializerOffline);
        this.userSessionTx = userSessionTx;
    }

    public SessionEntityWrapper<AuthenticatedClientSessionEntity> get(RealmModel realm, ClientModel client, UserSessionModel userSession, UUID key, boolean offline) {
        SessionUpdatesList<AuthenticatedClientSessionEntity> myUpdates = getUpdates(offline).get(key);
        if (myUpdates == null) {
            SessionEntityWrapper<AuthenticatedClientSessionEntity> wrappedEntity = null;
            Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> cache = getCache(offline);
            if (cache != null) {
                wrappedEntity = cache.get(key);
            }

            if (wrappedEntity == null) {
                LOG.debugf("client-session not found in cache for sessionId=%s, offline=%s, loading from persister", key, offline);
                wrappedEntity = getSessionEntityFromPersister(realm, client, userSession, offline);
            } else {
                LOG.debugf("client-session found in cache for sessionId=%s, offline=%s", key, offline);
            }

            if (wrappedEntity == null) {
                LOG.debugf("client-session not found in persister for sessionId=%s, offline=%s", key, offline);
                return null;
            }

            // Cache does not contain the offline flag value so adding it
            wrappedEntity.getEntity().setOffline(offline);

            RealmModel realmFromSession = kcSession.realms().getRealm(wrappedEntity.getEntity().getRealmId());
            if (!realmFromSession.getId().equals(realm.getId())) {
                LOG.warnf("Realm mismatch for session %s. Expected realm %s, but found realm %s", wrappedEntity.getEntity(), realm.getId(), realmFromSession.getId());
                return null;
            }

            myUpdates = new SessionUpdatesList<>(realm, wrappedEntity);
            getUpdates(offline).put(key, myUpdates);

            return wrappedEntity;
        } else {

            // If entity is scheduled for remove, we don't return it.
            boolean scheduledForRemove = myUpdates.getUpdateTasks().stream().filter((SessionUpdateTask task) -> {

                return task.getOperation() == SessionUpdateTask.CacheOperation.REMOVE;

            }).findFirst().isPresent();

            return scheduledForRemove ? null : myUpdates.getEntityWrapper();
        }
    }

    private SessionEntityWrapper<AuthenticatedClientSessionEntity> getSessionEntityFromPersister(RealmModel realm, ClientModel client, UserSessionModel userSession, boolean offline) {
        UserSessionPersisterProvider persister = kcSession.getProvider(UserSessionPersisterProvider.class);
        AuthenticatedClientSessionModel clientSession = persister.loadClientSession(realm, client, userSession, offline);

        if (clientSession == null) {
            return null;
        }

        SessionEntityWrapper<AuthenticatedClientSessionEntity> authenticatedClientSessionEntitySessionEntityWrapper = importClientSession(realm, client, userSession, clientSession);
        if (authenticatedClientSessionEntitySessionEntityWrapper == null) {
            LOG.debugf("client-session not imported from persister for sessionId=%s, offline=%s, removing from persister.", clientSession.getId(), offline);
            persister.removeClientSession(userSession.getId(), client.getId(), offline);
        }

        return authenticatedClientSessionEntitySessionEntityWrapper;
    }

    private AuthenticatedClientSessionEntity createAuthenticatedClientSessionInstance(String userSessionId, AuthenticatedClientSessionModel clientSession,
                                                                                      String realmId, String clientId) {
        UUID clientSessionId = PersistentUserSessionProvider.createClientSessionUUID(userSessionId, clientId);

        AuthenticatedClientSessionEntity entity = new AuthenticatedClientSessionEntity(clientSessionId);
        entity.setRealmId(realmId);

        entity.setAction(clientSession.getAction());
        entity.setAuthMethod(clientSession.getProtocol());

        entity.setNotes(clientSession.getNotes() == null ? new ConcurrentHashMap<>() : clientSession.getNotes());
        entity.setClientId(clientId);
        entity.setRedirectUri(clientSession.getRedirectUri());
        entity.setTimestamp(clientSession.getTimestamp());
        entity.setOffline(clientSession.getUserSession().isOffline());

        return entity;
    }

    private SessionEntityWrapper<AuthenticatedClientSessionEntity> importClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, AuthenticatedClientSessionModel persistentClientSession) {
        AuthenticatedClientSessionEntity entity = createAuthenticatedClientSessionInstance(userSession.getId(), persistentClientSession,
                realm.getId(), client.getId());
        boolean offline = userSession.isOffline();

        entity.setUserSessionId(userSession.getId());

        // Update timestamp to same value as userSession. LastSessionRefresh of userSession from DB will have correct value
        entity.setTimestamp(userSession.getLastSessionRefresh());

        if (getMaxIdleMsLoader(offline).apply(realm, client, entity) == SessionTimeouts.ENTRY_EXPIRED_FLAG
                || getLifespanMsLoader(offline).apply(realm, client, entity) == SessionTimeouts.ENTRY_EXPIRED_FLAG) {
            return null;
        }

        final UUID clientSessionId = entity.getId();

        SessionUpdateTask<AuthenticatedClientSessionEntity> createClientSessionTask = Tasks.addIfAbsentSync();
        this.addTask(entity.getId(), createClientSessionTask, entity, UserSessionModel.SessionPersistenceState.PERSISTENT);

        if (! (userSession instanceof UserSessionAdapter)) {
            throw new IllegalStateException("UserSessionModel must be instance of UserSessionAdapter");
        }

        UserSessionAdapter sessionToImportInto = (UserSessionAdapter) userSession;
        AuthenticatedClientSessionStore clientSessions = sessionToImportInto.getEntity().getAuthenticatedClientSessions();
        clientSessions.put(client.getId(), clientSessionId);

        SessionUpdateTask registerClientSessionTask = new RegisterClientSessionTask(client.getId(), clientSessionId, offline);
        userSessionTx.addTask(sessionToImportInto.getId(), registerClientSessionTask);

        return new SessionEntityWrapper<>(entity);
    }

    public static class RegisterClientSessionTask implements PersistentSessionUpdateTask<UserSessionEntity> {

        private final String clientUuid;
        private final UUID clientSessionId;
        private final boolean offline;

        public RegisterClientSessionTask(String clientUuid, UUID clientSessionId, boolean offline) {
            this.clientUuid = clientUuid;
            this.clientSessionId = clientSessionId;
            this.offline = offline;
        }

        @Override
        public void runUpdate(UserSessionEntity session) {
            AuthenticatedClientSessionStore clientSessions = session.getAuthenticatedClientSessions();
            clientSessions.put(clientUuid, clientSessionId);
        }

        @Override
        public CacheOperation getOperation() {
            return CacheOperation.REPLACE;
        }

        @Override
        public CrossDCMessageStatus getCrossDCMessageStatus(SessionEntityWrapper<UserSessionEntity> sessionWrapper) {
            return CrossDCMessageStatus.SYNC;
        }

        @Override
        public boolean isOffline() {
            return offline;
        }
    }
}