VectorSearchOperation.java

/*
 * Copyright 2024-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.data.mongodb.core.aggregation;

import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.bson.BinaryVector;
import org.bson.Document;

import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Vector;
import org.springframework.data.mongodb.core.mapping.MongoVector;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
import org.springframework.lang.Contract;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
 * Performs a semantic search on data in your Atlas cluster. This stage is only available for Atlas Vector Search.
 * Vector data must be less than or equal to 4096 dimensions in width.
 * <h3>Limitations</h3> You cannot use this stage together with:
 * <ul>
 * <li>{@link org.springframework.data.mongodb.core.aggregation.LookupOperation Lookup} stages</li>
 * <li>{@link org.springframework.data.mongodb.core.aggregation.FacetOperation Facet} stage</li>
 * </ul>
 *
 * @author Christoph Strobl
 * @author Mark Paluch
 * @since 4.5
 */
public class VectorSearchOperation implements AggregationOperation {

	private final SearchType searchType;
	private final @Nullable CriteriaDefinition filter;
	private final String indexName;
	private final Limit limit;
	private final @Nullable Integer numCandidates;
	private final QueryPaths path;
	private final Vector vector;
	private final @Nullable String score;
	private final @Nullable Consumer<Criteria> scoreCriteria;

	private VectorSearchOperation(SearchType searchType, @Nullable CriteriaDefinition filter, String indexName,
			Limit limit, @Nullable Integer numCandidates, QueryPaths path, Vector vector, @Nullable String searchScore,
			@Nullable Consumer<Criteria> scoreCriteria) {

		this.searchType = searchType;
		this.filter = filter;
		this.indexName = indexName;
		this.limit = limit;
		this.numCandidates = numCandidates;
		this.path = path;
		this.vector = vector;
		this.score = searchScore;
		this.scoreCriteria = scoreCriteria;
	}

	VectorSearchOperation(String indexName, QueryPaths path, Limit limit, Vector vector) {
		this(SearchType.DEFAULT, null, indexName, limit, null, path, vector, null, null);
	}

	/**
	 * Entrypoint to build a {@link VectorSearchOperation} starting from the {@code index} name to search. Atlas Vector
	 * Search doesn't return results if you misspell the index name or if the specified index doesn't already exist on the
	 * cluster.
	 *
	 * @param index must not be {@literal null} or empty.
	 * @return new instance of {@link VectorSearchOperation.PathContributor}.
	 */
	public static PathContributor search(String index) {
		return new VectorSearchBuilder().index(index);
	}

