VectorIndex.java

/*
 * Copyright 2024-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.data.mongodb.core.index;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;

import org.bson.Document;
import org.jspecify.annotations.Nullable;

import org.springframework.data.core.TypeInformation;
import org.springframework.data.core.TypedPropertyPath;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.lang.Contract;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;

/**
 * {@link SearchIndexDefinition} for creating MongoDB
 * <a href="https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/">Vector Index</a> required to
 * run {@code $vectorSearch} queries.
 *
 * @author Christoph Strobl
 * @author Mark Paluch
 * @since 4.5
 */
public class VectorIndex implements SearchIndexDefinition {

	private final String name;
	private final List<SearchField> fields = new ArrayList<>();

	/**
	 * Create a new {@link VectorIndex} instance.
	 *
	 * @param name The name of the index.
	 */
	public VectorIndex(String name) {
		this.name = name;
	}

	/**
	 * Add a filter field.
	 *
	 * @param path dot notation to field/property used for filtering.
	 * @return this.
	 */
	@Contract("_ -> this")
	public VectorIndex addFilter(String path) {

		Assert.hasText(path, "Path must not be null or empty");

		return addField(new VectorFilterField(path, "filter"));
	}

	/**
	 * Add a filter field.
	 *
	 * @param property The property used for filtering.
	 * @return this.
	 * @since 5.1
	 */
	@Contract("_ -> this")
	public <T,P> VectorIndex addFilter(TypedPropertyPath<T,P> property) {
		return addFilter(property.toDotPath());
	}

	/**
	 * Add a vector field and accept a {@link VectorFieldBuilder} customizer.
	 *
	 * @param path dot notation to field/property used for filtering.
	 * @param customizer customizer function.
	 * @return this.
	 */
	@Contract("_, _ -> this")
	public VectorIndex addVector(String path, Consumer<VectorFieldBuilder> customizer) {

		Assert.hasText(path, "Path must not be null or empty");

		VectorFieldBuilder builder = new VectorFieldBuilder(path, "vector");
		customizer.accept(builder);
		return addField(builder.build());
	}

	/**
	 * Add a vector field and accept a {@link VectorFieldBuilder} customizer.
	 *
	 * @param property the property holding the vector.
	 * @param customizer customizer function.
	 * @return this.
	 * @since 5.1
	 */
	@Contract("_, _ -> this")
	public <T,P> VectorIndex addVector(TypedPropertyPath<T,P> property, Consumer<VectorFieldBuilder> customizer) {
		return addVector(property.toDotPath(), customizer);
	}

	@Override
	public String getName() {
		return name;
	}

	@Override
	public String getType() {
		return "vectorSearch";
	}

	@Override
	public Document getDefinition(@Nullable TypeInformation<?> entity,
			@Nullable MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext) {

		MongoPersistentEntity<?> persistentEntity = entity != null
				? (mappingContext != null ? mappingContext.getPersistentEntity(entity) : null)
				: null;

		Document definition = new Document();
		List<Document> fields = new ArrayList<>();
		definition.put("fields", fields);

		for (SearchField field : this.fields) {

			Document filter = new Document("type", field.type());
			filter.put("path", resolvePath(field.path(), persistentEntity, mappingContext));

			if (field instanceof VectorIndexField vif) {

				filter.put("numDimensions", vif.dimensions());
				filter.put("similarity", vif.similarity());
				if (StringUtils.hasText(vif.quantization)) {
					filter.put("quantization", vif.quantization());
				}
			}
			fields.add(filter);
		}

		return definition;
	}

	@Contract("_ -> this")
	private VectorIndex addField(SearchField filterField) {

		fields.add(filterField);
		return this;
	}

	@Override
	public String toString() {
		return "VectorIndex{" + "name='" + name + '\'' + ", fields=" + fields + ", type='" + getType() + '\'' + '}';
	}

	/**
	 * Parse the {@link Document} into a {@link VectorIndex}.
	 */
	static VectorIndex of(Document document) {

		VectorIndex index = new VectorIndex(document.getString("name"));

		String definitionKey = document.containsKey("latestDefinition") ? "latestDefinition" : "definition";
		Document definition = document.get(definitionKey, Document.class);

		for (Object entry : definition.get("fields", List.class)) {
			if (entry instanceof Document field) {
				Object fieldType = field.get("type");
				if (ObjectUtils.nullSafeEquals(fieldType, "vector")) {
					index.addField(new VectorIndexField(field.getString("path"), "vector", field.getInteger("numDimensions"),
							field.getString("similarity"), field.getString("quantization")));
				} else {
					index.addField(new VectorFilterField(field.getString("path"), "filter"));
				}
			}
		}

		return index;
	}

