JpaChangesPerformer.java

/*
 * Copyright 2024 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.keycloak.models.sessions.infinispan.changes;

import org.infinispan.util.function.TriConsumer;
import org.jboss.logging.Logger;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.delegate.ClientModelLazyDelegate;
import org.keycloak.models.session.PersistentAuthenticatedClientSessionAdapter;
import org.keycloak.models.session.PersistentUserSessionAdapter;
import org.keycloak.models.session.UserSessionPersisterProvider;
import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionEntity;
import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionStore;
import org.keycloak.models.sessions.infinispan.entities.SessionEntity;
import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity;
import org.keycloak.models.utils.RealmModelDelegate;
import org.keycloak.models.utils.UserModelDelegate;
import org.keycloak.models.utils.UserSessionModelDelegate;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CompletableFuture;

import static org.keycloak.connections.infinispan.InfinispanConnectionProvider.CLIENT_SESSION_CACHE_NAME;
import static org.keycloak.connections.infinispan.InfinispanConnectionProvider.OFFLINE_CLIENT_SESSION_CACHE_NAME;
import static org.keycloak.connections.infinispan.InfinispanConnectionProvider.OFFLINE_USER_SESSION_CACHE_NAME;
import static org.keycloak.connections.infinispan.InfinispanConnectionProvider.USER_SESSION_CACHE_NAME;

public class JpaChangesPerformer<K, V extends SessionEntity> implements SessionChangesPerformer<K, V> {
    private static final Logger LOG = Logger.getLogger(JpaChangesPerformer.class);

    private final String cacheName;
    private final List<PersistentUpdate> changes = new LinkedList<>();
    private final TriConsumer<KeycloakSession, Map.Entry<K, SessionUpdatesList<V>>, MergedUpdate<V>> processor;
    private final ArrayBlockingQueue<PersistentUpdate> batchingQueue;

    public JpaChangesPerformer(String cacheName, ArrayBlockingQueue<PersistentUpdate> batchingQueue) {
        this.cacheName = cacheName;
        this.batchingQueue = batchingQueue;
        processor = processor();
    }

    @Override
    public void registerChange(Map.Entry<K, SessionUpdatesList<V>> entry, MergedUpdate<V> merged) {
        changes.add(new PersistentUpdate(innerSession -> processor.accept(innerSession, entry, merged)));
    }

    private TriConsumer<KeycloakSession, Map.Entry<K, SessionUpdatesList<V>>, MergedUpdate<V>> processor() {
        return switch (cacheName) {
            case USER_SESSION_CACHE_NAME, OFFLINE_USER_SESSION_CACHE_NAME -> this::processUserSessionUpdate;
            case CLIENT_SESSION_CACHE_NAME, OFFLINE_CLIENT_SESSION_CACHE_NAME -> this::processClientSessionUpdate;
            default -> throw new IllegalStateException("Unexpected value: " + cacheName);
        };
    }

    private boolean warningShown = false;

    private void offer(PersistentUpdate update) {
        if (!batchingQueue.offer(update)) {
            if (!warningShown) {
                warningShown = true;
                LOG.warn("Queue is full, will block");
            }
            try {
                // this will block until there is a free spot in the queue
                batchingQueue.put(update);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e);
            }
        }
    }

    @Override
    public void applyChanges() {
        if (!changes.isEmpty()) {
            changes.forEach(this::offer);
            List<Throwable> exceptions = new ArrayList<>();
            CompletableFuture.allOf(changes.stream().map(f -> f.future().exceptionally(throwable -> {
                exceptions.add(throwable);
                return null;
            })).toArray(CompletableFuture[]::new)).join();
            // If any of those futures has failed, add the exceptions as suppressed exceptions to our runtime exception
            if (!exceptions.isEmpty()) {
                RuntimeException ex = new RuntimeException("unable to complete the session updates");
                exceptions.forEach(ex::addSuppressed);
                throw ex;
            }
            changes.clear();
        }
    }

    public void applyChangesSynchronously(KeycloakSession session) {
        if (!changes.isEmpty()) {
            changes.forEach(persistentUpdate -> persistentUpdate.perform(session));
            changes.clear();
        }
    }

    private void processClientSessionUpdate(KeycloakSession innerSession, Map.Entry<K, SessionUpdatesList<V>> entry, MergedUpdate<V> merged) {
        SessionUpdatesList<V> sessionUpdates = entry.getValue();
        SessionEntityWrapper<V> sessionWrapper = sessionUpdates.getEntityWrapper();
        RealmModel realm = sessionUpdates.getRealm();
        UserSessionPersisterProvider userSessionPersister = innerSession.getProvider(UserSessionPersisterProvider.class);

        if (merged.getOperation() == SessionUpdateTask.CacheOperation.REMOVE) {
            AuthenticatedClientSessionEntity entity = (AuthenticatedClientSessionEntity) sessionWrapper.getEntity();
            userSessionPersister.removeClientSession(entity.getUserSessionId(), entity.getClientId(), entity.isOffline());
        } else if (merged.getOperation() == SessionUpdateTask.CacheOperation.ADD || merged.getOperation() == SessionUpdateTask.CacheOperation.ADD_IF_ABSENT){
            AuthenticatedClientSessionEntity entity = (AuthenticatedClientSessionEntity) sessionWrapper.getEntity();
            userSessionPersister.createClientSession(new AuthenticatedClientSessionModel() {
                @Override
                public int getStarted() {
                    return entity.getStarted();
                }

                @Override
                public int getUserSessionStarted() {
                    return entity.getUserSessionStarted();
                }

                @Override
                public boolean isUserSessionRememberMe() {
                    return entity.isUserSessionRememberMe();
                }

                @Override
                public String getId() {
                    return entity.getId().toString();
                }

                @Override
                public int getTimestamp() {
                    return entity.getTimestamp();
                }

                @Override
                public void setTimestamp(int timestamp) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public void detachFromUserSession() {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public UserSessionModel getUserSession() {
                    return new UserSessionModelDelegate(null) {
                        @Override
                        public String getId() {
                            return entity.getUserSessionId();
                        }
                    };
                }

                @Override
                public String getNote(String name) {
                    return entity.getNotes().get(name);
                }

                @Override
                public void setNote(String name, String value) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public void removeNote(String name) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public Map<String, String> getNotes() {
                    return entity.getNotes();
                }

                @Override
                public String getRedirectUri() {
                    return entity.getRedirectUri();
                }

                @Override
                public void setRedirectUri(String uri) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public RealmModel getRealm() {
                    return innerSession.realms().getRealm(entity.getRealmId());
                }

                @Override
                public ClientModel getClient() {
                    return new ClientModelLazyDelegate(() -> null) {
                        @Override
                        public String getId() {
                            return entity.getClientId();
                        }
                    };
                }

                @Override
                public String getAction() {
                    return entity.getAction();
                }

                @Override
                public void setAction(String action) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public String getProtocol() {
                    return entity.getAuthMethod();
                }

                @Override
                public void setProtocol(String method) {
                    throw new IllegalStateException("not implemented");
                }
            }, entity.isOffline());
        } else {
            AuthenticatedClientSessionEntity entity = (AuthenticatedClientSessionEntity) sessionWrapper.getEntity();
            ClientModel client = new ClientModelLazyDelegate(null) {
                @Override
                public String getId() {
                    return entity.getClientId();
                }
            };
            UserSessionModel userSession = new UserSessionModelDelegate(null) {
                @Override
                public String getId() {
                    return entity.getUserSessionId();
                }
            };
            PersistentAuthenticatedClientSessionAdapter clientSessionModel = (PersistentAuthenticatedClientSessionAdapter) userSessionPersister.loadClientSession(realm, client, userSession, entity.isOffline());
            if (clientSessionModel != null) {
                AuthenticatedClientSessionEntity authenticatedClientSessionEntity = new AuthenticatedClientSessionEntity(entity.getId()) {
                    @Override
                    public Map<String, String> getNotes() {
                        return new HashMap<>() {
                            @Override
                            public String get(Object key) {
                                return clientSessionModel.getNotes().get(key);
                            }

                            @Override
                            public String put(String key, String value) {
                                String oldValue = clientSessionModel.getNotes().get(key);
                                clientSessionModel.setNote(key, value);
                                return oldValue;
                            }
                        };
                    }

                    @Override
                    public void setRedirectUri(String redirectUri) {
                        clientSessionModel.setRedirectUri(redirectUri);
                    }

                    @Override
                    public void setTimestamp(int timestamp) {
                        clientSessionModel.setTimestamp(timestamp);
                    }

                    @Override
                    public void setAction(String action) {
                        clientSessionModel.setAction(action);
                    }

                    @Override
                    public void setAuthMethod(String authMethod) {
                        clientSessionModel.setProtocol(authMethod);
                    }

                    @Override
                    public String getAuthMethod() {
                        throw new IllegalStateException("not implemented");
                    }

                    @Override
                    public String getRedirectUri() {
                        return clientSessionModel.getRedirectUri();
                    }

                    @Override
                    public int getTimestamp() {
                        return clientSessionModel.getTimestamp();
                    }

                    @Override
                    public int getUserSessionStarted() {
                        return clientSessionModel.getUserSessionStarted();
                    }

                    @Override
                    public int getStarted() {
                        return clientSessionModel.getStarted();
                    }

                    @Override
                    public boolean isUserSessionRememberMe() {
                        return clientSessionModel.isUserSessionRememberMe();
                    }

                    @Override
                    public String getClientId() {
                        return clientSessionModel.getClient().getClientId();
                    }

                    @Override
                    public void setClientId(String clientId) {
                        throw new IllegalStateException("not implemented");
                    }

                    @Override
                    public String getAction() {
                        return clientSessionModel.getAction();
                    }

                    @Override
                    public void setNotes(Map<String, String> notes) {
                        clientSessionModel.getNotes().keySet().forEach(clientSessionModel::removeNote);
                        notes.forEach((k, v) -> clientSessionModel.setNote(k, v));
                    }

                    @Override
                    public UUID getId() {
                        return UUID.fromString(clientSessionModel.getId());
                    }

                    @Override
                    public SessionEntityWrapper mergeRemoteEntityWithLocalEntity(SessionEntityWrapper localEntityWrapper) {
                        throw new IllegalStateException("not implemented");
                    }

                    @Override
                    public String getUserSessionId() {
                        return clientSessionModel.getUserSession().getId();
                    }

                    @Override
                    public void setUserSessionId(String userSessionId) {
                        throw new IllegalStateException("not implemented");
                    }
                };
                sessionUpdates.getUpdateTasks().forEach(vSessionUpdateTask -> {
                    vSessionUpdateTask.runUpdate((V) authenticatedClientSessionEntity);
                    if (vSessionUpdateTask.getOperation() == SessionUpdateTask.CacheOperation.REMOVE) {
                        userSessionPersister.removeClientSession(entity.getUserSessionId(), entity.getClientId(), entity.isOffline());
                    }
                });
                clientSessionModel.getUpdatedModel();
            }
        }

    }

    private void processUserSessionUpdate(KeycloakSession innerSession, Map.Entry<K, SessionUpdatesList<V>> entry, MergedUpdate<V> merged) {
        SessionUpdatesList<V> sessionUpdates = entry.getValue();
        SessionEntityWrapper<V> sessionWrapper = sessionUpdates.getEntityWrapper();
        RealmModel realm = sessionUpdates.getRealm();
        UserSessionPersisterProvider userSessionPersister = innerSession.getProvider(UserSessionPersisterProvider.class);
        UserSessionEntity entity = (UserSessionEntity) sessionWrapper.getEntity();

        if (merged.getOperation() == SessionUpdateTask.CacheOperation.REMOVE) {
            userSessionPersister.removeUserSession(entry.getKey().toString(), entity.isOffline());
        } else if (merged.getOperation() == SessionUpdateTask.CacheOperation.ADD || merged.getOperation() == SessionUpdateTask.CacheOperation.ADD_IF_ABSENT){
            userSessionPersister.createUserSession(new UserSessionModel() {
                @Override
                public String getId() {
                    return entity.getId();
                }

                @Override
                public RealmModel getRealm() {
                    return new RealmModelDelegate(null) {
                        @Override
                        public String getId() {
                            return entity.getRealmId();
                        }
                    };
                }

                @Override
                public String getBrokerSessionId() {
                    return entity.getBrokerSessionId();
                }

                @Override
                public String getBrokerUserId() {
                    return entity.getBrokerUserId();
                }

                @Override
                public UserModel getUser() {
                    return new UserModelDelegate(null) {
                        @Override
                        public String getId() {
                            return entity.getUser();
                        }
                    };
                }

                @Override
                public String getLoginUsername() {
                    return entity.getLoginUsername();
                }

                @Override
                public String getIpAddress() {
                    return entity.getIpAddress();
                }

                @Override
                public String getAuthMethod() {
                    return entity.getAuthMethod();
                }

                @Override
                public boolean isRememberMe() {
                    return entity.isRememberMe();
                }

                @Override
                public int getStarted() {
                    return entity.getStarted();
                }

                @Override
                public int getLastSessionRefresh() {
                    return entity.getLastSessionRefresh();
                }

                @Override
                public void setLastSessionRefresh(int seconds) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public boolean isOffline() {
                    return entity.isOffline();
                }

                @Override
                public Map<String, AuthenticatedClientSessionModel> getAuthenticatedClientSessions() {
                    // This is not used when saving this to the database.
                    return Collections.emptyMap();
                }

                @Override
                public void removeAuthenticatedClientSessions(Collection<String> removedClientUUIDS) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public String getNote(String name) {
                    return entity.getNotes().get(name);
                }

                @Override
                public void setNote(String name, String value) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public void removeNote(String name) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public Map<String, String> getNotes() {
                    return entity.getNotes();
                }

                @Override
                public State getState() {
                    return entity.getState();
                }

                @Override
                public void setState(State state) {
                    throw new IllegalStateException("not implemented");
                }

                @Override
                public void restartSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe, String brokerSessionId, String brokerUserId) {
                    throw new IllegalStateException("not implemented");
                }
            }, entity.isOffline());
        } else {
            PersistentUserSessionAdapter userSessionModel = (PersistentUserSessionAdapter) userSessionPersister.loadUserSession(realm, entry.getKey().toString(), entity.isOffline());
            if (userSessionModel != null) {
                UserSessionEntity userSessionEntity = new UserSessionEntity(userSessionModel.getId()) {
                    @Override
                    public Map<String, String> getNotes() {
                        return new HashMap<>() {

                            @Override
                            public String get(Object key) {
                                return userSessionModel.getNotes().get(key);
                            }

                            @Override
                            public String put(String key, String value) {
                                String oldValue = userSessionModel.getNotes().get(key);
                                userSessionModel.setNote(key, value);
                                return oldValue;
                            }

                            @Override
                            public String remove(Object key) {
                                String oldValue = userSessionModel.getNotes().get(key);
                                userSessionModel.removeNote(key.toString());
                                return oldValue;
                            }

                            @Override
                            public void clear() {
                                userSessionModel.getNotes().clear();
                            }
                        };
                    }

                    @Override
                    public void setLastSessionRefresh(int lastSessionRefresh) {
                        userSessionModel.setLastSessionRefresh(lastSessionRefresh);
                    }

                    @Override
                    public void setState(UserSessionModel.State state) {
                        userSessionModel.setState(state);
                    }

                    @Override
                    public AuthenticatedClientSessionStore getAuthenticatedClientSessions() {
                        return new AuthenticatedClientSessionStore() {
                            @Override
                            public void clear() {
                                userSessionModel.getAuthenticatedClientSessions().clear();
                            }
                        };
                    }

                    @Override
                    public String getRealmId() {
                        return userSessionModel.getRealm().getId();
                    }

                    @Override
                    public void setRealmId(String realmId) {
                        userSessionModel.setRealm(innerSession.realms().getRealm(realmId));
                    }

                    @Override
                    public String getUser() {
                        return userSessionModel.getUser().getId();
                    }

                    @Override
                    public void setUser(String userId) {
                        userSessionModel.setUser(innerSession.users().getUserById(realm, userId));
                    }

                    @Override
                    public String getLoginUsername() {
                        return userSessionModel.getLoginUsername();
                    }

                    @Override
                    public void setLoginUsername(String loginUsername) {
                        userSessionModel.setLoginUsername(loginUsername);
                    }

                    @Override
                    public String getIpAddress() {
                        return userSessionModel.getIpAddress();
                    }

                    @Override
                    public void setIpAddress(String ipAddress) {
                        userSessionModel.setIpAddress(ipAddress);
                    }

                    @Override
                    public String getAuthMethod() {
                        return userSessionModel.getAuthMethod();
                    }

                    @Override
                    public void setAuthMethod(String authMethod) {
                        userSessionModel.setAuthMethod(authMethod);
                    }

                    @Override
                    public boolean isRememberMe() {
                        return userSessionModel.isRememberMe();
                    }

                    @Override
                    public void setRememberMe(boolean rememberMe) {
                        userSessionModel.setRememberMe(rememberMe);
                    }

                    @Override
                    public int getStarted() {
                        return userSessionModel.getStarted();
                    }

                    @Override
                    public void setStarted(int started) {
                        userSessionModel.setStarted(started);
                    }

                    @Override
                    public int getLastSessionRefresh() {
                        return userSessionModel.getLastSessionRefresh();
                    }

                    @Override
                    public void setNotes(Map<String, String> notes) {
                        userSessionModel.getNotes().keySet().forEach(userSessionModel::removeNote);
                        notes.forEach((k, v) -> userSessionModel.setNote(k, v));
                    }

                    @Override
                    public void setAuthenticatedClientSessions(AuthenticatedClientSessionStore authenticatedClientSessions) {
                        throw new IllegalStateException("not supported");
                    }

                    @Override
                    public UserSessionModel.State getState() {
                        return userSessionModel.getState();
                    }

                    @Override
                    public String getBrokerSessionId() {
                        return userSessionModel.getBrokerSessionId();
                    }

                    @Override
                    public void setBrokerSessionId(String brokerSessionId) {
                        userSessionModel.setBrokerSessionId(brokerSessionId);
                    }

                    @Override
                    public String getBrokerUserId() {
                        return userSessionModel.getBrokerUserId();
                    }

                    @Override
                    public void setBrokerUserId(String brokerUserId) {
                        userSessionModel.setBrokerUserId(brokerUserId);
                    }

                    @Override
                    public SessionEntityWrapper mergeRemoteEntityWithLocalEntity(SessionEntityWrapper localEntityWrapper) {
                        throw new IllegalStateException("not supported");
                    }
                };
                sessionUpdates.getUpdateTasks().forEach(vSessionUpdateTask -> {
                    vSessionUpdateTask.runUpdate((V) userSessionEntity);
                    if (vSessionUpdateTask.getOperation() == SessionUpdateTask.CacheOperation.REMOVE) {
                        userSessionPersister.removeUserSession(entry.getKey().toString(), entity.isOffline());
                    }
                });
                userSessionModel.getUpdatedModel();
            }

        }
    }
}