RootAuthenticationSessionUpdater.java
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.sessions.infinispan.changes.remote.updater.authsession;
import org.keycloak.common.util.Base64Url;
import org.keycloak.common.util.SecretGenerator;
import org.keycloak.common.util.Time;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.sessions.infinispan.AuthenticationSessionAdapter;
import org.keycloak.models.sessions.infinispan.SessionEntityUpdater;
import org.keycloak.models.sessions.infinispan.changes.remote.updater.BaseUpdater;
import org.keycloak.models.sessions.infinispan.changes.remote.updater.Expiration;
import org.keycloak.models.sessions.infinispan.entities.AuthenticationSessionEntity;
import org.keycloak.models.sessions.infinispan.entities.RootAuthenticationSessionEntity;
import org.keycloak.models.sessions.infinispan.util.SessionTimeouts;
import org.keycloak.sessions.AuthenticationSessionModel;
import org.keycloak.sessions.RootAuthenticationSessionModel;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
public class RootAuthenticationSessionUpdater extends BaseUpdater<String, RootAuthenticationSessionEntity> implements RootAuthenticationSessionModel {
private final static Comparator<Map.Entry<String, AuthenticationSessionEntity>> TIMESTAMP_COMPARATOR =
Comparator.comparingInt(e -> e.getValue().getTimestamp());
private final List<Consumer<RootAuthenticationSessionEntity>> changes;
private RealmModel realm;
private KeycloakSession session;
private int authSessionsLimit;
private RootAuthenticationSessionUpdater(String key, RootAuthenticationSessionEntity entity, long version, UpdaterState initialState) {
super(key, entity, version, initialState);
if (entity == null) {
assert initialState == UpdaterState.DELETED;
changes = List.of();
return;
}
changes = new ArrayList<>(4);
}
public synchronized void initialize(KeycloakSession session, RealmModel realm, int authSessionsLimit) {
this.session = session;
this.realm = realm;
this.authSessionsLimit = authSessionsLimit;
}
/**
* @return {@code true} if it is already initialized.
*/
public synchronized boolean isInitialized() {
return session != null;
}
@Override
protected boolean isUnchanged() {
return changes.isEmpty();
}
public static RootAuthenticationSessionUpdater create(String key, RootAuthenticationSessionEntity entity) {
return new RootAuthenticationSessionUpdater(key, Objects.requireNonNull(entity), NO_VERSION, UpdaterState.CREATED);
}
public static RootAuthenticationSessionUpdater wrap(String key, RootAuthenticationSessionEntity value, long version) {
return new RootAuthenticationSessionUpdater(key, Objects.requireNonNull(value), version, UpdaterState.READ);
}
public static RootAuthenticationSessionUpdater delete(String key) {
return new RootAuthenticationSessionUpdater(key, null, NO_VERSION, UpdaterState.DELETED);
}
@Override
public Expiration computeExpiration() {
return new Expiration(
SessionTimeouts.getAuthSessionMaxIdleMS(realm, null, getValue()),
SessionTimeouts.getAuthSessionLifespanMS(realm, null, getValue()));
}
@Override
public RootAuthenticationSessionEntity apply(String ignored, RootAuthenticationSessionEntity cachedEntity) {
assert !isDeleted();
assert !isReadOnly();
if (cachedEntity == null) {
//entity removed
return null;
}
changes.forEach(c -> c.accept(cachedEntity));
return cachedEntity.getAuthenticationSessions().isEmpty() ? null : cachedEntity;
}
@Override
public String getId() {
return getKey();
}
@Override
public RealmModel getRealm() {
return realm;
}
@Override
public int getTimestamp() {
return getValue().getTimestamp();
}
@Override
public void setTimestamp(int timestamp) {
addAndApplyChange(entity -> entity.setTimestamp(timestamp));
}
@Override
public Map<String, AuthenticationSessionModel> getAuthenticationSessions() {
Map<String, AuthenticationSessionModel> result = new HashMap<>();
for (Map.Entry<String, AuthenticationSessionEntity> entry : getValue().getAuthenticationSessions().entrySet()) {
String tabId = entry.getKey();
result.put(tabId, new AuthenticationSessionAdapter(session, this, new AuthenticationSessionUpdater(this, tabId, entry.getValue()), tabId));
}
return result;
}
@Override
public AuthenticationSessionModel getAuthenticationSession(ClientModel client, String tabId) {
if (client == null || tabId == null) {
return null;
}
AuthenticationSessionModel authSession = getAuthenticationSessions().get(tabId);
if (authSession != null && client.equals(authSession.getClient())) {
session.getContext().setAuthenticationSession(authSession);
return authSession;
} else {
return null;
}
}
@Override
public AuthenticationSessionModel createAuthenticationSession(ClientModel client) {
Objects.requireNonNull(client, "client");
AuthenticationSessionEntity authSessionEntity = new AuthenticationSessionEntity();
authSessionEntity.setClientUUID(client.getId());
String newTabId = Base64Url.encode(SecretGenerator.getInstance().randomBytes(8));
int timestamp = Time.currentTime();
addAndApplyChange(entity -> {
Map<String, AuthenticationSessionEntity> authenticationSessions = entity.getAuthenticationSessions();
if (authenticationSessions.size() >= authSessionsLimit && !authenticationSessions.containsKey(newTabId)) {
authenticationSessions.entrySet().stream()
.min(TIMESTAMP_COMPARATOR)
.map(Map.Entry::getKey)
.ifPresent(authenticationSessions::remove);
}
authSessionEntity.setTimestamp(timestamp);
authenticationSessions.put(newTabId, authSessionEntity);
// Update our timestamp when adding new authenticationSession
entity.setTimestamp(timestamp);
});
AuthenticationSessionAdapter authSession = new AuthenticationSessionAdapter(session, this, new AuthenticationSessionUpdater(this, newTabId, authSessionEntity), newTabId);
session.getContext().setAuthenticationSession(authSession);
return authSession;
}
@Override
public void removeAuthenticationSessionByTabId(String tabId) {
if (getValue().getAuthenticationSessions().remove(tabId) != null) {
if (getValue().getAuthenticationSessions().isEmpty()) {
markDeleted();
} else {
int currentTime = Time.currentTime();
addAndApplyChange(entity -> {
entity.getAuthenticationSessions().remove(tabId);
entity.setTimestamp(currentTime);
});
}
}
}
@Override
public void restartSession(RealmModel realm) {
addAndApplyChange(entity -> {
entity.getAuthenticationSessions().clear();
entity.setTimestamp(Time.currentTime());
});
}
private void addAndApplyChange(Consumer<RootAuthenticationSessionEntity> change) {
changes.add(change);
change.accept(getValue());
}
private record AuthenticationSessionUpdater(RootAuthenticationSessionUpdater updater, String tabId, AuthenticationSessionEntity authenticationSession) implements SessionEntityUpdater<AuthenticationSessionEntity> {
@Override
public AuthenticationSessionEntity getEntity() {
return authenticationSession;
}
@Override
public void onEntityUpdated() {
updater.addAndApplyChange(entity-> {
entity.getAuthenticationSessions().put(tabId, authenticationSession);
});
}
@Override
public void onEntityRemoved() {
}
}
}