LoadBalancerChildContextInitializer.java
/*
* Copyright 2012-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.loadbalancer.aot;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.lang.model.element.Modifier;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.boot.context.properties.bind.Bindable;
import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClientSpecification;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.aot.ApplicationContextAotGenerator;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName;
import org.springframework.util.Assert;
/**
* A {@link BeanRegistrationAotProcessor} that creates an
* {@link BeanRegistrationAotContribution} for LoadBalancer child contexts.
*
* @author Olga Maciaszek-Sharma
*/
public class LoadBalancerChildContextInitializer implements BeanRegistrationAotProcessor {
private final ApplicationContext applicationContext;
private final LoadBalancerClientFactory loadBalancerClientFactory;
public LoadBalancerChildContextInitializer(LoadBalancerClientFactory loadBalancerClientFactory,
ApplicationContext applicationContext) {
this.loadBalancerClientFactory = loadBalancerClientFactory;
this.applicationContext = applicationContext;
}
@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
Assert.isInstanceOf(ConfigurableApplicationContext.class, applicationContext);
ConfigurableApplicationContext context = ((ConfigurableApplicationContext) applicationContext);
BeanFactory applicationBeanFactory = context.getBeanFactory();
if (!(registeredBean.getBeanClass().equals(LoadBalancerClientFactory.class)
&& registeredBean.getBeanFactory().equals(applicationBeanFactory))) {
return null;
}
Set<String> contextIds = new HashSet<>();
contextIds.addAll(getContextIdsFromConfig());
contextIds.addAll(getEagerLoadContextIds());
Map<String, GenericApplicationContext> childContextAotContributions = contextIds.stream()
.map(contextId -> Map.entry(contextId, buildChildContext(contextId)))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return new AotContribution(childContextAotContributions);
}
private Set<String> getContextIdsFromConfig() {
Map<String, LoadBalancerClientSpecification> configurations = loadBalancerClientFactory.getConfigurations();
return configurations.keySet().stream().filter(key -> !key.startsWith("default.")).collect(Collectors.toSet());
}
private Set<String> getEagerLoadContextIds() {
return Binder.get(applicationContext.getEnvironment())
.bind("spring.cloud.loadbalancer.eager-load.clients", Bindable.setOf(String.class))
.orElse(Collections.emptySet());
}
private GenericApplicationContext buildChildContext(String contextId) {
GenericApplicationContext childContext = loadBalancerClientFactory.buildContext(contextId);
loadBalancerClientFactory.registerBeans(contextId, childContext);
return childContext;
}
private static class AotContribution implements BeanRegistrationAotContribution {
private final Map<String, GenericApplicationContext> childContexts;
AotContribution(Map<String, GenericApplicationContext> childContexts) {
this.childContexts = childContexts.entrySet()
.stream()
.filter(entry -> entry.getValue() != null)
.map(entry -> Map.entry(entry.getKey(), entry.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
@Override
public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) {
Map<String, ClassName> generatedInitializerClassNames = childContexts.entrySet().stream().map(entry -> {
String name = entry.getValue().getDisplayName();
name = name.replaceAll("[-]", "_");
GenerationContext childGenerationContext = generationContext.withName(name);
ClassName initializerClassName = new ApplicationContextAotGenerator()
.processAheadOfTime(entry.getValue(), childGenerationContext);
return Map.entry(entry.getKey(), initializerClassName);
}).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
GeneratedMethod postProcessorMethod = beanRegistrationCode.getMethods()
.add("addChildContextInitializer", method -> {
method.addJavadoc("Use AOT child context management initialization")
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addParameter(RegisteredBean.class, "registeredBean")
.addParameter(LoadBalancerClientFactory.class, "instance")
.returns(LoadBalancerClientFactory.class)
.addStatement("$T<String, Object> initializers = new $T<>()", Map.class, HashMap.class);
generatedInitializerClassNames.keySet()
.forEach(contextId -> method.addStatement("initializers.put($S, new $L())", contextId,
generatedInitializerClassNames.get(contextId)));
method.addStatement("return instance.withApplicationContextInitializers(initializers)");
});
beanRegistrationCode.addInstancePostProcessor(postProcessorMethod.toMethodReference());
}
}
}