SearchInteraction.java
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.bson.BinaryVector;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.query.MongoParameters;
import org.springframework.data.repository.aot.generate.QueryMetadata;
import org.springframework.util.StringUtils;
/**
* A vector search interaction for MongoDB repositories.
*
* @author Christoph Strobl
* @since 5.0
*/
class SearchInteraction extends MongoInteraction implements QueryMetadata {
private final Class<?> domainType;
private final AotStringQuery filter;
private final VectorSearch vectorSearch;
private final MongoParameters parameters;
private final MongoMappingContext mappingContext;
public SearchInteraction(Class<?> domainType, VectorSearch vectorSearch, AotStringQuery filter,
MongoParameters parameters, MongoMappingContext mappingContext) {
this.domainType = domainType;
this.vectorSearch = vectorSearch;
this.filter = filter;
this.parameters = parameters;
this.mappingContext = mappingContext;
}
public AotStringQuery getFilter() {
return filter;
}
@Nullable
String getIndexName() {
return vectorSearch.indexName();
}
public MongoParameters getParameters() {
return parameters;
}
@Override
InteractionType getExecutionType() {
return InteractionType.AGGREGATION;
}
@Override
public Map<String, Object> serialize() {
Map<String, Object> serialized = new LinkedHashMap<>();
if (StringUtils.hasText(vectorSearch.indexName())) {
serialized.put("index", vectorSearch.indexName());
}
serialized.put("path", getSearchPath());
if (vectorSearch.searchType().equals(SearchType.ENN)) {
serialized.put("exact", true);
}
if (StringUtils.hasText(filter.getQueryString())) {
serialized.put("filter", filter.getQueryString());
}
String limit = limitParameter();
if (StringUtils.hasText(limit)) {
serialized.put("limit", limit);
}
if (StringUtils.hasText(vectorSearch.numCandidates())) {
serialized.put("numCandidates", vectorSearch.numCandidates());
} else if (StringUtils.hasText(limit)) {
serialized.put("numCandidates", limit + " * 20");
}
serialized.put("queryVector", "?" + parameters.getVectorIndex());
String $vectorSearch = DocumentSerializer.toJson(new Document("$vectorSearch", serialized));
return Map.of("pipeline", List.of($vectorSearch));
}
private @Nullable String limitParameter() {
if (parameters.hasLimitParameter()) {
return "?" + parameters.getLimitIndex();
} else if (StringUtils.hasText(vectorSearch.limit())) {
return vectorSearch.limit();
}
return null;
}
public String getSearchPath() {
if (StringUtils.hasText(vectorSearch.path())) {
return vectorSearch.path();
}
MongoPersistentEntity<?> entity = mappingContext.getRequiredPersistentEntity(domainType);
for (MongoPersistentProperty property : entity) {
if (Vector.class.isAssignableFrom(property.getActualType())
|| BinaryVector.class.isAssignableFrom(property.getActualType())) {
return property.getName();
}
}
throw new IllegalArgumentException("No vector search path found for type %s".formatted(domainType));
}
}