VectorSearchBlocks.java
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import java.util.List;
import org.bson.Document;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.data.core.TypeInformation;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.aot.Snippet.BuilderStyleBuilder;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.util.StringUtils;
/**
* Code blocks for building vector search operations in AOT processing for MongoDB repositories.
*
* @author Christoph Strobl
* @since 5.0
*/
class VectorSearchBlocks {
@NullUnmarked
static class VectorSearchQueryCodeBlockBuilder {
private final AotQueryMethodGenerationContext context;
private final MongoQueryMethod queryMethod;
private final VectorSearch vectorSearchAnnotation;
private String searchQueryVariableName;
private AotStringQuery filter;
private final String searchPath;
VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod,
String searchPath) {
this.context = context;
this.queryMethod = queryMethod;
this.vectorSearchAnnotation = queryMethod.getRequiredVectorSearchAnnotation();
this.searchPath = searchPath;
}
VectorSearchQueryCodeBlockBuilder withFilter(AotStringQuery filter) {
this.filter = filter;
return this;
}
VectorSearchQueryCodeBlockBuilder usingVariableName(String searchQueryVariableName) {
this.searchQueryVariableName = searchQueryVariableName;
return this;
}
CodeBlock build() {
Builder builder = CodeBlock.builder();
String vectorParameterName = context.getVectorParameterName();
String indexName = vectorSearchAnnotation.indexName();
SearchType searchType = vectorSearchAnnotation.searchType();
ExpressionSnippet limit = getLimitExpression();
if (limit.requiresEvaluation() && !StringUtils.hasText(vectorSearchAnnotation.numCandidates())
&& (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT)) {
VariableSnippet variableBlock = limit.as(VariableSnippet::create)
.variableName(context.localVariable("limitToUse"));
variableBlock.renderDeclaration(builder);
limit = variableBlock;
}
BuilderStyleBuilder vectorSearchOperationBuilder = Snippet.declare(builder)
.variableBuilder(VectorSearchOperation.class, context.localVariable("$vectorSearch"))
.as("$T.vectorSearch($S).path($S).vector($L).limit($L)", Aggregation.class, indexName, searchPath,
vectorParameterName, limit.code());
if (!searchType.equals(SearchType.DEFAULT)) {
vectorSearchOperationBuilder.call("searchType").with("$T.$L", SearchType.class, searchType.name());
}
ExpressionSnippet numCandidates = getNumCandidatesExpression(searchType, limit);
if (!numCandidates.isEmpty()) {
vectorSearchOperationBuilder.call("numCandidates").with(numCandidates);
}
vectorSearchOperationBuilder.call("withSearchScore").with("\"__score__\"");
if (StringUtils.hasText(context.getScoreParameterName())) {
vectorSearchOperationBuilder.call("withFilterBySore").with("$1L -> { $1L.gt($2L.getValue()); }",
context.localVariable("criteria"), context.getScoreParameterName());
} else if (StringUtils.hasText(context.getScoreRangeParameterName())) {
vectorSearchOperationBuilder.call("withFilterBySore")
.with("scoreBetween($1L.getLowerBound(), $1L.getUpperBound())", context.getScoreRangeParameterName());
}
VariableSnippet vectorSearchOperation = vectorSearchOperationBuilder.variable();
getFilter(vectorSearchOperation.getVariableName()).appendTo(builder);
VariableSnippet sortStage = getSort().as(VariableSnippet::create).variableName(context.localVariable("$sort"));
sortStage.renderDeclaration(builder);
builder.add("\n");
VariableSnippet aggregationPipeline = Snippet.declare(builder)
.variable(AggregationPipeline.class, searchQueryVariableName).as("new $T($T.of($L, $L))",
AggregationPipeline.class, List.class, vectorSearchOperation.getVariableName(), sortStage.code());
String scoringFunctionVar = context.localVariable("scoringFunction");
builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar);
if (StringUtils.hasText(context.getScoreParameterName())) {
builder.add("$L.getFunction();\n", context.getScoreParameterName());
} else if (StringUtils.hasText(context.getScoreRangeParameterName())) {
builder.add("scoringFunction($L);\n", context.getScoreRangeParameterName());
} else {
builder.add("$1T.unspecified();\n", ScoringFunction.class);
}
builder.addStatement(
"return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)",
VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class),
context.getRepositoryInformation().getDomainType(), TypeInformation.class,
queryMethod.getReturnType().getType(), aggregationPipeline.getVariableName(), scoringFunctionVar);
return builder.build();
}
private ExpressionSnippet getSort() {
if (!filter.isSorted()) {
return new ExpressionSnippet(
CodeBlock.of("$T.sort($T.Direction.DESC, $S)", Aggregation.class, Sort.class, "__score__"));
}
Builder builder = CodeBlock.builder();
String ctx = context.localVariable("ctx");
String mappedSort = context.localVariable("mappedSort");
builder.add("($T) ($L) -> {\n", AggregationOperation.class, ctx);
builder.indent();
builder.add("$1T $4L = $5L.getMappedObject(parse($2S), $3T.class);\n", Document.class, filter.getSortString(),
context.getMethodReturn().getActualClassName(), mappedSort, ctx);
builder.add("return new $1T($2S, $3L.append(\"__score__\", -1));\n", Document.class, "$sort", mappedSort);
builder.unindent();
builder.add("}");
return new ExpressionSnippet(builder.build());
}
private Snippet getFilter(String vectorSearchVar) {
if (!StringUtils.hasText(filter.getQueryString())) {
return ExpressionSnippet.empty();
}
Builder builder = CodeBlock.builder();
String filterVar = context.localVariable("filter");
builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter")
.filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery());
builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar);
builder.add("\n");
return new ExpressionSnippet(builder.build());
}
private ExpressionSnippet getNumCandidatesExpression(SearchType searchType, ExpressionSnippet limit) {
String numCandidates = vectorSearchAnnotation.numCandidates();
if (StringUtils.hasText(numCandidates)) {
if (MongoCodeBlocks.containsPlaceholder(numCandidates) || MongoCodeBlocks.containsExpression(numCandidates)) {
return new ExpressionSnippet(
MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, context), true);
} else {
return new ExpressionSnippet(CodeBlock.of("$L", numCandidates));
}
}
if (searchType == VectorSearchOperation.SearchType.ANN
|| searchType == VectorSearchOperation.SearchType.DEFAULT) {
Builder builder = CodeBlock.builder();
if (StringUtils.hasText(context.getLimitParameterName())) {
builder.add("$L.max() * 20", context.getLimitParameterName());
} else if (filter.isLimited()) {
builder.add("$L", filter.getLimit() * 20);
} else {
builder.add("$L * 20", limit.code());
}
return new ExpressionSnippet(builder.build());
}
return ExpressionSnippet.empty();
}
private ExpressionSnippet getLimitExpression() {
if (StringUtils.hasText(context.getLimitParameterName())) {
return new ExpressionSnippet(CodeBlock.of("$L", context.getLimitParameterName()));
}
if (filter.isLimited()) {
return new ExpressionSnippet(CodeBlock.of("$L", filter.getLimit()));
}
String limit = vectorSearchAnnotation.limit();
if (StringUtils.hasText(limit)) {
if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) {
return new ExpressionSnippet(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, context),
true);
} else {
return new ExpressionSnippet(CodeBlock.of("$L", limit));
}
}
return new ExpressionSnippet(CodeBlock.of("$T.unlimited()", Limit.class));
}
}
}