InfinispanIdentityProviderStorageProvider.java

/*
 * Copyright 2024 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.keycloak.models.cache.infinispan.idp;

import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.keycloak.common.Profile;
import org.keycloak.models.IdentityProviderMapperModel;
import org.keycloak.models.IdentityProviderStorageProvider;
import org.keycloak.models.IdentityProviderModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.ModelException;
import org.keycloak.models.OrganizationModel;
import org.keycloak.models.RealmModel;
import org.keycloak.models.cache.CacheRealmProvider;
import org.keycloak.models.cache.infinispan.CachedCount;
import org.keycloak.models.cache.infinispan.RealmCacheManager;
import org.keycloak.models.cache.infinispan.RealmCacheSession;
import org.keycloak.organization.OrganizationProvider;

import static org.keycloak.models.IdentityProviderStorageProvider.LoginFilter.getLoginPredicate;

public class InfinispanIdentityProviderStorageProvider implements IdentityProviderStorageProvider {

    private static final String IDP_COUNT_KEY_SUFFIX = ".idp.count";
    private static final String IDP_ALIAS_KEY_SUFFIX = ".idp.alias";
    private static final String IDP_ORG_ID_KEY_SUFFIX = ".idp.orgId";
    private static final String IDP_LOGIN_SUFFIX = ".idp.login";

    private final KeycloakSession session;
    private final IdentityProviderStorageProvider idpDelegate;
    private final RealmCacheSession realmCache;
    private final long startupRevision;

    public InfinispanIdentityProviderStorageProvider(KeycloakSession session) {
        this.session = session;
        this.idpDelegate = session.getProvider(IdentityProviderStorageProvider.class, "jpa");
        this.realmCache = (RealmCacheSession) session.getProvider(CacheRealmProvider.class);
        this.startupRevision = realmCache.getCache().getCurrentCounter();
    }

    private static String cacheKeyIdpCount(RealmModel realm) {
        return realm.getId() + IDP_COUNT_KEY_SUFFIX;
    }

    private static String cacheKeyIdpAlias(RealmModel realm, String alias) {
        return realm.getId() + "." + alias + IDP_ALIAS_KEY_SUFFIX;
    }

    private static String cacheKeyIdpMapperAliasName(RealmModel realm, String alias, String name) {
        return realm.getId() + "." + alias + IDP_ALIAS_KEY_SUFFIX + "." + name;
    }

    public static String cacheKeyOrgId(RealmModel realm, String orgId) {
        return realm.getId() + "." + orgId + IDP_ORG_ID_KEY_SUFFIX;
    }

    public static String cacheKeyForLogin(RealmModel realm, FetchMode fetchMode) {
        return realm.getId() + IDP_LOGIN_SUFFIX + "." + fetchMode;
    }

    @Override
    public IdentityProviderModel create(IdentityProviderModel model) {
        registerCountInvalidation();
        registerIDPLoginInvalidation(model);
        return idpDelegate.create(model);
    }

    @Override
    public void update(IdentityProviderModel model) {
        // for cases the alias is being updated, it is needed to lookup the idp by id to obtain the original alias
        IdentityProviderModel idpById = getById(model.getInternalId());
        registerIDPInvalidation(idpById);
        registerIDPLoginInvalidationOnUpdate(idpById, model);
        idpDelegate.update(model);
    }

    @Override
    public boolean remove(String alias) {
        String cacheKey = cacheKeyIdpAlias(getRealm(), alias);
        IdentityProviderModel storedIdp = idpDelegate.getByAlias(alias);
        if (isInvalid(cacheKey)) {
            //lookup idp by alias in cache to be able to invalidate its internalId
            registerIDPInvalidation(storedIdp);
        } else {
            CachedIdentityProvider cached = realmCache.getCache().get(cacheKey, CachedIdentityProvider.class);
            if (cached != null) {
                registerIDPInvalidation(cached.getIdentityProvider());
            }
        }
        registerCountInvalidation();
        registerIDPLoginInvalidation(storedIdp);
        return idpDelegate.remove(alias);
    }

    @Override
    public void removeAll() {
        registerCountInvalidation();
        // no need to invalidate each entry in cache, removeAll() is (currently) called only in case the realm is being deleted
        idpDelegate.removeAll();
    }

    @Override
    public IdentityProviderModel getById(String internalId) {
        if (internalId == null) return null;
        CachedIdentityProvider cached = realmCache.getCache().get(internalId, CachedIdentityProvider.class);
        String realmId = getRealm().getId();
        if (cached != null && !cached.getRealm().equals(realmId)) {
            cached = null;
        }

        if (cached == null) {
            Long loaded = realmCache.getCache().getCurrentRevision(internalId);
            IdentityProviderModel model = idpDelegate.getById(internalId);
            if (model == null) return null;
            if (isInvalid(internalId)) return createOrganizationAwareIdentityProviderModel(model);
            cached = new CachedIdentityProvider(loaded, getRealm(), internalId, model);
            realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());
        } else if (isInvalid(internalId)) {
            return createOrganizationAwareIdentityProviderModel(idpDelegate.getById(internalId));
        }
        return createOrganizationAwareIdentityProviderModel(cached.getIdentityProvider());
    }

    @Override
    public IdentityProviderModel getByAlias(String alias) {
        String cacheKey = cacheKeyIdpAlias(getRealm(), alias);

        if (isInvalid(cacheKey)) {
            return createOrganizationAwareIdentityProviderModel(idpDelegate.getByAlias(alias));
        }

        CachedIdentityProvider cached = realmCache.getCache().get(cacheKey, CachedIdentityProvider.class);

        if (cached == null) {
            Long loaded = realmCache.getCache().getCurrentRevision(cacheKey);
            IdentityProviderModel model = idpDelegate.getByAlias(alias);
            if (model == null) {
                return null;
            }
            cached = new CachedIdentityProvider(loaded, getRealm(), cacheKey, model);
            realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());
        }

        return createOrganizationAwareIdentityProviderModel(cached.getIdentityProvider());
    }

    @Override
    public Stream<IdentityProviderModel> getByOrganization(String orgId, Integer first, Integer max) {
        RealmModel realm = getRealm();
        String cacheKey = cacheKeyOrgId(realm, orgId);

        // check if there is invalidation for this key or the organization was invalidated
        if (isInvalid(cacheKey) || isInvalid(orgId)) {
            return idpDelegate.getByOrganization(orgId, first, max).map(this::createOrganizationAwareIdentityProviderModel);
        }

        RealmCacheManager cache = realmCache.getCache();
        IdentityProviderListQuery query = cache.get(cacheKey, IdentityProviderListQuery.class);
        String searchKey = Optional.ofNullable(first).orElse(-1) + "." + Optional.ofNullable(max).orElse(-1);
        Set<String> cached;

        if (query == null) {
            // not cached yet
            Long loaded = cache.getCurrentRevision(cacheKey);
            cached = idpDelegate.getByOrganization(orgId, first, max).map(IdentityProviderModel::getInternalId).collect(Collectors.toSet());
            query = new IdentityProviderListQuery(loaded, cacheKey, realm, searchKey, cached);
            cache.addRevisioned(query, startupRevision);
        } else {
            cached = query.getIDPs(searchKey);
            if (cached == null) {
                // there is a cache entry, but the current search is not yet cached
                cache.invalidateObject(cacheKey);
                Long loaded = cache.getCurrentRevision(cacheKey);
                cached = idpDelegate.getByOrganization(orgId, first, max).map(IdentityProviderModel::getInternalId).collect(Collectors.toSet());
                query = new IdentityProviderListQuery(loaded, cacheKey, realm, searchKey, cached, query);
                cache.addRevisioned(query, cache.getCurrentCounter());
            }
        }

        Set<IdentityProviderModel> identityProviders = new HashSet<>();
        for (String id : cached) {
            IdentityProviderModel idp = session.identityProviders().getById(id);
            if (idp == null) {
                realmCache.registerInvalidation(cacheKey);
                return idpDelegate.getByOrganization(orgId, first, max).map(this::createOrganizationAwareIdentityProviderModel);
            }
            identityProviders.add(idp);
        }

        return identityProviders.stream();
    }

    @Override
    public Stream<IdentityProviderModel> getForLogin(FetchMode mode, String organizationId) {
        String cacheKey = cacheKeyForLogin(getRealm(), mode);

        if (isInvalid(cacheKey)) {
            return idpDelegate.getForLogin(mode, organizationId).map(this::createOrganizationAwareIdentityProviderModel);
        }

        RealmCacheManager cache = realmCache.getCache();
        IdentityProviderListQuery query = cache.get(cacheKey, IdentityProviderListQuery.class);
        String searchKey = organizationId != null ? organizationId : "";
        Set<String> cached;

        if (query == null) {
            // not cached yet
            Long loaded = cache.getCurrentRevision(cacheKey);
            cached = idpDelegate.getForLogin(mode, organizationId).map(IdentityProviderModel::getInternalId).collect(Collectors.toSet());
            query = new IdentityProviderListQuery(loaded, cacheKey, getRealm(), searchKey, cached);
            cache.addRevisioned(query, startupRevision);
        } else {
            cached = query.getIDPs(searchKey);
            if (cached == null) {
                // there is a cache entry, but the current search is not yet cached
                cache.invalidateObject(cacheKey);
                Long loaded = cache.getCurrentRevision(cacheKey);
                cached = idpDelegate.getForLogin(mode, organizationId).map(IdentityProviderModel::getInternalId).collect(Collectors.toSet());
                query = new IdentityProviderListQuery(loaded, cacheKey, getRealm(), searchKey, cached, query);
                cache.addRevisioned(query, cache.getCurrentCounter());
            }
        }

        Set<IdentityProviderModel> identityProviders = new HashSet<>();
        for (String id : cached) {
            IdentityProviderModel idp = session.identityProviders().getById(id);
            if (idp == null) {
                realmCache.registerInvalidation(cacheKey);
                return idpDelegate.getForLogin(mode, organizationId).map(this::createOrganizationAwareIdentityProviderModel);
            }
            identityProviders.add(idp);
        }

        return identityProviders.stream();
    }

    @Override
    public Stream<String> getByFlow(String flowId, String search, Integer first, Integer max) {
        return idpDelegate.getByFlow(flowId, search, first, max);
    }

    @Override
    public Stream<IdentityProviderModel> getAllStream(Map<String, String> attrs, Integer first, Integer max) {
        return idpDelegate.getAllStream(attrs, first, max).map(this::createOrganizationAwareIdentityProviderModel);
    }

    @Override
    public long count() {
        String cacheKey = cacheKeyIdpCount(getRealm());
        CachedCount cached = realmCache.getCache().get(cacheKey, CachedCount.class);

        // cached and not invalidated
        if (cached != null && !isInvalid(cacheKey)) {
            return cached.getCount();
        }

        Long loaded = realmCache.getCache().getCurrentRevision(cacheKey);
        long count = idpDelegate.count();
        cached = new CachedCount(loaded, getRealm(), cacheKey, count);
        realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());

        return count;
    }

    @Override
    public void close() {
        idpDelegate.close();
    }

    @Override
    public IdentityProviderMapperModel createMapper(IdentityProviderMapperModel model) {
        return idpDelegate.createMapper(model);
    }

    @Override
    public void updateMapper(IdentityProviderMapperModel model) {
        registerIDPMapperInvalidation(model);
        idpDelegate.updateMapper(model);
    }

    @Override
    public boolean removeMapper(IdentityProviderMapperModel model) {
        registerIDPMapperInvalidation(model);
        return idpDelegate.removeMapper(model);
    }

    @Override
    public void removeAllMappers() {
        // no need to invalidate each entry in cache, removeAllMappers() is (currently) called only in case the realm is being deleted
        idpDelegate.removeAllMappers();
    }

    @Override
    public IdentityProviderMapperModel getMapperById(String id) {
        CachedIdentityProviderMapper cached = realmCache.getCache().get(id, CachedIdentityProviderMapper.class);
        String realmId = getRealm().getId();
        if (cached != null && !cached.getRealm().equals(realmId)) {
            cached = null;
        }

        if (cached == null) {
            Long loaded = realmCache.getCache().getCurrentRevision(id);
            IdentityProviderMapperModel model = idpDelegate.getMapperById(id);
            if (model == null) return null;
            if (isInvalid(id)) return model;
            cached = new CachedIdentityProviderMapper(loaded, getRealm(), id, model);
            realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());
        } else if (isInvalid(id)) {
            return idpDelegate.getMapperById(id);
        }
        return cached.getIdentityProviderMapper();
    }

    @Override
    public IdentityProviderMapperModel getMapperByName(String identityProviderAlias, String name) {
        String cacheKey = cacheKeyIdpMapperAliasName(getRealm(), identityProviderAlias, name);

        if (isInvalid(cacheKey)) {
            return idpDelegate.getMapperByName(identityProviderAlias, name);
        }

        CachedIdentityProviderMapper cached = realmCache.getCache().get(cacheKey, CachedIdentityProviderMapper.class);

        if (cached == null) {
            Long loaded = realmCache.getCache().getCurrentRevision(cacheKey);
            IdentityProviderMapperModel model = idpDelegate.getMapperByName(identityProviderAlias, name);
            if (model == null) return null;
            cached = new CachedIdentityProviderMapper(loaded, getRealm(), cacheKey, model);
            realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());
        }

        return cached.getIdentityProviderMapper();
    }

    @Override
    public Stream<IdentityProviderMapperModel> getMappersStream(Map<String, String> options, Integer first, Integer max) {
        return idpDelegate.getMappersStream(options, first, max);
    }

    @Override
    public Stream<IdentityProviderMapperModel> getMappersByAliasStream(String identityProviderAlias) {
        return idpDelegate.getMappersByAliasStream(identityProviderAlias);
    }

    private void registerIDPInvalidation(IdentityProviderModel idp) {
        realmCache.registerInvalidation(idp.getInternalId());
        realmCache.registerInvalidation(cacheKeyIdpAlias(getRealm(), idp.getAlias()));
    }

    private void registerCountInvalidation() {
        realmCache.registerInvalidation(cacheKeyIdpCount(getRealm()));
    }

    private void registerIDPMapperInvalidation(IdentityProviderMapperModel mapper) {
        if (mapper.getId() == null) {
            throw new ModelException("Identity Provider Mapper does not exist");
        }
        realmCache.registerInvalidation(mapper.getId());
        realmCache.registerInvalidation(cacheKeyIdpMapperAliasName(getRealm(), mapper.getIdentityProviderAlias(), mapper.getName()));
    }

    private void registerIDPLoginInvalidation(IdentityProviderModel idp) {
        // only invalidate login caches if the IDP qualifies as a login IDP.
        if (getLoginPredicate().test(idp)) {
            for (FetchMode mode : FetchMode.values()) {
                realmCache.registerInvalidation(cacheKeyForLogin(getRealm(), mode));
            }
        }
    }

    /**
     * Registers invalidations for the caches that hold the IDPs available for login when an IDP is updated. The caches
     * are <strong>NOT</strong> invalidated if:
     * <ul>
     *     <li>IDP is currently NOT a login IDP, and the update hasn't changed that (i.e. it continues to be unavailable for login);</li>
     *     <li>IDP is currently a login IDP, and the update hasn't changed that. This includes the organization link not being updated as well</li>
     * </ul>
     * In all other scenarios, the caches must be invalidated.
     *
     * @param original the identity provider's current model
     * @param updated the identity provider's updated model
     */
    private void registerIDPLoginInvalidationOnUpdate(IdentityProviderModel original, IdentityProviderModel updated) {
        // IDP isn't currently available for login and update preserves that - no need to invalidate.
        if (!getLoginPredicate().test(original) && !getLoginPredicate().test(updated)) {
            return;
        }
        // IDP is currently available for login and update preserves that, including organization link - no need to invalidate.
        if (getLoginPredicate().test(original) && getLoginPredicate().test(updated)
                && Objects.equals(original.getOrganizationId(), updated.getOrganizationId())) {
            return;
        }

        // all other scenarios should invalidate the login caches.
        for (FetchMode mode : FetchMode.values()) {
            realmCache.registerInvalidation(cacheKeyForLogin(getRealm(), mode));
        }
    }

    private RealmModel getRealm() {
        RealmModel realm = session.getContext().getRealm();
        if (realm == null) {
            throw new IllegalArgumentException("Session not bound to a realm");
        }
        return realm;
    }

    private boolean isInvalid(String cacheKey) {
        return realmCache.isInvalid(cacheKey);
    }

    private IdentityProviderModel createOrganizationAwareIdentityProviderModel(IdentityProviderModel idp) {
        if (!Profile.isFeatureEnabled(Profile.Feature.ORGANIZATION)) return idp;
        return new IdentityProviderModel(idp) {
            @Override
            public boolean isEnabled() {
                // if IdP is bound to an org
                if (getOrganizationId() != null) {
                    OrganizationProvider provider = session.getProvider(OrganizationProvider.class);
                    OrganizationModel org = provider == null ? null : provider.getById(getOrganizationId());
                    return org != null && provider.isEnabled() && org.isEnabled() && super.isEnabled();
                }
                return super.isEnabled();
            }
        };
    }
}