RemoteUserSessionProvider.java
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.sessions.infinispan.remote;
import java.lang.invoke.MethodHandles;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import io.reactivex.rxjava3.core.Flowable;
import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.commons.util.concurrent.CompletionStages;
import org.jboss.logging.Logger;
import org.keycloak.cluster.ClusterProvider;
import org.keycloak.common.Profile;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.light.LightweightUserAdapter;
import org.keycloak.models.session.UserSessionPersisterProvider;
import org.keycloak.models.sessions.infinispan.changes.remote.updater.BaseUpdater;
import org.keycloak.models.sessions.infinispan.changes.remote.updater.client.AuthenticatedClientSessionUpdater;
import org.keycloak.models.sessions.infinispan.changes.remote.updater.user.UserSessionUpdater;
import org.keycloak.models.sessions.infinispan.entities.ClientSessionKey;
import org.keycloak.models.sessions.infinispan.entities.RemoteAuthenticatedClientSessionEntity;
import org.keycloak.models.sessions.infinispan.entities.RemoteUserSessionEntity;
import org.keycloak.models.sessions.infinispan.query.ClientSessionQueries;
import org.keycloak.models.sessions.infinispan.query.QueryHelper;
import org.keycloak.models.sessions.infinispan.query.UserSessionQueries;
import org.keycloak.models.sessions.infinispan.remote.transaction.ClientSessionChangeLogTransaction;
import org.keycloak.models.sessions.infinispan.remote.transaction.UserSessionChangeLogTransaction;
import org.keycloak.models.sessions.infinispan.remote.transaction.UserSessionTransaction;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.utils.StreamsUtil;
import static org.keycloak.models.Constants.SESSION_NOTE_LIGHTWEIGHT_USER;
/**
* An {@link UserSessionProvider} implementation that uses only {@link RemoteCache} as storage.
*/
public class RemoteUserSessionProvider implements UserSessionProvider {
private static final Logger log = Logger.getLogger(MethodHandles.lookup().lookupClass());
private static final int MAX_CONCURRENT_REQUESTS = 16;
private final KeycloakSession session;
private final UserSessionTransaction transaction;
private final int batchSize;
public RemoteUserSessionProvider(KeycloakSession session, UserSessionTransaction transaction, int batchSize) {
this.session = session;
this.transaction = transaction;
this.batchSize = batchSize;
}
@Override
public AuthenticatedClientSessionModel createClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession) {
var clientTx = getClientSessionTransaction(false);
var key = new ClientSessionKey(userSession.getId(), client.getId());
var entity = RemoteAuthenticatedClientSessionEntity.create(key, realm.getId(), userSession);
var model = clientTx.create(key, entity);
if (!model.isInitialized()) {
model.initialize(userSession, client, clientTx);
}
return model;
}
@Override
public AuthenticatedClientSessionModel getClientSession(UserSessionModel userSession, ClientModel client, String clientSessionId, boolean offline) {
if (clientSessionId == null) {
return null;
}
var clientTx = getClientSessionTransaction(offline);
var updater = clientTx.get(new ClientSessionKey(userSession.getId(), client.getId()));
if (updater == null) {
return null;
}
if (!updater.isInitialized()) {
updater.initialize(userSession, client, clientTx);
}
return updater;
}
@Override
public UserSessionModel createUserSession(String id, RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe, String brokerSessionId, String brokerUserId, UserSessionModel.SessionPersistenceState persistenceState) {
if (id == null) {
id = KeycloakModelUtils.generateId();
}
var entity = RemoteUserSessionEntity.create(id, realm, user, loginUsername, ipAddress, authMethod, rememberMe, brokerSessionId, brokerUserId);
var updater = getUserSessionTransaction(false).create(id, entity);
return initUserSessionUpdater(updater, persistenceState, realm, user, false);
}
@Override
public UserSessionModel getUserSession(RealmModel realm, String id) {
return getUserSession(realm, id, false);
}
@Override
public Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, UserModel user) {
return StreamsUtil.closing(streamUserSessionByUserId(realm, user, false));
}
@Override
public Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client) {
return StreamsUtil.closing(streamUserSessionByClientId(realm, client.getId(), false, null, null));
}
@Override
public Stream<UserSessionModel> getUserSessionsStream(RealmModel realm, ClientModel client, Integer firstResult, Integer maxResults) {
return StreamsUtil.closing(streamUserSessionByClientId(realm, client.getId(), false, firstResult, maxResults));
}
@Override
public Stream<UserSessionModel> getUserSessionByBrokerUserIdStream(RealmModel realm, String brokerUserId) {
return StreamsUtil.closing(streamUserSessionByBrokerUserId(realm, brokerUserId, false));
}
@Override
public UserSessionModel getUserSessionByBrokerSessionId(RealmModel realm, String brokerSessionId) {
var userTx = getUserSessionTransaction(false);
var query = UserSessionQueries.searchByBrokerSessionId(userTx.getCache(), realm.getId(), brokerSessionId);
return QueryHelper.fetchSingle(query, userTx::wrapFromProjection)
.map(session -> initUserSessionFromQuery(session, realm, null, false))
.orElse(null);
}
@Override
public UserSessionModel getUserSessionWithPredicate(RealmModel realm, String id, boolean offline, Predicate<UserSessionModel> predicate) {
var updater = getUserSession(realm, id, offline);
return updater != null && predicate.test(updater) ? updater : null;
}
@Override
public long getActiveUserSessions(RealmModel realm, ClientModel client) {
return computeUserSessionCount(realm, client, false);
}
@Override
public Map<String, Long> getActiveClientSessionStats(RealmModel realm, boolean offline) {
var query = ClientSessionQueries.activeClientCount(getClientSessionTransaction(offline).getCache());
return QueryHelper.streamAll(query, batchSize, QueryHelper.PROJECTION_TO_STRING_LONG_ENTRY)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
@Override
public void removeUserSession(RealmModel realm, UserSessionModel userSession) {
internalRemoveUserSession(userSession, false);
}
@Override
public void removeUserSessions(RealmModel realm, UserModel user) {
transaction.removeAllSessionByUserId(realm.getId(), user.getId());
}
@Override
public void removeAllExpired() {
//rely on Infinispan expiration
}
@Override
public void removeExpired(RealmModel realm) {
//rely on Infinispan expiration
}
@Override
public void removeUserSessions(RealmModel realm) {
transaction.removeOnlineSessionsByRealmId(realm.getId());
}
@Override
public void onRealmRemoved(RealmModel realm) {
transaction.removeAllSessionsByRealmId(realm.getId());
var database = session.getProvider(UserSessionPersisterProvider.class);
if (database != null) {
database.onRealmRemoved(realm);
}
}
@Override
public void onClientRemoved(RealmModel realm, ClientModel client) {
var database = session.getProvider(UserSessionPersisterProvider.class);
if (database != null) {
database.onClientRemoved(realm, client);
}
}
@Override
public UserSessionModel createOfflineUserSession(UserSessionModel userSession) {
var entity = RemoteUserSessionEntity.createFromModel(userSession);
var updater = getUserSessionTransaction(true).create(userSession.getId(), entity);
return initUserSessionUpdater(updater, userSession.getPersistenceState(), userSession.getRealm(), userSession.getUser(), true);
}
@Override
public UserSessionModel getOfflineUserSession(RealmModel realm, String userSessionId) {
return getUserSession(realm, userSessionId, true);
}
@Override
public void removeOfflineUserSession(RealmModel realm, UserSessionModel userSession) {
internalRemoveUserSession(userSession, true);
}
@Override
public AuthenticatedClientSessionModel createOfflineClientSession(AuthenticatedClientSessionModel clientSession, UserSessionModel offlineUserSession) {
var clientTx = getClientSessionTransaction(true);
var key = new ClientSessionKey(offlineUserSession.getId(), clientSession.getClient().getId());
var entity = RemoteAuthenticatedClientSessionEntity.createFromModel(key, clientSession);
var model = clientTx.create(key, entity);
if (!model.isInitialized()) {
model.initialize(offlineUserSession, clientSession.getClient(), clientTx);
}
return model;
}
@Override
public Stream<UserSessionModel> getOfflineUserSessionsStream(RealmModel realm, UserModel user) {
return StreamsUtil.closing(streamUserSessionByUserId(realm, user, true));
}
@Override
public Stream<UserSessionModel> getOfflineUserSessionByBrokerUserIdStream(RealmModel realm, String brokerUserId) {
return StreamsUtil.closing(streamUserSessionByBrokerUserId(realm, brokerUserId, true));
}
@Override
public long getOfflineSessionsCount(RealmModel realm, ClientModel client) {
return computeUserSessionCount(realm, client, true);
}
@Override
public Stream<UserSessionModel> getOfflineUserSessionsStream(RealmModel realm, ClientModel client, Integer firstResult, Integer maxResults) {
return StreamsUtil.closing(streamUserSessionByClientId(realm, client.getId(), true, firstResult, maxResults));
}
@Override
public int getStartupTime(RealmModel realm) {
return session.getProvider(ClusterProvider.class).getClusterStartupTime();
}
@Override
public KeycloakSession getKeycloakSession() {
return session;
}
@Override
public void close() {
}
@Override
public void migrate(String modelVersion) {
if ("25.0.0".equals(modelVersion)) {
migrateUserSessions(true);
migrateUserSessions(false);
}
}
private void migrateUserSessions(boolean offline) {
log.info("Migrate user sessions from database to the remote cache");
List<String> userSessionIds = Collections.synchronizedList(new ArrayList<>(batchSize));
List<Map.Entry<String, String>> clientSessionIds = Collections.synchronizedList(new ArrayList<>(batchSize));
boolean hasSessions;
do {
hasSessions = migrateUserSessionBatch(session.getKeycloakSessionFactory(), offline, userSessionIds, clientSessionIds);
} while (hasSessions);
log.info("All sessions migrated.");
}
private boolean migrateUserSessionBatch(KeycloakSessionFactory factory, boolean offline, List<String> userSessionBuffer, List<Map.Entry<String, String>> clientSessionBuffer) {
var userSessionCache = getUserSessionTransaction(offline).getCache();
var clientSessionCache = getClientSessionTransaction(offline).getCache();
log.infof("Migrating %s user(s) session(s) from database.", batchSize);
return KeycloakModelUtils.runJobInTransactionWithResult(factory, kcSession -> {
var database = kcSession.getProvider(UserSessionPersisterProvider.class);
var stage = CompletionStages.aggregateCompletionStage();
database.loadUserSessionsStream(-1, batchSize, offline, "")
.forEach(userSessionModel -> {
var userSessionEntity = RemoteUserSessionEntity.createFromModel(userSessionModel);
stage.dependsOn(userSessionCache.putIfAbsentAsync(userSessionModel.getId(), userSessionEntity));
userSessionBuffer.add(userSessionModel.getId());
for (var clientSessionModel : userSessionModel.getAuthenticatedClientSessions().values()) {
var clientSessionKey = new ClientSessionKey(userSessionModel.getId(), clientSessionModel.getClient().getId());
clientSessionBuffer.add(Map.entry(userSessionModel.getId(), clientSessionModel.getId()));
var clientSessionEntity = RemoteAuthenticatedClientSessionEntity.createFromModel(clientSessionKey, clientSessionModel);
stage.dependsOn(clientSessionCache.putIfAbsentAsync(clientSessionKey, clientSessionEntity));
}
});
CompletionStages.join(stage.freeze());
if (userSessionBuffer.isEmpty() && clientSessionBuffer.isEmpty()) {
return false;
}
log.infof("%s user(s) session(s) stored in the remote cache. Removing them from database.", userSessionBuffer.size());
userSessionBuffer.forEach(s -> database.removeUserSession(s, offline));
userSessionBuffer.clear();
clientSessionBuffer.forEach(e -> database.removeClientSession(e.getKey(), e.getValue(), offline));
clientSessionBuffer.clear();
return true;
});
}
private UserSessionUpdater getUserSession(RealmModel realm, String id, boolean offline) {
if (id == null) {
return null;
}
var updater = getUserSessionTransaction(offline).get(id);
if (updater == null || !updater.getValue().getRealmId().equals(realm.getId())) {
return null;
}
if (updater.isInitialized()) {
return updater;
}
UserModel user = session.users().getUserById(realm, updater.getValue().getUserId());
return initUserSessionUpdater(updater, UserSessionModel.SessionPersistenceState.PERSISTENT, realm, user, offline);
}
private void internalRemoveUserSession(UserSessionModel userSession, boolean offline) {
transaction.removeUserSessionById(userSession.getId(), offline);
}
private UserSessionChangeLogTransaction getUserSessionTransaction(boolean offline) {
return transaction.getUserSessions(offline);
}
private ClientSessionChangeLogTransaction getClientSessionTransaction(boolean offline) {
return transaction.getClientSessions(offline);
}
private UserSessionUpdater initUserSessionFromQuery(UserSessionUpdater updater, RealmModel realm, UserModel user, boolean offline) {
assert updater != null;
assert realm != null;
if (updater.isDeleted()) {
return null;
}
if (updater.isInitialized()) {
return updater;
}
if (user == null) {
user = session.users().getUserById(realm, updater.getValue().getUserId());
}
return initUserSessionUpdater(updater, UserSessionModel.SessionPersistenceState.PERSISTENT, realm, user, offline);
}
private UserSessionUpdater initUserSessionUpdater(UserSessionUpdater updater, UserSessionModel.SessionPersistenceState persistenceState, RealmModel realm, UserModel user, boolean offline) {
if (user instanceof LightweightUserAdapter) {
updater.initialize(persistenceState, realm, user, new ClientSessionMapping(updater));
return checkExpiration(updater);
}
// copied from org.keycloak.models.sessions.infinispan.InfinispanUserSessionProvider
if (Profile.isFeatureEnabled(Profile.Feature.TRANSIENT_USERS) && updater.getNotes().containsKey(SESSION_NOTE_LIGHTWEIGHT_USER)) {
LightweightUserAdapter lua = LightweightUserAdapter.fromString(session, realm, updater.getNotes().get(SESSION_NOTE_LIGHTWEIGHT_USER));
updater.initialize(persistenceState, realm, lua, new ClientSessionMapping(updater));
lua.setUpdateHandler(lua1 -> {
if (lua == lua1) { // Ensure there is no conflicting user model, only the latest lightweight user can be used
updater.setNote(SESSION_NOTE_LIGHTWEIGHT_USER, lua1.serialize());
}
});
return checkExpiration(updater);
}
if (user == null) {
// remove orphaned user session from the cache
internalRemoveUserSession(updater, offline);
return null;
}
updater.initialize(persistenceState, realm, user, new ClientSessionMapping(updater));
return checkExpiration(updater);
}
private AuthenticatedClientSessionModel initClientSessionUpdater(AuthenticatedClientSessionUpdater updater, UserSessionUpdater userSession) {
if (updater == null || updater.isDeleted()) {
return null;
}
var client = userSession.getRealm().getClientById(updater.getKey().clientId());
if (client == null) {
updater.markDeleted();
return null;
}
if (updater.isInitialized()) {
return updater;
}
updater.initialize(userSession, client, getClientSessionTransaction(userSession.isOffline()));
return checkExpiration(updater);
}
private long computeUserSessionCount(RealmModel realm, ClientModel client, boolean offline) {
var query = ClientSessionQueries.countClientSessions(getClientSessionTransaction(offline).getCache(), realm.getId(), client.getId());
return QueryHelper.fetchSingle(query, QueryHelper.SINGLE_PROJECTION_TO_LONG).orElse(0L);
}
private Stream<UserSessionModel> streamUserSessionByUserId(RealmModel realm, UserModel user, boolean offline) {
var userTx = getUserSessionTransaction(offline);
var query = UserSessionQueries.searchByUserId(userTx.getCache(), realm.getId(), user.getId());
return QueryHelper.streamAll(query, batchSize, userTx::wrapFromProjection)
.map(session -> initUserSessionFromQuery(session, realm, user, offline))
.filter(Objects::nonNull)
.map(UserSessionModel.class::cast);
}
private Stream<UserSessionModel> streamUserSessionByBrokerUserId(RealmModel realm, String brokerUserId, boolean offline) {
var userTx = getUserSessionTransaction(offline);
var query = UserSessionQueries.searchByBrokerUserId(userTx.getCache(), realm.getId(), brokerUserId);
return QueryHelper.streamAll(query, batchSize, userTx::wrapFromProjection)
.map(session -> initUserSessionFromQuery(session, realm, null, offline))
.filter(Objects::nonNull)
.map(UserSessionModel.class::cast);
}
private Stream<UserSessionModel> streamUserSessionByClientId(RealmModel realm, String clientId, boolean offline, Integer offset, Integer maxResults) {
var userSessionIdQuery = ClientSessionQueries.fetchUserSessionIdForClientId(getClientSessionTransaction(offline).getCache(), realm.getId(), clientId);
if (offset != null) {
userSessionIdQuery.startOffset(offset);
}
userSessionIdQuery.maxResults(maxResults == null ? Integer.MAX_VALUE : maxResults);
var userSessionTx = getUserSessionTransaction(offline);
return Flowable.fromIterable(QueryHelper.toCollection(userSessionIdQuery, QueryHelper.SINGLE_PROJECTION_TO_STRING))
.flatMapMaybe(userSessionTx::maybeGet, false, MAX_CONCURRENT_REQUESTS)
.blockingStream(batchSize)
.map(session -> initUserSessionFromQuery(session, realm, null, offline))
.filter(Objects::nonNull)
.map(UserSessionModel.class::cast);
}
private static <K, V, T extends BaseUpdater<K, V>> T checkExpiration(T updater) {
var expiration = updater.computeExpiration();
if (expiration.isExpired()) {
updater.markDeleted();
return null;
}
return updater;
}
private class ClientSessionMapping extends AbstractMap<String, AuthenticatedClientSessionModel> implements Consumer<RemoteAuthenticatedClientSessionEntity> {
private final UserSessionUpdater userSession;
private boolean coldCache = true;
ClientSessionMapping(UserSessionUpdater userSession) {
this.userSession = userSession;
}
@Override
public void clear() {
getTransaction().removeByUserSessionId(getUserSessionId());
}
@Override
public AuthenticatedClientSessionModel get(Object key) {
var updater = getTransaction().get(keyForClientId(key));
return initClientSessionUpdater(updater, userSession);
}
@Override
public AuthenticatedClientSessionModel remove(Object key) {
getTransaction().remove(keyForClientId(key));
return null;
}
@Override
public boolean containsKey(Object key) {
return get(key) != null;
}
@SuppressWarnings("NullableProblems")
@Override
public Set<Entry<String, AuthenticatedClientSessionModel>> entrySet() {
if (coldCache) {
fetchAndCacheClientSessions();
coldCache = false;
}
// iterate from the locally cached data.
return getTransaction().getClientSessions()
.filter(this::isFromUserSession)
.map(this::initialize)
.filter(Objects::nonNull)
.map(RemoteUserSessionProvider::toMapEntry)
.collect(Collectors.toSet());
}
private ClientSessionKey keyForClientId(String clientId) {
return new ClientSessionKey(getUserSessionId(), clientId);
}
private ClientSessionKey keyForClientId(Object clientId) {
return keyForClientId(String.valueOf(clientId));
}
private void fetchAndCacheClientSessions() {
var query = ClientSessionQueries.fetchClientSessions(getTransaction().getCache(), getUserSessionId());
QueryHelper.streamAll(query, batchSize, Function.identity()).forEach(this);
}
@Override
public void accept(RemoteAuthenticatedClientSessionEntity entity) {
getTransaction().wrapFromProjection(entity);
}
private ClientSessionChangeLogTransaction getTransaction() {
return getClientSessionTransaction(userSession.isOffline());
}
private String getUserSessionId() {
return userSession.getKey();
}
private boolean isFromUserSession(AuthenticatedClientSessionUpdater updater) {
return Objects.equals(getUserSessionId(), updater.getValue().getUserSessionId());
}
private AuthenticatedClientSessionModel initialize(AuthenticatedClientSessionUpdater updater) {
return initClientSessionUpdater(updater, userSession);
}
}
private static Map.Entry<String, AuthenticatedClientSessionModel> toMapEntry(AuthenticatedClientSessionModel model) {
return Map.entry(model.getClient().getId(), model);
}
}