	private String resolvePath(String path, @Nullable MongoPersistentEntity<?> persistentEntity,
			@Nullable MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext) {

		if (persistentEntity == null || mappingContext == null) {
			return path;
		}

		QueryMapper.MetadataBackedField mbf = new QueryMapper.MetadataBackedField(path, persistentEntity, mappingContext);

		return mbf.getMappedKey();
	}

	interface SearchField {

		String path();

		String type();
	}

	record VectorFilterField(String path, String type) implements SearchField {
	}

	record VectorIndexField(String path, String type, int dimensions, @Nullable String similarity,
			@Nullable String quantization) implements SearchField {
	}

	/**
	 * Builder to create a vector field
	 */
	public static class VectorFieldBuilder {

		private final String path;
		private final String type;

		private int dimensions;
		private @Nullable String similarity;
		private @Nullable String quantization;

		VectorFieldBuilder(String path, String type) {

			this.path = path;
			this.type = type;
		}

		/**
		 * Number of vector dimensions enforced at index- and query-time.
		 *
		 * @param dimensions value between {@code 0} and {@code 4096}.
		 * @return this.
		 */
		@Contract("_ -> this")
		public VectorFieldBuilder dimensions(int dimensions) {
			this.dimensions = dimensions;
			return this;
		}

		/**
		 * Use similarity based on the angle between vectors.
		 *
		 * @return new instance of {@link VectorIndex}.
		 */
		@Contract(" -> this")
		public VectorFieldBuilder cosine() {
			return similarity(SimilarityFunction.COSINE);
		}

		/**
		 * Use similarity based the distance between vector ends.
		 */
		@Contract(" -> this")
		public VectorFieldBuilder euclidean() {
			return similarity(SimilarityFunction.EUCLIDEAN);
		}

		/**
		 * Use similarity based on both angle and magnitude of the vectors.
		 *
		 * @return new instance of {@link VectorIndex}.
		 */
		@Contract(" -> this")
		public VectorFieldBuilder dotProduct() {
			return similarity(SimilarityFunction.DOT_PRODUCT);
		}

		/**
		 * Similarity function used.
		 *
		 * @param similarity should be one of {@literal euclidean | cosine | dotProduct}.
		 * @return this.
		 * @see SimilarityFunction
		 * @see #similarity(SimilarityFunction)
		 */
		@Contract("_ -> this")
		public VectorFieldBuilder similarity(String similarity) {

			this.similarity = similarity;
			return this;
		}

		/**
		 * Similarity function used.
		 *
		 * @param similarity must not be {@literal null}.
		 * @return this.
		 */
		@Contract("_ -> this")
		public VectorFieldBuilder similarity(SimilarityFunction similarity) {

			return similarity(similarity.getFunctionName());
		}

		/**
		 * Quantization used.
		 *
		 * @param quantization should be one of {@literal none | scalar | binary}.
		 * @return this.
		 * @see Quantization
		 * @see #quantization(Quantization)
		 */
		public VectorFieldBuilder quantization(String quantization) {

			this.quantization = quantization;
			return this;
		}

		/**
		 * Quantization used.
		 *
		 * @param quantization must not be {@literal null}.
		 * @return this.
		 */
		public VectorFieldBuilder quantization(Quantization quantization) {
			return quantization(quantization.getQuantizationName());
		}

		VectorIndexField build() {
			return new VectorIndexField(this.path, this.type, this.dimensions, this.similarity, this.quantization);
		}
	}

	/**
	 * Similarity function used to calculate vector distance.
	 */
	public enum SimilarityFunction {

		DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean");

		final String functionName;

		SimilarityFunction(String functionName) {
			this.functionName = functionName;
		}

		public String getFunctionName() {
			return functionName;
		}
	}

	/**
	 * Vector quantization. Quantization reduce vector sizes while preserving performance.
	 */
	public enum Quantization {

		NONE("none"),

		/**
		 * Converting a float point into an integer.
		 */
		SCALAR("scalar"),

		/**
		 * Converting a float point into a single bit.
		 */
		BINARY("binary");

		final String quantizationName;

		Quantization(String quantizationName) {
			this.quantizationName = quantizationName;
		}

		public String getQuantizationName() {
			return quantizationName;
		}
	}
}