AggregationBlocks.java
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Stream;
import org.bson.Document;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.mapping.model.SimpleTypeHolder;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
import org.springframework.data.mongodb.core.query.Collation;
import org.springframework.data.mongodb.repository.Hint;
import org.springframework.data.mongodb.repository.ReadPreference;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.ReflectionUtils;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/**
* Code blocks for building aggregation pipelines and execution statements for MongoDB repositories.
*
* @author Christoph Strobl
* @since 5.0
*/
class AggregationBlocks {
@NullUnmarked
static class AggregationExecutionCodeBlockBuilder {
private final AotQueryMethodGenerationContext context;
private final SimpleTypeHolder simpleTypeHolder;
private final MongoQueryMethod queryMethod;
private String aggregationVariableName;
AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, SimpleTypeHolder simpleTypeHolder,
MongoQueryMethod queryMethod) {
this.context = context;
this.simpleTypeHolder = simpleTypeHolder;
this.queryMethod = queryMethod;
}
AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) {
this.aggregationVariableName = aggregationVariableName;
return this;
}
CodeBlock build() {
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
Builder builder = CodeBlock.builder();
builder.add("\n");
Class<?> outputType = getOutputType(simpleTypeHolder, queryMethod);
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
return builder.build();
}
if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) {
builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
return builder.build();
}
if (outputType == Document.class) {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
if (queryMethod.isStreamQuery()) {
VariableSnippet results = Snippet.declare(builder)
.variable(ResolvableType.forClassWithGenerics(Stream.class, Document.class),
context.localVariable("results"))
.as("$L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))",
results.getVariableName(), returnType);
} else {
VariableSnippet results = Snippet.declare(builder)
.variable(AggregationResults.class, context.localVariable("results"))
.as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
if (!queryMethod.isCollectionQuery()) {
builder.addStatement(
"return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))",
CollectionUtils.class, returnType, results.getVariableName());
} else {
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
results.getVariableName());
}
}
} else {
if (queryMethod.isSliceQuery()) {
VariableSnippet results = Snippet.declare(builder)
.variable(AggregationResults.class, context.localVariable("results"))
.as("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
VariableSnippet hasNext = Snippet.declare(builder).variable("hasNext").as(
"$L.getMappedResults().size() > $L.getPageSize()", results.getVariableName(),
context.getPageableParameterName());
builder.addStatement(
"return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)",
SliceImpl.class, hasNext.getVariableName(), results.getVariableName(),
context.getPageableParameterName());
} else {
if (queryMethod.isStreamQuery()) {
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
outputType);
} else {
CodeBlock codeBlock = CodeBlock.of("$L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType);
builder.addStatement("return $L",
MongoCodeBlocks.potentiallyWrapStreamable(context.getMethodReturn(), codeBlock));
}
}
}
return builder.build();
}
}
private static Class<?> getOutputType(SimpleTypeHolder simpleTypeHolder, MongoQueryMethod queryMethod) {
Class<?> outputType = queryMethod.getReturnedObjectType();
if (simpleTypeHolder.isSimpleType(outputType)) {
return Document.class;
}
if (ClassUtils.isAssignable(AggregationResults.class, outputType)
&& queryMethod.getReturnType().getComponentType() != null) {
return queryMethod.getReturnType().getComponentType().getType();
}
return outputType;
}
@NullUnmarked
static class AggregationCodeBlockBuilder {
private final AotQueryMethodGenerationContext context;
private final SimpleTypeHolder simpleTypeHolder;
private final MongoQueryMethod queryMethod;
private final String parameterNames;
private AggregationInteraction source;
private String aggregationVariableName;
private boolean pipelineOnly;
AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, SimpleTypeHolder simpleTypeHolder,
MongoQueryMethod queryMethod) {
this.context = context;
this.simpleTypeHolder = simpleTypeHolder;
this.queryMethod = queryMethod;
this.parameterNames = StringUtils.collectionToDelimitedString(context.getAllParameterNames(), ", ");
}
AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) {
this.source = aggregation;
return this;
}
AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) {
this.aggregationVariableName = aggregationVariableName;
return this;
}
AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) {
this.pipelineOnly = pipelineOnly;
return this;
}
CodeBlock build() {
Builder builder = CodeBlock.builder();
builder.add("\n");
String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline"));
builder.add(pipeline(pipelineName));
if (!pipelineOnly) {
Class<?> domainType = context.getRepositoryInformation().getDomainType();
Snippet.declare(builder)
.variable(ResolvableType.forClassWithGenerics(TypedAggregation.class, domainType), aggregationVariableName)
.as("$T.newAggregation($T.class, $L.getOperations())", Aggregation.class, domainType, pipelineName);
builder.add(aggregationOptions(aggregationVariableName));
}
return builder.build();
}
private CodeBlock pipeline(String pipelineVariableName) {
String sortParameter = context.getSortParameterName();
String limitParameter = context.getLimitParameterName();
String pageableParameter = context.getPageableParameterName();
Builder builder = CodeBlock.builder();
builder.add(aggregationStages(context.localVariable("stages"), source.stages()));
if (StringUtils.hasText(sortParameter)) {
Class<?> outputType = getOutputType(simpleTypeHolder, queryMethod);
builder.add(sortingStage(sortParameter, outputType));
}
if (StringUtils.hasText(limitParameter)) {
builder.add(limitingStage(limitParameter));
}
if (StringUtils.hasText(pageableParameter)) {
builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery()));
}
builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName,
context.localVariable("stages"));
return builder.build();
}
private CodeBlock aggregationOptions(String aggregationVariableName) {
Builder builder = CodeBlock.builder();
List<CodeBlock> options = new ArrayList<>(5);
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
options.add(CodeBlock.of(".skipOutput()"));
}
MergedAnnotation<Hint> hintAnnotation = context.getAnnotation(Hint.class);
String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null;
if (StringUtils.hasText(hint)) {
options.add(CodeBlock.of(".hint($S)", hint));
}
MergedAnnotation<ReadPreference> readPreferenceAnnotation = context.getAnnotation(ReadPreference.class);
String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null;
if (StringUtils.hasText(readPreference)) {
options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference));
}
if (queryMethod.hasAnnotatedCollation()) {
options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation()));
}
if (!options.isEmpty()) {
Builder optionsBuilder = CodeBlock.builder();
optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class,
context.localVariable("aggregationOptions"));
optionsBuilder.indent();
for (CodeBlock optionBlock : options) {
optionsBuilder.add(optionBlock);
optionsBuilder.add("\n");
}
optionsBuilder.add(".build();\n");
optionsBuilder.unindent();
builder.add(optionsBuilder.build());
builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName,
context.localVariable("aggregationOptions"));
}
return builder.build();
}
private CodeBlock aggregationStages(String stageListVariableName, Collection<String> stages) {
Builder builder = CodeBlock.builder();
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,
stages.size());
int stageCounter = 0;
for (String stage : stages) {
VariableSnippet stageSnippet = Snippet.declare(builder)
.variable(Document.class, context.localVariable("stage_%s".formatted(stageCounter)))
.of(MongoCodeBlocks.asDocument(context.getExpressionMarker(), stage, parameterNames));
builder.addStatement("$L.add($L)", stageListVariableName, stageSnippet.getVariableName());
stageCounter++;
}
return builder.build();
}
private CodeBlock sortingStage(String sortProvider, Class<?> outputType) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if ($L.isSorted())", sortProvider);
builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument"));
builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider);
builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);",
context.localVariable("sortDocument"), context.localVariable("order"));
builder.endControlFlow();
if (outputType == Document.class || simpleTypeHolder.isSimpleType(outputType)
|| ClassUtils.isAssignable(context.getRepositoryInformation().getDomainType(), outputType)) {
builder.addStatement("$L.add(new $T($S, $L))", context.localVariable("stages"), Document.class, "$sort",
context.localVariable("sortDocument"));
} else {
builder.addStatement("$L.add(($T) _ctx -> new $T($S, _ctx.getMappedObject($L, $T.class)))",
context.localVariable("stages"), AggregationOperation.class, Document.class, "$sort",
context.localVariable("sortDocument"), outputType);
}
builder.endControlFlow();
return builder.build();
}
private CodeBlock pagingStage(String pageableProvider, boolean slice) {
Builder builder = CodeBlock.builder();
builder.add(sortingStage(pageableProvider + ".getSort()", getOutputType(simpleTypeHolder, queryMethod)));
builder.beginControlFlow("if ($L.isPaged())", pageableProvider);
builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider);
builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
builder.endControlFlow();
if (slice) {
builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"),
Aggregation.class, pageableProvider);
} else {
builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class,
pageableProvider);
}
builder.endControlFlow();
return builder.build();
}
private CodeBlock limitingStage(String limitProvider) {
Builder builder = CodeBlock.builder();
builder.beginControlFlow("if ($L.isLimited())", limitProvider);
builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class,
limitProvider);
builder.endControlFlow();
return builder.build();
}
}
}