DefaultComponentFactoryProviderFactory.java
/*
* Copyright 2021 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.services;
import org.keycloak.Config;
import org.keycloak.Config.Scope;
import org.keycloak.cluster.ClusterProvider;
import org.keycloak.common.util.StackUtil;
import org.keycloak.component.ComponentFactoryProviderFactory;
import org.keycloak.component.ComponentModel;
import org.keycloak.component.ComponentModelScope;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.provider.InvalidationHandler;
import org.keycloak.provider.InvalidationHandler.InvalidableObjectType;
import org.keycloak.provider.InvalidationHandler.ObjectType;
import org.keycloak.provider.Provider;
import org.keycloak.provider.ProviderFactory;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Stream;
import org.jboss.logging.Logger;
/**
* @author hmlnarik
*/
public class DefaultComponentFactoryProviderFactory implements ComponentFactoryProviderFactory {
private static final Logger LOG = Logger.getLogger(DefaultComponentFactoryProviderFactory.class);
public static final String PROVIDER_ID = "default";
private final AtomicReference<ConcurrentMap<String, ProviderFactory>> componentsMap = new AtomicReference<>(new ConcurrentHashMap<>());
/**
* Should an ID in the key be invalidated, it would invalidate also all the IDs in the values
*/
private final ConcurrentMap<Object, Set<String>> dependentInvalidations = new ConcurrentHashMap<>();
private KeycloakSessionFactory factory;
private boolean componentCachingAvailable;
private boolean componentCachingEnabled;
private Boolean componentCachingForced;
@Override
public void init(Scope config) {
this.componentCachingEnabled = config.getBoolean("cachingEnabled", true);
this.componentCachingForced = config.getBoolean("cachingForced", false);
}
@Override
public void postInit(KeycloakSessionFactory factory) {
this.factory = factory;
this.componentCachingAvailable = this.componentCachingEnabled && this.factory.getProviderFactory(ClusterProvider.class) != null;
if (! componentCachingEnabled) {
LOG.warn("Caching of components disabled by the configuration which may have performance impact.");
} else if (! componentCachingAvailable) {
if (Objects.equals(componentCachingForced, Boolean.TRUE)) {
LOG.warn("Component caching forced even though no system-wide ClusterProviderFactory found. This would be only reliable in single-node deployment.");
this.componentCachingAvailable = true;
} else {
LOG.warn("No system-wide ClusterProviderFactory found. Cannot send messages across cluster, thus disabling caching of components. Consider setting cachingForced option in single-node deployment.");
}
}
}
@Override
@SuppressWarnings("unchecked")
public <T extends Provider> ProviderFactory<T> getProviderFactory(Class<T> clazz, String realmId, String componentId, Function<KeycloakSessionFactory, ComponentModel> modelGetter) {
ProviderFactory res = componentsMap.get().get(componentId);
if (res != null) {
LOG.tracef("Found cached ProviderFactory for %s in (%s, %s)", clazz, realmId, componentId);
return res;
}
// Apply the expensive operation before putting it into the cache
final ComponentModel cm;
if (modelGetter == null) {
LOG.debugf("Getting component configuration for component (%s, %s) from realm configuration", clazz, realmId, componentId);
cm = KeycloakModelUtils.getComponentModel(factory, realmId, componentId);
} else {
LOG.debugf("Getting component configuration for component (%s, %s) via provided method", realmId, componentId);
cm = modelGetter.apply(factory);
}
if (cm == null) {
return null;
}
final String provider = cm.getProviderId();
ProviderFactory<T> pf = provider == null
? factory.getProviderFactory(clazz)
: factory.getProviderFactory(clazz, provider);
if (pf == null) { // Either not found or not enabled
LOG.debugf("ProviderFactory for %s in (%s, %s) not found", clazz, realmId, componentId);
return null;
}
final ProviderFactory newFactory;
try {
newFactory = pf.getClass().getDeclaredConstructor().newInstance();
} catch (ReflectiveOperationException ex) {
LOG.warn("Cannot instantiate factory", ex);
return null;
}
Scope scope = Config.scope(factory.getSpi(clazz).getName(), provider);
ComponentModelScope configScope = new ComponentModelScope(scope, cm);
ProviderFactory<T> providerFactory;
if (this.componentCachingAvailable) {
providerFactory = componentsMap.get().computeIfAbsent(componentId, cId -> initializeFactory(clazz, realmId, componentId, newFactory, configScope));
} else {
providerFactory = initializeFactory(clazz, realmId, componentId, newFactory, configScope);
}
return providerFactory;
}
@SuppressWarnings("unchecked")
protected <T extends Provider> ProviderFactory<T> initializeFactory(Class<T> clazz, String realmId, String componentId, final ProviderFactory newFactory, ComponentModelScope configScope) {
LOG.debugf("Initializing ProviderFactory for %s in (%s, %s)", clazz, realmId, componentId);
newFactory.init(configScope);
newFactory.postInit(factory);
if (realmId == null) {
realmId = configScope.getComponentParentId();
}
if (realmId != null) {
dependentInvalidations.computeIfAbsent(realmId, k -> ConcurrentHashMap.newKeySet()).add(componentId);
}
dependentInvalidations.computeIfAbsent(newFactory.getClass(), k -> ConcurrentHashMap.newKeySet()).add(componentId);
return newFactory;
}
@Override
public void invalidate(KeycloakSession session, InvalidableObjectType type, Object... ids) {
if (LOG.isDebugEnabled()) {
LOG.debugf("Invalidating %s: %s", type, Arrays.asList(ids));
}
LOG.tracef("invalidate(%s)%s", type, StackUtil.getShortStackTrace());
if (type == ObjectType._ALL_) {
final ConcurrentMap<String, ProviderFactory> cm = componentsMap.getAndSet(new ConcurrentHashMap<>());
dependentInvalidations.clear();
cm.values().forEach(ProviderFactory::close);
} else if (type == ObjectType.COMPONENT) {
Stream.of(ids)
.map(componentsMap.get()::remove).filter(Objects::nonNull)
.forEach(ProviderFactory::close);
propagateInvalidation(session, componentsMap.get(), type, ids);
} else if (type == ObjectType.REALM || type == ObjectType.PROVIDER_FACTORY) {
Stream.of(ids)
.map(dependentInvalidations::get).filter(Objects::nonNull).flatMap(Collection::stream)
.map(componentsMap.get()::remove).filter(Objects::nonNull)
.forEach(ProviderFactory::close);
Stream.of(ids).forEach(dependentInvalidations::remove);
propagateInvalidation(session, componentsMap.get(), type, ids);
} else {
propagateInvalidation(session, componentsMap.get(), type, ids);
}
}
private void propagateInvalidation(KeycloakSession session, ConcurrentMap<String, ProviderFactory> componentsMap, InvalidableObjectType type, Object[] ids) {
componentsMap.values()
.stream()
.filter(InvalidationHandler.class::isInstance)
.map(InvalidationHandler.class::cast)
.forEach(ih -> ih.invalidate(session, type, ids));
}
@Override
public String getId() {
return PROVIDER_ID;
}
@Override
public void close() {
componentsMap.get().values().forEach(ProviderFactory::close);
}
}