VectorSearchDelegate.java

/*
 * Copyright 2025-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.data.mongodb.repository.query;

import java.util.ArrayList;
import java.util.List;

import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Vector;
import org.springframework.data.expression.ValueExpression;
import org.springframework.data.mapping.PersistentPropertyPath;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.core.query.BasicQuery;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.data.repository.query.ValueExpressionDelegate;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.util.NumberUtils;
import org.springframework.util.StringUtils;

/**
 * Delegate to assemble information about Vector Search queries necessary to run a MongoDB {@code $vectorSearch}.
 *
 * @author Mark Paluch
 */
class VectorSearchDelegate {

	private final VectorSearchQueryFactory queryFactory;
	private final VectorSearchOperation.SearchType searchType;
	private final String indexName;
	private final @Nullable Integer numCandidates;
	private final @Nullable String numCandidatesExpression;
	private final Limit limit;
	private final @Nullable String limitExpression;
	private final MongoConverter converter;

	VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) {

		VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow();

		this.searchType = vectorSearch.searchType();
		this.indexName = method.getAnnotatedHint();

		if (StringUtils.hasText(vectorSearch.numCandidates())) {

			ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates());

			if (expression.isLiteral()) {
				this.numCandidates = Integer.parseInt(vectorSearch.numCandidates());
				this.numCandidatesExpression = null;
			} else {
				this.numCandidates = null;
				this.numCandidatesExpression = vectorSearch.numCandidates();
			}

		} else {
			this.numCandidates = null;
			this.numCandidatesExpression = null;
		}

		if (StringUtils.hasText(vectorSearch.limit())) {

			ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit());

			if (expression.isLiteral()) {
				this.limit = Limit.of(Integer.parseInt(vectorSearch.limit()));
				this.limitExpression = null;
			} else {
				this.limit = Limit.unlimited();
				this.limitExpression = vectorSearch.limit();
			}

		} else {
			this.limit = Limit.unlimited();
			this.limitExpression = null;
		}

		this.converter = converter;

		if (StringUtils.hasText(vectorSearch.filter())) {
			this.queryFactory = StringUtils.hasText(vectorSearch.path())
					? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path())
					: new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity());
		} else {
			this.queryFactory = new PartTreeQueryFactory(
					new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()),
					converter.getMappingContext());
		}
	}

	/**
	 * Create Query Metadata for {@code $vectorSearch}.
	 */
	QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor,
			MongoParameterAccessor accessor, @Nullable Class<?> typeToRead, ParameterBindingDocumentCodec codec,
			ParameterBindingContext context) {

		String scoreField = "__score__";
		Class<?> outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType();
		VectorSearchInput vectorSearchInput = createSearchInput(evaluator, accessor, codec, context);
		AggregationPipeline pipeline = createVectorSearchPipeline(vectorSearchInput, scoreField, outputType, accessor,
				evaluator);

		return new QueryContainer(vectorSearchInput.path, scoreField, vectorSearchInput.query, pipeline, searchType,
				outputType, getSimilarityFunction(accessor), indexName);
	}

	@SuppressWarnings("NullAway")
	AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType,
			MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) {

		Vector vector = accessor.getVector();
		Score score = accessor.getScore();
		Range<Score> distance = accessor.getScoreRange();
		Limit limit = Limit.of(input.query().getLimit());

		List<AggregationOperation> stages = new ArrayList<>();
		VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector)
				.limit(limit);

		Integer candidates = null;
		if (this.numCandidatesExpression != null) {
			candidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
		} else if (this.numCandidates != null) {
			candidates = this.numCandidates;
		} else if (input.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN
				|| searchType == VectorSearchOperation.SearchType.DEFAULT)) {

			/*
			MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy.
			 */
			candidates = input.query().getLimit() * 20;
		}

		if (candidates != null) {
			$vectorSearch = $vectorSearch.numCandidates(candidates);
		}
		//
		$vectorSearch = $vectorSearch.filter(input.query.getQueryObject());
		$vectorSearch = $vectorSearch.searchType(this.searchType);
		$vectorSearch = $vectorSearch.withSearchScore(scoreField);

		if (score != null) {
			$vectorSearch = $vectorSearch.withFilterBySore(c -> {
				c.gt(score.getValue());
			});
		} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
			$vectorSearch = $vectorSearch.withFilterBySore(c -> {
				Range.Bound<Score> lower = distance.getLowerBound();
				if (lower.isBounded()) {
					double value = lower.getValue().get().getValue();
					if (lower.isInclusive()) {
						c.gte(value);
					} else {
						c.gt(value);
					}
				}

				Range.Bound<Score> upper = distance.getUpperBound();
				if (upper.isBounded()) {

					double value = upper.getValue().get().getValue();
					if (upper.isInclusive()) {
						c.lte(value);
					} else {
						c.lt(value);
					}
				}
			});
		}

		stages.add($vectorSearch);

		if (input.query().isSorted()) {

			stages.add(ctx -> {

				Document mappedSort = ctx.getMappedObject(input.query().getSortObject(), outputType);
				mappedSort.append(scoreField, -1);
				return ctx.getMappedObject(new Document("$sort", mappedSort));
			});
		} else {
			stages.add(Aggregation.sort(Sort.Direction.DESC, scoreField));
		}

		return new AggregationPipeline(stages);
	}

	private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor,
			ParameterBindingDocumentCodec codec, ParameterBindingContext context) {

		VectorSearchInput input = queryFactory.createQuery(accessor, codec, context);
		Limit limit = getLimit(evaluator, accessor);
		if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) {
			input.query().limit(limit);
		}
		return input;
	}

	private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) {

		if (this.limitExpression != null) {

			Object value = evaluator.evaluate(this.limitExpression);
			if (value != null) {
				if (value instanceof Limit l) {
					return l;
				}
				if (value instanceof Number n) {
					return Limit.of(n.intValue());
				}
				if (value instanceof String s) {
					return Limit.of(NumberUtils.parseNumber(s, Integer.class));
				}
				throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number");
			}
		}

		if (this.limit.isLimited()) {
			return this.limit;
		}

		return accessor.getLimit();
	}

	public String getQueryString() {
		return queryFactory.getQueryString();
	}

	ScoringFunction getSimilarityFunction(MongoParameterAccessor accessor) {

		Score score = accessor.getScore();

		if (score != null) {
			return score.getFunction();
		}

		Range<Score> scoreRange = accessor.getScoreRange();

		if (scoreRange != null) {
			if (scoreRange.getUpperBound().isBounded()) {
				return scoreRange.getUpperBound().getValue().get().getFunction();
			}

			if (scoreRange.getLowerBound().isBounded()) {
				return scoreRange.getLowerBound().getValue().get().getFunction();
			}
		}

		return ScoringFunction.unspecified();
	}

	/**
	 * Metadata for a Vector Search Aggregation.
	 *
	 * @param path
	 * @param query
	 * @param searchType
	 * @param outputType
	 * @param scoringFunction
	 */
	record QueryContainer(String path, String scoreField, Query query, AggregationPipeline pipeline,
			VectorSearchOperation.SearchType searchType, Class<?> outputType, ScoringFunction scoringFunction, String index) {

	}

	/**
	 * Strategy interface to implement a query factory for the Vector Search pre-filter query.
	 */
	private interface VectorSearchQueryFactory {

		VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
				ParameterBindingContext context);

		/**
		 * @return the underlying query string to determine {@link ParameterBindingContext}.
		 */
		String getQueryString();
	}

	private static class AnnotatedQueryFactory implements VectorSearchQueryFactory {

		private final String query;
		private final String path;

		AnnotatedQueryFactory(String query, String path) {

			this.query = query;
			this.path = path;
		}

		AnnotatedQueryFactory(String query, MongoPersistentEntity<?> entity) {

			this.query = query;
			String path = null;
			for (MongoPersistentProperty property : entity) {
				if (Vector.class.isAssignableFrom(property.getType())) {
					path = property.getFieldName();
					break;
				}
			}

			if (path == null) {
				throw new InvalidMongoDbApiUsageException(
						"Cannot find Vector Search property in entity [%s]".formatted(entity.getName()));
			}

			this.path = path;
		}

		public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
				ParameterBindingContext context) {

			Document queryObject = codec.decode(this.query, context);
			Query query = new BasicQuery(queryObject);

			Sort sort = parameterAccessor.getSort();
			if (sort.isSorted()) {
				query = query.with(sort);
			}

			return new VectorSearchInput(path, query);
		}

		@Override
		public String getQueryString() {
			return this.query;
		}
	}

	private class PartTreeQueryFactory implements VectorSearchQueryFactory {

		private final String path;
		private final PartTree tree;

		@SuppressWarnings("NullableProblems")
		PartTreeQueryFactory(PartTree tree, MappingContext<?, MongoPersistentProperty> context) {

			String path = null;
			for (PartTree.OrPart part : tree) {
				for (Part p : part) {
					if (p.getType() == Part.Type.SIMPLE_PROPERTY || p.getType() == Part.Type.NEAR
							|| p.getType() == Part.Type.WITHIN || p.getType() == Part.Type.BETWEEN) {
						PersistentPropertyPath<MongoPersistentProperty> ppp = context.getPersistentPropertyPath(p.getProperty());
						MongoPersistentProperty property = ppp.getLeafProperty();

						if (Vector.class.isAssignableFrom(property.getType())) {
							path = p.getProperty().toDotPath();
							break;
						}
					}
				}
			}

			if (path == null) {
				throw new InvalidMongoDbApiUsageException(
						"No Simple Property/Near/Within/Between part found for a Vector property");
			}

			this.path = path;
			this.tree = tree;
		}

		@SuppressWarnings("NullAway")
		public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
				ParameterBindingContext context) {

			MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false,
					true);

			Query query = creator.createQuery(parameterAccessor.getSort());

			if (tree.isLimiting()) {
				query.limit(tree.getMaxResults());
			}

			return new VectorSearchInput(path, query);
		}

		@Override
		public String getQueryString() {
			return "";
		}
	}

	private record VectorSearchInput(String path, Query query) {

	}

}