	/**
	 * Configure the search type to use. {@link SearchType#ENN} leads to an exact search while {@link SearchType#ANN} uses
	 * {@code exact=false}.
	 *
	 * @param searchType must not be null.
	 * @return a new {@link VectorSearchOperation} with {@link SearchType} applied.
	 */
	@Contract("_ -> new")
	public VectorSearchOperation searchType(SearchType searchType) {
		return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score,
				scoreCriteria);
	}

	/**
	 * Criteria expression that compares an indexed field with a boolean, date, objectId, number (not decimals), string,
	 * or UUID to use as a pre-filter.
	 * <p>
	 * Atlas Vector Search supports only the filters for the following MQL match expressions:
	 * <ul>
	 * <li>$gt</li>
	 * <li>$lt</li>
	 * <li>$gte</li>
	 * <li>$lte</li>
	 * <li>$eq</li>
	 * <li>$ne</li>
	 * <li>$in</li>
	 * <li>$nin</li>
	 * <li>$nor</li>
	 * <li>$not</li>
	 * <li>$and</li>
	 * <li>$or</li>
	 * </ul>
	 *
	 * @param filter must not be null.
	 * @return a new {@link VectorSearchOperation} with {@link CriteriaDefinition} applied.
	 */
	@Contract("_ -> new")
	public VectorSearchOperation filter(CriteriaDefinition filter) {
		return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score,
				scoreCriteria);
	}

	/**
	 * Criteria expression that compares an indexed field with a boolean, date, objectId, number (not decimals), string,
	 * or UUID to use as a pre-filter.
	 * <p>
	 * Atlas Vector Search supports only the filters for the following MQL match expressions:
	 * <ul>
	 * <li>$gt</li>
	 * <li>$lt</li>
	 * <li>$gte</li>
	 * <li>$lte</li>
	 * <li>$eq</li>
	 * <li>$ne</li>
	 * <li>$in</li>
	 * <li>$nin</li>
	 * <li>$nor</li>
	 * <li>$not</li>
	 * <li>$and</li>
	 * <li>$or</li>
	 * </ul>
	 *
	 * @param filter must not be null.
	 * @return a new {@link VectorSearchOperation} with {@link CriteriaDefinition} applied.
	 */
	@Contract("_ -> new")
	public VectorSearchOperation filter(Document filter) {

		return filter(new CriteriaDefinition() {
			@Override
			public Document getCriteriaObject() {
				return filter;
			}

			@Nullable
			@Override
			public String getKey() {
				return null;
			}
		});
	}

	/**
	 * Number of nearest neighbors to use during the search. Value must be less than or equal to ({@code <=}) {@code 10000}. You
	 * can't specify a number less than the number of documents to return (limit). This field is required if
	 * {@link #searchType(SearchType)} is {@link SearchType#ANN} or {@link SearchType#DEFAULT}.
	 *
	 * @param numCandidates number of nearest neighbors to use during the search
	 * @return a new {@link VectorSearchOperation} with {@code numCandidates} applied.
	 */
	@Contract("_ -> new")
	public VectorSearchOperation numCandidates(int numCandidates) {
		return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score,
				scoreCriteria);
	}

	/**
	 * Add a {@link AddFieldsOperation} stage including the search score using {@code score} as field name.
	 *
	 * @return a new {@link VectorSearchOperation} with search score applied.
	 * @see #withSearchScore(String)
	 */
	@Contract("-> new")
	public VectorSearchOperation withSearchScore() {
		return withSearchScore("score");
	}

	/**
	 * Add a {@link AddFieldsOperation} stage including the search score using {@code scoreFieldName} as field name.
	 *
	 * @param scoreFieldName name of the score field.
	 * @return a new {@link VectorSearchOperation} with {@code scoreFieldName} applied.
	 * @see #withSearchScore()
	 */
	@Contract("_ -> new")
	public VectorSearchOperation withSearchScore(String scoreFieldName) {
		return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, scoreFieldName,
				scoreCriteria);
	}

	/**
	 * Add a {@link MatchOperation} stage targeting the score field name. Implies that the score field is present by
	 * either reusing a previous {@link AddFieldsOperation} from {@link #withSearchScore()} or
	 * {@link #withSearchScore(String)} or by adding a new {@link AddFieldsOperation} stage.
	 *
	 * @return a new {@link VectorSearchOperation} with search score filter applied.
	 */
	@Contract("_ -> new")
	public VectorSearchOperation withFilterBySore(Consumer<Criteria> score) {
		return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector,
				StringUtils.hasText(this.score) ? this.score : "score", score);
	}

	@Override
	public Document toDocument(AggregationOperationContext context) {

		Document $vectorSearch = new Document();

		if (searchType != null && !searchType.equals(SearchType.DEFAULT)) {
			$vectorSearch.append("exact", searchType.equals(SearchType.ENN));
		}

		if (filter != null) {
			$vectorSearch.append("filter", context.getMappedObject(filter.getCriteriaObject()));
		}

		$vectorSearch.append("index", indexName);

		if(limit.isLimited()) {
			$vectorSearch.append("limit", limit.max());
		}

		if (numCandidates != null) {
			$vectorSearch.append("numCandidates", numCandidates);
		}

		Object path = this.path.getPathObject();

		if (path instanceof String pathFieldName) {
			Document mappedObject = context.getMappedObject(new Document(pathFieldName, 1));
			path = mappedObject.keySet().iterator().next();
		}

		Object source = vector.getSource();

		if (source instanceof float[]) {
			source = vector.toDoubleArray();
		}

		if (source instanceof double[] ds) {
			source = Arrays.stream(ds).boxed().collect(Collectors.toList());
		}

		$vectorSearch.append("path", path);
		$vectorSearch.append("queryVector", source);

		return new Document(getOperator(), $vectorSearch);
	}

	@Override
	public List<Document> toPipelineStages(AggregationOperationContext context) {

		if (!StringUtils.hasText(score)) {
			return List.of(toDocument(context));
		}

		AddFieldsOperation $vectorSearchScore = Aggregation.addFields().addField(score)
				.withValueOfExpression("{\"$meta\":\"vectorSearchScore\"}").build();

		if (scoreCriteria == null) {
			return List.of(toDocument(context), $vectorSearchScore.toDocument(context));
		}

		Criteria criteria = Criteria.where(score);
		scoreCriteria.accept(criteria);
		MatchOperation $filterByScore = Aggregation.match(criteria);

		return List.of(toDocument(context), $vectorSearchScore.toDocument(context), $filterByScore.toDocument(context));
	}

	@Override
	public String getOperator() {
		return "$vectorSearch";
	}

	/**
	 * Builder helper to create a {@link VectorSearchOperation}.
	 */
	private static class VectorSearchBuilder implements PathContributor, VectorContributor, LimitContributor {

		@Nullable String index;
		@Nullable QueryPath<String> paths;
		@Nullable Vector vector;

		PathContributor index(String index) {
			this.index = index;
			return this;
		}

		@Override
		public VectorContributor path(String path) {

			this.paths = QueryPath.path(path);
			return this;
		}

		@Override
		public VectorSearchOperation limit(Limit limit) {

			Assert.notNull(index, "Index must be set first");
			Assert.notNull(paths, "Path must be set first");
			Assert.notNull(vector, "Vector must be set first");
			
			return new VectorSearchOperation(index, QueryPaths.of(paths), limit, vector);
		}

		@Override
		public LimitContributor vector(Vector vector) {
			this.vector = vector;
			return this;
		}
	}

	/**
	 * Search type, ANN as approximation or ENN for exact search.
	 */
	public enum SearchType {

		/** MongoDB Server default (value will be omitted) */
		DEFAULT,
		/** Approximate Nearest Neighbour */
		ANN,
		/** Exact Nearest Neighbour */
		ENN
	}

	/**
	 * Value object capturing query paths.
	 */
	public static class QueryPaths {

		private final Set<QueryPath<?>> paths;

		private QueryPaths(Set<QueryPath<?>> paths) {
			this.paths = paths;
		}

		/**
		 * Factory method to create {@link QueryPaths} from a single {@link QueryPath}.
		 *
		 * @param path
		 * @return a new {@link QueryPaths} instance.
		 */
		public static QueryPaths of(QueryPath<String> path) {
			return new QueryPaths(Set.of(path));
		}

		Object getPathObject() {

			if (paths.size() == 1) {
				return paths.iterator().next().value();
			}
			return paths.stream().map(QueryPath::value).collect(Collectors.toList());
		}
	}

	/**
	 * Interface describing a query path contract. Query paths might be simple field names, wildcard paths, or
	 * multi-paths. paths.
	 *
	 * @param <T>
	 */
	public interface QueryPath<T> {

		T value();

		static QueryPath<String> path(String field) {
			return new SimplePath(field);
		}
	}

	public static class SimplePath implements QueryPath<String> {

		String name;

		public SimplePath(String name) {
			this.name = name;
		}

		@Override
		public String value() {
			return name;
		}
	}

	/**
	 * Fluent API to configure a path on the VectorSearchOperation builder.
	 */
	public interface PathContributor {

		/**
		 * Indexed vector type field to search.
		 *
		 * @param path name of the search path.
		 * @return
		 */
		@Contract("_ -> this")
		VectorContributor path(String path);
	}

	/**
	 * Fluent API to configure a vector on the VectorSearchOperation builder.
	 */
	public interface VectorContributor {

		/**
		 * Array of float numbers that represent the query vector. The number type must match the indexed field value type.
		 * Otherwise, Atlas Vector Search doesn't return any results or errors.
		 *
		 * @param vector the query vector.
		 * @return
		 */
		@Contract("_ -> this")
		default LimitContributor vector(float... vector) {
			return vector(Vector.of(vector));
		}

		/**
		 * Array of byte numbers that represent the query vector. The number type must match the indexed field value type.
		 * Otherwise, Atlas Vector Search doesn't return any results or errors.
		 *
		 * @param vector the query vector.
		 * @return
		 */
		@Contract("_ -> this")
		default LimitContributor vector(byte[] vector) {
			return vector(BinaryVector.int8Vector(vector));
		}

		/**
		 * Array of double numbers that represent the query vector. The number type must match the indexed field value type.
		 * Otherwise, Atlas Vector Search doesn't return any results or errors.
		 *
		 * @param vector the query vector.
		 * @return
		 */
		@Contract("_ -> this")
		default LimitContributor vector(double... vector) {
			return vector(Vector.of(vector));
		}

		/**
		 * Array of numbers that represent the query vector. The number type must match the indexed field value type.
		 * Otherwise, Atlas Vector Search doesn't return any results or errors.
		 *
		 * @param vector the query vector.
		 * @return
		 */
		@Contract("_ -> this")
		default LimitContributor vector(List<? extends Number> vector) {
			return vector(Vector.of(vector));
		}

		/**
		 * Binary vector (BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or int8 type) that
		 * represent the query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector
		 * Search doesn't return any results or errors.
		 *
		 * @param vector the query vector.
		 * @return
		 */
		@Contract("_ -> this")
		default LimitContributor vector(BinaryVector vector) {
			return vector(MongoVector.of(vector));
		}

		/**
		 * The query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector Search doesn't
		 * return any results or errors.
		 *
		 * @param vector the query vector.
		 * @return
		 */
		@Contract("_ -> this")
		LimitContributor vector(Vector vector);
	}

	/**
	 * Fluent API to configure a limit on the VectorSearchOperation builder.
	 */
	public interface LimitContributor {

		/**
		 * Number (of type int only) of documents to return in the results. This value can't exceed the value of
		 * numCandidates if you specify numCandidates.
		 *
		 * @param limit
		 * @return
		 */
		@Contract("_ -> this")
		default VectorSearchOperation limit(int limit) {
			return limit(Limit.of(limit));
		}

		/**
		 * Number (of type int only) of documents to return in the results. This value can't exceed the value of
		 * numCandidates if you specify numCandidates.
		 *
		 * @param limit
		 * @return
		 */
		@Contract("_ -> this")
		VectorSearchOperation limit(Limit limit);
	}

}