AbstractStorageManager.java

/*
 * Copyright 2020 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.keycloak.storage;

import org.jboss.logging.Logger;
import org.keycloak.Config;
import org.keycloak.common.util.reflections.Types;
import org.keycloak.component.ComponentFactory;
import org.keycloak.component.ComponentModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.provider.Provider;
import org.keycloak.provider.ProviderFactory;
import org.keycloak.utils.ServicesUtils;

import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;

/**
 *
 * @param <ProviderType> This type will be used for looking for factories that produce instances of desired providers
 * @param <StorageProviderModelType> Type of model used for creating provider, it needs to extend 
 *                                  CacheableStorageProviderModel as it has {@code isEnabled()} method and also extend
 *                                  PrioritizedComponentModel which is required for sorting providers based on its
 *                                  priorities
 */
public abstract class AbstractStorageManager<ProviderType extends Provider,
        StorageProviderModelType extends CacheableStorageProviderModel> {

    private static final Logger LOG = Logger.getLogger(AbstractStorageManager.class);

    /**
     * Timeouts are used as time boundary for obtaining models from an external storage. Default value is set
     * to 3000 milliseconds and it's configurable.
     */
    private static final Long STORAGE_PROVIDER_DEFAULT_TIMEOUT = 3000L;
    protected final KeycloakSession session;
    private final Class<ProviderType> providerTypeClass;
    private final Class<? extends ProviderFactory> factoryTypeClass;
    private final Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction;
    private final String configScope;
    private Long storageProviderTimeout;

    public AbstractStorageManager(KeycloakSession session, Class<? extends ProviderFactory> factoryTypeClass, Class<ProviderType> providerTypeClass, Function<ComponentModel, StorageProviderModelType> toStorageProviderModelTypeFunction, String configScope) {
        this.session = session;
        this.providerTypeClass = providerTypeClass;
        this.factoryTypeClass = factoryTypeClass;
        this.toStorageProviderModelTypeFunction = toStorageProviderModelTypeFunction;
        this.configScope = configScope;
    }

    protected Long getStorageProviderTimeout() {
        if (storageProviderTimeout == null) {
            storageProviderTimeout = Config.scope(configScope).getLong("storageProviderTimeout", STORAGE_PROVIDER_DEFAULT_TIMEOUT);
        }
        return storageProviderTimeout;
    }

    /**
     * Returns a factory with the providerId, which produce instances of type CreatedProviderType
     * @param providerId id of factory that produce desired instances
     * @return A factory that implements {@code ComponentFactory<CreatedProviderType, ProviderType>}
     */
    protected <T extends ProviderType> ComponentFactory<T, ProviderType> getStorageProviderFactory(String providerId) {
        return (ComponentFactory<T, ProviderType>) session.getKeycloakSessionFactory()
                .getProviderFactory(providerTypeClass, providerId);
    }

    /**
     * Returns stream of all storageProviders within the realm that implements the capabilityInterface.
     *
     * @param realm realm
     * @param capabilityInterface class of desired capabilityInterface.
     *                            For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
     * @return enabled storage providers for realm and @{code getProviderTypeClass()}
     */
    protected <T> Stream<T> getEnabledStorageProviders(RealmModel realm, Class<T> capabilityInterface) {
        return getStorageProviderModels(realm, providerTypeClass)
                .map(toStorageProviderModelTypeFunction)
                .filter(StorageProviderModelType::isEnabled)
                .sorted(StorageProviderModelType.comparator)
                .map(storageProviderModelType -> getStorageProviderInstance(storageProviderModelType, capabilityInterface, false))
                .filter(Objects::nonNull);
    }

    /**
     * Gets all enabled StorageProviders that implements the capabilityInterface, applies applyFunction on each of
     * them and then join the results together.
     *
     * !! Each StorageProvider has a limited time to respond, if it fails to do it, empty stream is returned !!
     *
     * @param realm realm
     * @param capabilityInterface class of desired capabilityInterface.
     *                            For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
     * @param applyFunction function that is applied on StorageProviders
     * @param <R> result of applyFunction
     * @return a stream with all results from all StorageProviders
     */
    protected <R, T> Stream<R> flatMapEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Function<T, ? extends Stream<R>> applyFunction) {
        return getEnabledStorageProviders(realm, capabilityInterface)
                .flatMap(ServicesUtils.timeBound(session, getStorageProviderTimeout(), applyFunction));
    }

    /**
     * Gets all enabled StorageProviders that implements the capabilityInterface, applies applyFunction on each of
     * them and returns the stream.
     *
     * !! Each StorageProvider has a limited time to respond, if it fails to do it, null is returned !!
     *
     * @param realm realm
     * @param capabilityInterface class of desired capabilityInterface.
     *                            For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
     * @param applyFunction function that is applied on StorageProviders
     * @param <R> Result of applyFunction
     * @return First result from StorageProviders
     */
    protected <R, T> Stream<R> mapEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Function<T, R> applyFunction) {
        return getEnabledStorageProviders(realm, capabilityInterface)
                .map(ServicesUtils.timeBoundOne(session, getStorageProviderTimeout(), applyFunction))
                .filter(Objects::nonNull);
    }

    /**
     * Gets all enabled StorageProviders that implements the capabilityInterface and call applyFunction on each
     *
     * !! Each StorageProvider has a limited time for consuming !!
     *
     * @param realm realm
     * @param capabilityInterface class of desired capabilityInterface.
     *                            For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
     * @param consumer function that is applied on StorageProviders
     */
    protected <T> void consumeEnabledStorageProvidersWithTimeout(RealmModel realm, Class<T> capabilityInterface, Consumer<T> consumer) {
        getEnabledStorageProviders(realm, capabilityInterface)
                .forEachOrdered(ServicesUtils.consumeWithTimeBound(session, getStorageProviderTimeout(), consumer));
    }


    protected <T> T getStorageProviderInstance(RealmModel realm, String providerId, Class<T> capabilityInterface) {
        return getStorageProviderInstance(realm, providerId, capabilityInterface, false);
    }

    /**
     * Returns an instance of provider with the providerId within the realm or null if storage provider with providerId
     * doesn't implement capabilityInterface.
     *
     * @param realm realm
     * @param providerId id of ComponentModel within database/storage
     * @param capabilityInterface class of desired capabilityInterface.
     *                            For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
     * @return an instance of type CreatedProviderType or null if storage provider with providerId doesn't implement capabilityInterface
     */
    protected <T> T getStorageProviderInstance(RealmModel realm, String providerId, Class<T> capabilityInterface, boolean includeDisabled) {
        if (providerId == null || capabilityInterface == null) return null;
        return getStorageProviderInstance(getStorageProviderModel(realm, providerId), capabilityInterface, includeDisabled);
    }

    /**
     * Returns an instance of StorageProvider model corresponding realm and providerId
     * @param realm Realm.
     * @param providerId Id of desired provider.
     * @return An instance of type StorageProviderModelType
     */
    protected StorageProviderModelType getStorageProviderModel(RealmModel realm, String providerId) {
        ComponentModel componentModel = realm.getComponent(providerId);
        if (componentModel == null) {
            return null;
        }

        return toStorageProviderModelTypeFunction.apply(componentModel);
    }

    /**
     * Returns an instance of provider for the model or null if storage provider based on the model doesn't implement capabilityInterface.
     *
     * @param model StorageProviderModel obtained from database/storage
     * @param capabilityInterface class of desired capabilityInterface.
     *                            For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
     * @param <T> Required capability interface type
     * @return an instance of type T or null if storage provider based on the model doesn't exist or doesn't implement the capabilityInterface.
     */
    protected <T> T getStorageProviderInstance(StorageProviderModelType model, Class<T> capabilityInterface) {
        return getStorageProviderInstance(model, capabilityInterface, false);
    }

    /**
     * Returns an instance of provider for the model or null if storage provider based on the model doesn't implement capabilityInterface.
     *
     * @param model StorageProviderModel obtained from database/storage
     * @param capabilityInterface class of desired capabilityInterface.
     *                            For example, {@code GroupLookupProvider} or {@code UserQueryProvider}
     * @param includeDisabled If set to true, the method will return also disabled providers.
     * @return an instance of type T or null if storage provider based on the model doesn't exist or doesn't implement the capabilityInterface.
     */
    protected <T> T getStorageProviderInstance(StorageProviderModelType model, Class<T> capabilityInterface, boolean includeDisabled) {
        if (model == null || (!model.isEnabled() && !includeDisabled) || capabilityInterface == null) {
            return null;
        }

        @SuppressWarnings("unchecked")
        ProviderType instance = (ProviderType) session.getAttribute(model.getId());
        if (instance != null && capabilityInterface.isAssignableFrom(instance.getClass())) return capabilityInterface.cast(instance);

        ComponentFactory<? extends ProviderType, ProviderType> factory = getStorageProviderFactory(model.getProviderId());
        if (factory == null) {
            LOG.warnv("Configured StorageProvider {0} of provider id {1} does not exist", model.getName(), model.getProviderId());
            return null;
        }
        if (!Types.supports(capabilityInterface, factory, factoryTypeClass)) {
            return null;
        }

        instance = factory.create(session, model);
        if (instance == null) {
            throw new IllegalStateException("StorageProviderFactory (of type " + factory.getClass().getName() + ") produced a null instance");
        }
        session.enlistForClose(instance);
        session.setAttribute(model.getId(), instance);
        return capabilityInterface.cast(instance);
    }

    /**
     * Stream of ComponentModels of storageType.
     * @param realm Realm.
     * @param storageType Type.
     * @return Stream of ComponentModels
     */
    public static Stream<ComponentModel> getStorageProviderModels(RealmModel realm, Class<? extends Provider> storageType) {
        return realm.getStorageProviders(storageType);
    }
}