InfinispanOrganizationProvider.java

/*
 * Copyright 2024 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.keycloak.models.cache.infinispan.organization;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.keycloak.models.IdentityProviderModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.OrganizationDomainModel;
import org.keycloak.models.OrganizationModel;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.cache.CacheRealmProvider;
import org.keycloak.models.cache.UserCache;
import org.keycloak.models.cache.infinispan.CachedCount;
import org.keycloak.models.cache.infinispan.RealmCacheSession;
import org.keycloak.models.cache.infinispan.UserCacheSession;
import org.keycloak.organization.OrganizationProvider;

import static org.keycloak.models.cache.infinispan.idp.InfinispanIdentityProviderStorageProvider.cacheKeyOrgId;

public class InfinispanOrganizationProvider implements OrganizationProvider {

    private static final String ORG_COUNT_KEY_SUFFIX = ".org.count";
    private static final String ORG_MEMBERS_COUNT_KEY_SUFFIX = ".members.count";

    private final KeycloakSession session;
    private final UserCacheSession userCache;
    private OrganizationProvider orgDelegate;
    private final RealmCacheSession realmCache;
    private final Map<String, OrganizationAdapter> managedOrganizations = new HashMap<>();

    public InfinispanOrganizationProvider(KeycloakSession session) {
        this.session = session;
        this.realmCache = (RealmCacheSession) session.getProvider(CacheRealmProvider.class);
        this.userCache = (UserCacheSession) session.getProvider(UserCache.class);
    }

    private static String cacheKeyOrgCount(RealmModel realm) {
        return realm.getId() + ORG_COUNT_KEY_SUFFIX;
    }

    public static String cacheKeyOrgMemberCount(RealmModel realm, OrganizationModel organization) {
        return realm.getId() + ".org." + organization.getId() + ORG_MEMBERS_COUNT_KEY_SUFFIX;
    }

    @Override
    public OrganizationModel create(String id, String name, String alias) {
        registerCountInvalidation();
        return getDelegate().create(id, name, alias);
    }

    private OrganizationProvider getDelegate() {
        if (orgDelegate == null) {
            // use lazy initialization to avoid touching the entity manager
            orgDelegate = session.getProvider(OrganizationProvider.class, "jpa");
        }
        return orgDelegate;
    }

    @Override
    public boolean remove(OrganizationModel organization) {
        registerOrganizationInvalidation(organization);
        registerCountInvalidation();
        return getDelegate().remove(organization);
    }

    @Override
    public OrganizationModel getById(String id) {
        if (realmCache == null) {
            return getDelegate().getById(id);
        }

        CachedOrganization cached = realmCache.getCache().get(id, CachedOrganization.class);
        String realmId = getRealm().getId();
        if (cached != null && !cached.getRealm().equals(realmId)) {
            cached = null;
        }

        if (cached == null) {
            Long loaded = realmCache.getCache().getCurrentRevision(id);
            OrganizationModel model = getDelegate().getById(id);
            if (model == null) return null;
            if (isRealmCacheKeyInvalid(id)) return model;
            cached = new CachedOrganization(loaded, getRealm(), model);
            realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());

        // no need to check for realm invalidation as IdP changes are handled by events within InfinispanOrganizationProviderFactory
        } else if (isRealmCacheKeyInvalid(id)) {
            return getDelegate().getById(id);
        } else if (managedOrganizations.containsKey(id)) {
            return managedOrganizations.get(id);
        }
        OrganizationAdapter adapter = new OrganizationAdapter(cached, () -> getDelegate(), this);
        managedOrganizations.put(id, adapter);
        return adapter;
    }

    @Override
    public OrganizationModel getByDomainName(String domainName) {
        if (realmCache == null) {
            return getDelegate().getByDomainName(domainName);
        }

        String cacheKey = cacheKeyByDomain(domainName);
        if (isRealmCacheKeyInvalid(cacheKey)) {
            return getDelegate().getByDomainName(domainName);
        }

        CachedOrganizationIds cached = realmCache.getCache().get(cacheKey, CachedOrganizationIds.class);

        if (cached == null) {
            Long loaded = realmCache.getCache().getCurrentRevision(cacheKey);
            OrganizationModel model = getDelegate().getByDomainName(domainName);
            if (model == null) {
                return null;
            }
            cached = new CachedOrganizationIds(loaded, cacheKey, getRealm(), Stream.ofNullable(model));
            realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());
        }

        return cached.getOrgIds().stream().map(this::getById).findAny().orElse(null);
    }

    @Override
    public Stream<OrganizationModel> getAllStream(String search, Boolean exact, Integer first, Integer max) {
        // Return cache delegates to ensure cache invalidation during write operations
        return getCacheDelegates(getDelegate().getAllStream(search, exact, first, max));
    }

    @Override
    public Stream<OrganizationModel> getAllStream(Map<String, String> attributes, Integer first, Integer max) {
        // Return cache delegates to ensure cache invalidation during write operations
        return getCacheDelegates(getDelegate().getAllStream(attributes, first, max));
    }

    @Override
    public void removeAll() {
        //TODO: won't scale, requires a better mechanism for bulk deleting organizations within a realm
        //this way, all organizations in the realm will be invalidated ... or should it be invalidated whole realm instead?
        getAllStream().forEach(this::remove);
    }

    @Override
    public boolean addManagedMember(OrganizationModel organization, UserModel user) {
        registerMemberInvalidation(organization, user);
        return getDelegate().addManagedMember(organization, user);
    }

    @Override
    public boolean addMember(OrganizationModel organization, UserModel user) {
        registerMemberInvalidation(organization, user);
        return getDelegate().addMember(organization, user);
    }

    @Override
    public boolean removeMember(OrganizationModel organization, UserModel member) {
        registerMemberInvalidation(organization, member);
        return getDelegate().removeMember(organization, member);
    }

    @Override
    public Stream<UserModel> getMembersStream(OrganizationModel organization, String search, Boolean exact, Integer first, Integer max) {
        Map<String, String> filters = Optional.ofNullable(search)
                .map(value -> Map.of(UserModel.SEARCH, value))
                .orElse(Map.of());
        return getMembersStream(organization, filters, exact, first, max);
    }

    @Override
    public Stream<UserModel> getMembersStream(OrganizationModel organization, Map<String, String> filters, Boolean exact, Integer first, Integer max) {
        return getDelegate().getMembersStream(organization, filters, exact, first, max);
    }

    @Override
    public long getMembersCount(OrganizationModel organization) {
        if (realmCache == null) {
            return getDelegate().getMembersCount(organization);
        }

        String cacheKey = cacheKeyOrgMemberCount(getRealm(), organization);
        CachedCount cached = realmCache.getCache().get(cacheKey, CachedCount.class);

        // cached and not invalidated
        if (cached != null && !isRealmCacheKeyInvalid(cacheKey)) {
            return cached.getCount();
        }

        Long loaded = realmCache.getCache().getCurrentRevision(cacheKey);
        long membersCount = getDelegate().getMembersCount(organization);
        cached = new CachedCount(loaded, getRealm(), cacheKey, membersCount);
        realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());

        return membersCount;
    }

    @Override
    public UserModel getMemberById(OrganizationModel organization, String id) {
        if (userCache == null) {
            return getDelegate().getMemberById(organization, id);
        }

        RealmModel realm = getRealm();
        UserModel user = session.users().getUserById(realm, id);

        if (user == null) {
            return null;
        }

        String cacheKey = cacheKeyMembership(realm, organization, user);

        if (isUserCacheKeyInvalid(cacheKey)) {
            return getDelegate().getMemberById(organization, user.getId());
        }

        CachedMembership cached = userCache.getCache().get(cacheKey, CachedMembership.class);

        if (cached == null) {
            boolean isManaged = getDelegate().isManagedMember(organization, user);
            Long loaded = userCache.getCache().getCurrentRevision(cacheKey);
            UserModel member = getDelegate().getMemberById(organization, user.getId());
            cached = new CachedMembership(loaded, cacheKey, realm, isManaged, member != null);
            userCache.getCache().addRevisioned(cached, userCache.getStartupRevision());
        }

        return cached.isMember() ? user : null;
    }

    @Override
    public Stream<OrganizationModel> getByMember(UserModel member) {
        if (userCache == null) {
            return getDelegate().getByMember(member);
        }

        String cacheKey = cacheKeyByMember(member);

        if (isUserCacheKeyInvalid(cacheKey)) {
            return getDelegate().getByMember(member);
        }

        CachedOrganizationIds cached = userCache.getCache().get(cacheKey, CachedOrganizationIds.class);

        if (cached == null) {
            Long loaded = userCache.getCache().getCurrentRevision(cacheKey);
            Stream<OrganizationModel> model = getDelegate().getByMember(member);
            cached = new CachedOrganizationIds(loaded, cacheKey, getRealm(), model);
            userCache.getCache().addRevisioned(cached, userCache.getStartupRevision());
        }

        return cached.getOrgIds().stream().map(this::getById);
    }

    @Override
    public boolean isManagedMember(OrganizationModel organization, UserModel user) {
        if (userCache == null) {
            return getDelegate().isManagedMember(organization, user);
        }

        if (user == null) {
            return false;
        }

        String cacheKey = cacheKeyMembership(getRealm(), organization, user);
        CachedMembership cached = userCache.getCache().get(cacheKey, CachedMembership.class);

        if (cached == null || isUserCacheKeyInvalid(cacheKey)) {
            // this will not cache the result as calling getMemberById() to have a full caching entry would lead to a recursion
            return getDelegate().isManagedMember(organization, user);
        }

        return cached.isManaged();

    }

    @Override
    public boolean addIdentityProvider(OrganizationModel organization, IdentityProviderModel identityProvider) {
        boolean added = getDelegate().addIdentityProvider(organization, identityProvider);
        if (added) {
            registerOrganizationInvalidation(organization);
        }
        return added;
    }

    @Override
    public Stream<IdentityProviderModel> getIdentityProviders(OrganizationModel organization) {
        return getDelegate().getIdentityProviders(organization);
    }

    @Override
    public boolean removeIdentityProvider(OrganizationModel organization, IdentityProviderModel identityProvider) {
        boolean removed = getDelegate().removeIdentityProvider(organization, identityProvider);
        if (removed) {
            registerOrganizationInvalidation(organization);
        }
        return removed;
    }

    @Override
    public boolean isEnabled() {
        return getRealm().isOrganizationsEnabled();
    }

    @Override
    public long count() {
        if (realmCache == null) {
            return getDelegate().count();
        }

        String cacheKey = cacheKeyOrgCount(getRealm());
        CachedCount cached = realmCache.getCache().get(cacheKey, CachedCount.class);

        // cached and not invalidated
        if (cached != null && !isRealmCacheKeyInvalid(cacheKey)) {
            return cached.getCount();
        }

        Long loaded = realmCache.getCache().getCurrentRevision(cacheKey);
        long count = getDelegate().count();
        cached = new CachedCount(loaded, getRealm(), cacheKey, count);
        realmCache.getCache().addRevisioned(cached, realmCache.getStartupRevision());

        return count;
    }

    @Override
    public void close() {
        if (orgDelegate != null) {
            getDelegate().close();
        }
    }

    void registerOrganizationInvalidation(OrganizationModel organization) {
        String id = organization.getId();

        if (realmCache != null) {
            realmCache.registerInvalidation(cacheKeyOrgId(getRealm(), id));
            realmCache.registerInvalidation(id);
            organization.getDomains()
                    .map(OrganizationDomainModel::getName)
                    .map(this::cacheKeyByDomain)
                    .forEach(realmCache::registerInvalidation);
        }

        OrganizationAdapter adapter = managedOrganizations.get(id);

        if (adapter != null) {
            adapter.invalidate();
        }
    }

    private void registerCountInvalidation() {
        if (realmCache != null) {
            realmCache.registerInvalidation(cacheKeyOrgCount(getRealm()));
        }
    }

    private RealmModel getRealm() {
        RealmModel realm = session.getContext().getRealm();
        if (realm == null) {
            throw new IllegalArgumentException("Session not bound to a realm");
        }
        return realm;
    }

    private Stream<OrganizationModel> getCacheDelegates(Stream<OrganizationModel> backendOrganizations) {
        return backendOrganizations.map(OrganizationModel::getId).map(this::getById);
    }

    private String cacheKeyByDomain(String domainName) {
        return getRealm().getId() + ".org.domain.name." + domainName;
    }

    private String cacheKeyByMember(UserModel user) {
        return getRealm().getId() + ".org.member." + user.getId() + ".orgs";
    }

    private String cacheKeyMembership(RealmModel realm, OrganizationModel organization, UserModel user) {
        return realm.getId() + ".org." + organization.getId() + ".member." + user.getId() + ".membership";
    }

    void registerMemberInvalidation(OrganizationModel organization, UserModel member) {
        if (userCache != null) {
            userCache.registerInvalidation(cacheKeyByMember(member));
            userCache.registerInvalidation(cacheKeyMembership(getRealm(), organization, member));
        }
        if (realmCache != null) {
            realmCache.registerInvalidation(cacheKeyOrgMemberCount(getRealm(), organization));
        }
    }

    private boolean isRealmCacheKeyInvalid(String cacheKey) {
        return realmCache.isInvalid(cacheKey);
    }

    private boolean isUserCacheKeyInvalid(String cacheKey) {
        return userCache.isInvalid(cacheKey);
    }
}