MongoRepositoryContributor.java
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*;
import static org.springframework.data.mongodb.repository.aot.QueryBlocks.*;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Properties;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.data.mapping.model.SimpleTypeHolder;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
import org.springframework.data.mongodb.core.convert.MongoCustomConversions;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.repository.Query;
import org.springframework.data.mongodb.repository.Update;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.config.MongoRepositoryConfigurationExtension;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder;
import org.springframework.data.repository.aot.generate.MethodContributor;
import org.springframework.data.repository.aot.generate.QueryMetadata;
import org.springframework.data.repository.aot.generate.RepositoryContributor;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.config.PropertiesBasedNamedQueriesFactoryBean;
import org.springframework.data.repository.config.RepositoryConfigurationSource;
import org.springframework.data.repository.core.NamedQueries;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.PropertiesBasedNamedQueries;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.TypeName;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
/**
* MongoDB specific {@link RepositoryContributor}.
*
* @author Christoph Strobl
* @author Mark Paluch
* @since 5.0
*/
public class MongoRepositoryContributor extends RepositoryContributor {
private static final Log logger = LogFactory.getLog(MongoRepositoryContributor.class);
private final AotRepositoryContext repositoryContext;
private final AotQueryCreator queryCreator;
private final SimpleTypeHolder simpleTypeHolder;
private final MongoMappingContext mappingContext;
private final NamedQueries namedQueries;
public MongoRepositoryContributor(AotRepositoryContext repositoryContext) {
super(repositoryContext);
ClassLoader classLoader = repositoryContext.getBeanFactory() != null ? repositoryContext.getClassLoader() : null;
if (classLoader == null) {
classLoader = getClass().getClassLoader();
}
this.repositoryContext = repositoryContext;
this.namedQueries = getNamedQueries(repositoryContext.getConfigurationSource(), classLoader);
// avoid Java Time (JSR-310) Type introspection
MongoCustomConversions mongoCustomConversions = MongoCustomConversions
.create(MongoCustomConversions.MongoConverterConfigurationAdapter::useNativeDriverJavaTimeCodecs);
this.simpleTypeHolder = mongoCustomConversions.getSimpleTypeHolder();
this.mappingContext = new MongoMappingContext();
this.mappingContext.setSimpleTypeHolder(this.simpleTypeHolder);
this.mappingContext.setAutoIndexCreation(false);
this.mappingContext.afterPropertiesSet();
this.queryCreator = new AotQueryCreator(this.mappingContext);
}
@SuppressWarnings("NullAway")
private NamedQueries getNamedQueries(@Nullable RepositoryConfigurationSource configSource, ClassLoader classLoader) {
String location = configSource != null ? configSource.getNamedQueryLocation().orElse(null) : null;
if (location == null) {
location = new MongoRepositoryConfigurationExtension().getDefaultNamedQueryLocation();
}
if (StringUtils.hasText(location)) {
try {
PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(classLoader);
PropertiesBasedNamedQueriesFactoryBean factoryBean = new PropertiesBasedNamedQueriesFactoryBean();
factoryBean.setLocations(resolver.getResources(location));
factoryBean.afterPropertiesSet();
return factoryBean.getObject();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return new PropertiesBasedNamedQueries(new Properties());
}
@Override
protected void customizeClass(AotRepositoryClassBuilder classBuilder) {
classBuilder.customize(builder -> builder.superclass(TypeName.get(MongoAotRepositoryFragmentSupport.class)));
}
@Override
protected void customizeConstructor(AotRepositoryConstructorBuilder constructorBuilder) {
constructorBuilder.addParameter("operations", MongoOperations.class, customizer -> {
String mongoOperationsRef = getMongoTemplateRef();
customizer.bindToField()
.origin(StringUtils.hasText(mongoOperationsRef)
? new RuntimeBeanReference(mongoOperationsRef, MongoOperations.class)
: new RuntimeBeanReference(MongoOperations.class));
});
constructorBuilder.addParameter("context", RepositoryFactoryBeanSupport.FragmentCreationContext.class, false);
constructorBuilder.customize((builder) -> {
builder.addStatement("super(operations, context)");
});
}
private @Nullable String getMongoTemplateRef() {
return repositoryContext.getConfigurationSource().getAttribute("mongoTemplateRef")
.filter(it -> !"mongoTemplate".equals(it)).orElse(null);
}
@Override
@SuppressWarnings("NullAway")
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method) {
MongoQueryMethod queryMethod = new MongoQueryMethod(method, getRepositoryInformation(), getProjectionFactory(),
mappingContext);
MethodContributor.RepositoryMethodContribution contribution = null;
QueryMetadata queryMetadata = null;
if (queryMethod.hasAnnotatedAggregation()) {
AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation());
queryMetadata = aggregation;
contribution = aggregationMethodContributor(queryMethod, simpleTypeHolder, aggregation);
} else {
QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod,
AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method);
if (queryMethod.hasAnnotatedVectorSearch()) {
VectorSearch vectorSearch = queryMethod.getRequiredVectorSearchAnnotation();
SearchInteraction interaction = new SearchInteraction(getRepositoryInformation().getDomainType(), vectorSearch,
query.getQuery(), queryMethod.getParameters(), mappingContext);
queryMetadata = interaction;
contribution = searchMethodContributor(queryMethod, interaction);
} else if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1
&& queryMethod.getReturnType().isCollectionLike())) {
NearQueryInteraction near = new NearQueryInteraction(query, queryMethod.getParameters());
queryMetadata = near;
contribution = nearQueryMethodContributor(queryMethod, near);
} else if (query.isDelete()) {
queryMetadata = query;
contribution = deleteMethodContributor(queryMethod, query);
} else if (queryMethod.isModifyingQuery()) {
int updateIndex = queryMethod.getParameters().getUpdateIndex();
if (updateIndex != -1) {
UpdateInteraction update = new UpdateInteraction(query, null, updateIndex);
queryMetadata = update;
contribution = updateMethodContributor(queryMethod, update);
} else {
Update updateSource = queryMethod.getUpdateSource();
if (StringUtils.hasText(updateSource.value())) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null);
queryMetadata = update;
contribution = updateMethodContributor(queryMethod, update);
}
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
queryMetadata = update;
contribution = aggregationUpdateMethodContributor(queryMethod, simpleTypeHolder, update);
}
}
} else {
queryMetadata = query;
contribution = queryMethodContributor(queryMethod, query);
}
}
if (queryMetadata == null) {
return null;
}
if (backoff(queryMethod) || contribution == null) {
return MethodContributor.forQueryMethod(queryMethod).metadataOnly(queryMetadata);
}
return MethodContributor.forQueryMethod(queryMethod).withMetadata(queryMetadata).contribute(contribution);
}
@SuppressWarnings("NullAway")
private QueryInteraction createStringQuery(RepositoryInformation repositoryInformation, MongoQueryMethod queryMethod,
@Nullable Query queryAnnotation, Method source) {
QueryInteraction query;
if (queryMethod.hasAnnotatedQuery() && queryAnnotation != null) {
query = new QueryInteraction(new AotStringQuery(queryMethod.getAnnotatedQuery()), queryAnnotation.count(),
queryAnnotation.delete(), queryAnnotation.exists());
} else if (namedQueries.hasQuery(queryMethod.getNamedQueryName())) {
query = new QueryInteraction(new AotStringQuery(namedQueries.getQuery(queryMethod.getNamedQueryName())),
queryAnnotation != null && queryAnnotation.count(), queryAnnotation != null && queryAnnotation.delete(),
queryAnnotation != null && queryAnnotation.exists());
} else {
PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType());
AotStringQuery aotStringQuery = queryCreator.createQuery(partTree, queryMethod, source);
query = new QueryInteraction(aotStringQuery,
partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection());
}
if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) {
query = query.withSort(queryAnnotation.sort());
}
if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.fields())) {
query = query.withFields(queryAnnotation.fields());
}
return query;
}
private static boolean backoff(MongoQueryMethod method) {
// TODO: returning arrays.
boolean skip = method.getReturnType().getType().isArray();
if (skip && logger.isDebugEnabled()) {
logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query"
.formatted(method.getName()));
}
return skip;
}
private static MethodContributor.RepositoryMethodContribution nearQueryMethodContributor(MongoQueryMethod queryMethod,
NearQueryInteraction interaction) {
return context -> {
CodeBlock.Builder builder = CodeBlock.builder();
String variableName = context.localVariable("nearQuery");
builder.add(geoNearBlockBuilder(context, queryMethod).usingQueryVariableName(variableName).build());
if (!context.getBindableParameterNames().isEmpty()) {
String filterQueryVariableName = context.localVariable("filterQuery");
builder.add(queryBlockBuilder(context, queryMethod).usingQueryVariableName(filterQueryVariableName)
.filter(interaction.getQuery()).build());
builder.addStatement("$L.query($L)", variableName, filterQueryVariableName);
}
builder.add(geoNearExecutionBlockBuilder(context).referencing(variableName).build());
return builder.build();
};
}
static MethodContributor.RepositoryMethodContribution aggregationMethodContributor(MongoQueryMethod queryMethod,
SimpleTypeHolder simpleTypeHolder,
AggregationInteraction aggregation) {
return context -> {
CodeBlock.Builder builder = CodeBlock.builder();
String variableName = context.localVariable("aggregation");
builder.add(aggregationBlockBuilder(context, simpleTypeHolder, queryMethod).stages(aggregation)
.usingAggregationVariableName(variableName).build());
builder.add(
aggregationExecutionBlockBuilder(context, simpleTypeHolder, queryMethod).referencing(variableName).build());
return builder.build();
};
}
static MethodContributor.RepositoryMethodContribution searchMethodContributor(MongoQueryMethod queryMethod,
SearchInteraction interaction) {
return context -> {
CodeBlock.Builder builder = CodeBlock.builder();
String variableName = context.localVariable("search");
builder.add(
new VectorSearchBlocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod, interaction.getSearchPath())
.usingVariableName(variableName).withFilter(interaction.getFilter()).build());
return builder.build();
};
}
static MethodContributor.RepositoryMethodContribution updateMethodContributor(MongoQueryMethod queryMethod,
UpdateInteraction update) {
return context -> {
CodeBlock.Builder builder = CodeBlock.builder();
// update filter
String filterVariableName = context.localVariable(update.name());
builder.add(queryBlockBuilder(context, queryMethod).filter(update.getFilter())
.usingQueryVariableName(filterVariableName).build());
// update definition
String updateVariableName;
if (update.hasUpdateDefinitionParameter()) {
updateVariableName = context.getParameterName(update.getRequiredUpdateDefinitionParameter());
} else {
updateVariableName = context.localVariable("updateDefinition");
builder.add(updateBlockBuilder(context).update(update).usingUpdateVariableName(updateVariableName).build());
}
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
.referencingUpdate(updateVariableName).build());
return builder.build();
};
}
static MethodContributor.RepositoryMethodContribution aggregationUpdateMethodContributor(MongoQueryMethod queryMethod,
SimpleTypeHolder simpleTypeHolder,
AggregationUpdateInteraction update) {
return context -> {
CodeBlock.Builder builder = CodeBlock.builder();
// update filter
String filterVariableName = context.localVariable(update.name());
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(update.getFilter());
builder.add(queryCodeBlockBuilder.usingQueryVariableName(filterVariableName).build());
// update definition
String updateVariableName = context.localVariable("updateDefinition");
builder.add(aggregationBlockBuilder(context, simpleTypeHolder, queryMethod).stages(update)
.usingAggregationVariableName(updateVariableName).pipelineOnly(true).build());
builder.addStatement("$T $L = $T.from($L.getOperations())", AggregationUpdate.class,
context.localVariable("aggregationUpdate"), AggregationUpdate.class, updateVariableName);
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
.referencingUpdate(context.localVariable("aggregationUpdate")).build());
return builder.build();
};
}
static MethodContributor.RepositoryMethodContribution deleteMethodContributor(MongoQueryMethod queryMethod,
QueryInteraction query) {
return context -> {
CodeBlock.Builder builder = CodeBlock.builder();
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query);
String queryVariableName = context.localVariable(query.name());
builder.add(queryCodeBlockBuilder.usingQueryVariableName(queryVariableName).build());
builder.add(deleteExecutionBlockBuilder(context, queryMethod).referencing(queryVariableName).build());
return builder.build();
};
}
static MethodContributor.RepositoryMethodContribution queryMethodContributor(MongoQueryMethod queryMethod,
QueryInteraction query) {
return context -> {
CodeBlock.Builder builder = CodeBlock.builder();
QueryCodeBlockBuilder queryCodeBlockBuilder = queryBlockBuilder(context, queryMethod).filter(query);
builder.add(queryCodeBlockBuilder.usingQueryVariableName(context.localVariable(query.name())).build());
builder.add(queryExecutionBlockBuilder(context, queryMethod).forQuery(query).build());
return builder.build();
};
}
}