UserSessionLimitsAuthenticator.java

package org.keycloak.authentication.authenticators.sessionlimits;

import java.util.Collections;
import org.jboss.logging.Logger;
import org.keycloak.authentication.AuthenticationFlowException;
import org.keycloak.authentication.Authenticator;
import org.keycloak.authentication.AuthenticationFlowContext;
import org.keycloak.authentication.AuthenticationFlowError;
import org.keycloak.models.AuthenticatorConfigModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.UserSessionModel;
import org.keycloak.representations.idm.OAuth2ErrorRepresentation;
import org.keycloak.services.managers.AuthenticationManager;

import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import org.keycloak.events.Errors;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.utils.StringUtil;

import static org.keycloak.utils.LockObjectsForModification.lockUserSessionsForModification;

public class UserSessionLimitsAuthenticator implements Authenticator {

    private static Logger logger = Logger.getLogger(UserSessionLimitsAuthenticator.class);
    public static final String SESSION_LIMIT_EXCEEDED = "sessionLimitExceeded";
    private static String realmEventDetailsTemplate = "Realm session limit exceeded. Realm: %s, Realm limit: %s. Session count: %s, User id: %s";
    private static String clientEventDetailsTemplate = "Client session limit exceeded. Realm: %s, Client limit: %s. Session count: %s, User id: %s";
    protected KeycloakSession session;

    String behavior;

    public UserSessionLimitsAuthenticator(KeycloakSession session) {
        this.session = session;
    }

    @Override
    public void authenticate(AuthenticationFlowContext context) {
        AuthenticatorConfigModel authenticatorConfig = context.getAuthenticatorConfig();
        if (authenticatorConfig == null) {
            throw new AuthenticationFlowException("No configuration found of 'User Session Count Limiter' authenticator. Please make sure to configure this authenticator in your authentication flow in the realm '" + context.getRealm().getName() + "'!"
                    , AuthenticationFlowError.INTERNAL_ERROR);
        }
        Map<String, String> config = authenticatorConfig.getConfig();

        // Get the configuration for this authenticator
        behavior = config.get(UserSessionLimitsAuthenticatorFactory.BEHAVIOR);
        int userRealmLimit = getIntConfigProperty(UserSessionLimitsAuthenticatorFactory.USER_REALM_LIMIT, config);
        int userClientLimit = getIntConfigProperty(UserSessionLimitsAuthenticatorFactory.USER_CLIENT_LIMIT, config);

        if (context.getRealm() != null && context.getUser() != null) {

            // Get the session count in this realm for this specific user
            List<UserSessionModel> userSessionsForRealm = lockUserSessionsForModification(session, () -> session.sessions().getUserSessionsStream(context.getRealm(), context.getUser()).collect(Collectors.toList()));
            int userSessionCountForRealm = userSessionsForRealm.size();

            // Get the session count related to the current client for this user
            ClientModel currentClient = context.getAuthenticationSession().getClient();
            logger.debugf("session-limiter's current keycloak clientId: %s", currentClient.getClientId());

            List<UserSessionModel> userSessionsForClient = getUserSessionsForClientIfEnabled(userSessionsForRealm, currentClient, userClientLimit);
            int userSessionCountForClient = userSessionsForClient.size();
            logger.debugf("session-limiter's configured realm session limit: %s", userRealmLimit);
            logger.debugf("session-limiter's configured client session limit: %s", userClientLimit);
            logger.debugf("session-limiter's count of total user sessions for the entire realm (could be apps other than web apps): %s", userSessionCountForRealm);
            logger.debugf("session-limiter's count of total user sessions for this keycloak client: %s", userSessionCountForClient);

            // First check if the user has too many sessions in this realm
            if (exceedsLimit(userSessionCountForRealm, userRealmLimit)) {
                logger.infof("Too many session in this realm for the current user. Session count: %s", userSessionCountForRealm);
                String eventDetails = String.format(realmEventDetailsTemplate, context.getRealm().getName(), userRealmLimit, userSessionCountForRealm, context.getUser().getId());
                handleLimitExceeded(context, userSessionsForRealm, eventDetails, userRealmLimit);
            } // otherwise if the user is still allowed to create a new session in the realm, check if this applies for this specific client as well.
            else if (exceedsLimit(userSessionCountForClient, userClientLimit)) {
                logger.infof("Too many sessions related to the current client for this user. Session count: %s", userSessionCountForRealm);
                String eventDetails = String.format(clientEventDetailsTemplate, context.getRealm().getName(), userClientLimit, userSessionCountForClient, context.getUser().getId());
                handleLimitExceeded(context, userSessionsForClient, eventDetails, userClientLimit);
            } else {
                context.success();
            }
        } else {
            context.success();
        }
    }

    private boolean exceedsLimit(long count, long limit) {
        if (limit <= 0) { // if limit is zero or negative, consider the limit disabled
            return false;
        }
        return getNumberOfSessionsThatNeedToBeLoggedOut(count, limit) > 0;
    }

    private long getNumberOfSessionsThatNeedToBeLoggedOut(long count, long limit) {
        return count - (limit - 1);
    }

    private int getIntConfigProperty(String key, Map<String, String> config) {
        String value = config.get(key);
        if (StringUtil.isBlank(value)) {
            return -1;
        }
        return Integer.parseInt(value);
    }

    private List<UserSessionModel> getUserSessionsForClientIfEnabled(List<UserSessionModel> userSessionsForRealm, ClientModel currentClient, int userClientLimit) {
        // Only count this users sessions for this client only in case a limit is configured, otherwise skip this costly operation.
        if (userClientLimit <= 0) {
            return Collections.EMPTY_LIST;
        }
        logger.debugf("total user sessions for this keycloak client will not be counted. Will be logged as 0 (zero)");
        List<UserSessionModel> userSessionsForClient = userSessionsForRealm.stream().filter(session -> session.getAuthenticatedClientSessionByClient(currentClient.getId()) != null).collect(Collectors.toList());
        return userSessionsForClient;
    }

    @Override
    public void action(AuthenticationFlowContext context) {

    }

    @Override
    public boolean requiresUser() {
        return false;
    }

    @Override
    public boolean configuredFor(KeycloakSession session, RealmModel realm, UserModel user) {
        return true;
    }

    @Override
    public void setRequiredActions(KeycloakSession session, RealmModel realm, UserModel user) {

    }

    @Override
    public void close() {

    }

    private void handleLimitExceeded(AuthenticationFlowContext context, List<UserSessionModel> userSessions, String eventDetails, long limit) {
        switch (behavior) {
            case UserSessionLimitsAuthenticatorFactory.DENY_NEW_SESSION:
                logger.info("Denying new session");
                String errorMessage = Optional.ofNullable(context.getAuthenticatorConfig())
                        .map(AuthenticatorConfigModel::getConfig)
                        .map(f -> f.get(UserSessionLimitsAuthenticatorFactory.ERROR_MESSAGE))
                        .orElse(SESSION_LIMIT_EXCEEDED);

                context.getEvent().error(Errors.GENERIC_AUTHENTICATION_ERROR);
                Response challenge = null;
                if(context.getFlowPath() == null) {
                    OAuth2ErrorRepresentation errorRep = new OAuth2ErrorRepresentation(Errors.GENERIC_AUTHENTICATION_ERROR, errorMessage);
                    challenge = Response.status(Response.Status.UNAUTHORIZED.getStatusCode()).entity(errorRep).type(MediaType.APPLICATION_JSON_TYPE).build();
                }
                else {
                    challenge = context.form().setError(errorMessage).createErrorPage(Response.Status.FORBIDDEN);
                }
                context.failure(AuthenticationFlowError.GENERIC_AUTHENTICATION_ERROR, challenge, eventDetails, errorMessage);
                break;

            case UserSessionLimitsAuthenticatorFactory.TERMINATE_OLDEST_SESSION:
                logger.info("Terminating oldest session");
                logoutOldestSessions(userSessions, limit);
                context.success();
                break;
        }
    }

    private void logoutOldestSessions(List<UserSessionModel> userSessions, long limit) {
        long numberOfSessionsThatNeedToBeLoggedOut = getNumberOfSessionsThatNeedToBeLoggedOut(userSessions.size(), limit);
        if (numberOfSessionsThatNeedToBeLoggedOut == 1) {
            logger.info("Logging out oldest session");
        } else {
            logger.infof("Logging out oldest %s sessions", numberOfSessionsThatNeedToBeLoggedOut);
        }

        userSessions
            .stream()
            .sorted(Comparator.comparingInt(UserSessionModel::getLastSessionRefresh))
            .limit(numberOfSessionsThatNeedToBeLoggedOut)
            .forEach(userSession -> AuthenticationManager.backchannelLogout(session, userSession, true));
    }
}