SimilarityNormalizer.java

/*
 * Copyright 2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.data.jpa.repository.query;

import java.util.HashMap;
import java.util.Map;
import java.util.function.DoubleUnaryOperator;

import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.VectorScoringFunctions;

/**
 * Normalizes the score returned by a database to a similarity value and vice versa.
 *
 * @author Mark Paluch
 * @since 4.0
 * @see org.springframework.data.domain.Similarity
 */
public class SimilarityNormalizer {

	/**
	 * Identity normalizer for {@link ScoringFunction#unspecified()} scoring function without altering the score.
	 */
	public static final SimilarityNormalizer IDENTITY = new SimilarityNormalizer(ScoringFunction.unspecified(),
			DoubleUnaryOperator.identity(), DoubleUnaryOperator.identity());

	/**
	 * Normalizer for Euclidean scores using {@code euclidean_distance(���)} as the scoring function.
	 */
	public static final SimilarityNormalizer EUCLIDEAN = new SimilarityNormalizer(VectorScoringFunctions.EUCLIDEAN,
			it -> 1 / (1.0 + Math.pow(it, 2)), it -> it == 0 ? Float.MAX_VALUE : Math.sqrt((1 / it) - 1));

	/**
	 * Normalizer for Cosine scores using {@code cosine_distance(���)} as the scoring function.
	 */
	public static final SimilarityNormalizer COSINE = new SimilarityNormalizer(VectorScoringFunctions.COSINE,
			it -> (1.0 + (1 - it)) / 2.0, it -> 1 - ((it * 2) - 1));

	/**
	 * Normalizer for Negative Inner Product (Dot) scores using {@code negative_inner_product(���)} as the scoring function.
	 */
	public static final SimilarityNormalizer DOT_PRODUCT = new SimilarityNormalizer(VectorScoringFunctions.DOT_PRODUCT,
			it -> (1 - it) / 2, it -> 1 - (it * 2));

	private static final Map<ScoringFunction, SimilarityNormalizer> NORMALIZERS = new HashMap<>();

	static {
		NORMALIZERS.put(EUCLIDEAN.scoringFunction, EUCLIDEAN);
		NORMALIZERS.put(COSINE.scoringFunction, COSINE);
		NORMALIZERS.put(DOT_PRODUCT.scoringFunction, DOT_PRODUCT);
	}

	private final ScoringFunction scoringFunction;
	private final DoubleUnaryOperator similarity;
	private final DoubleUnaryOperator score;

	/**
	 * Constructor for {@link SimilarityNormalizer} using the given {@link DoubleUnaryOperator} for similarity and score
	 * computation.
	 *
	 * @param similarity compute the similarity from the underlying score returned by a database result.
	 * @param score compute the score value from a given {@link org.springframework.data.domain.Similarity} to compare
	 *          against database results.
	 */
	SimilarityNormalizer(ScoringFunction scoringFunction, DoubleUnaryOperator similarity, DoubleUnaryOperator score) {
		this.scoringFunction = scoringFunction;
		this.score = score;
		this.similarity = similarity;
	}

	/**
	 * Lookup a {@link SimilarityNormalizer} for a given {@link ScoringFunction}.
	 *
	 * @param scoringFunction the scoring function to translate.
	 * @return the {@link SimilarityNormalizer} for the given {@link ScoringFunction}.
	 * @throws IllegalArgumentException if the {@link ScoringFunction} is not associated with a
	 *           {@link SimilarityNormalizer}.
	 */
	public static SimilarityNormalizer get(ScoringFunction scoringFunction) {

		SimilarityNormalizer normalizer = NORMALIZERS.get(scoringFunction);

		if (normalizer == null) {
			throw new IllegalArgumentException("No SimilarityNormalizer found for " + scoringFunction.getName());
		}

		return normalizer;
	}

	/**
	 * @param score score value as returned by the database.
	 * @return the {@link org.springframework.data.domain.Similarity} value.
	 */
	public double getSimilarity(double score) {
		return similarity.applyAsDouble(score);
	}

	/**
	 * @param similarity similarity value as requested by the query mechanism.
	 * @return database score value.
	 */
	public double getScore(double similarity) {
		return score.applyAsDouble(similarity);
	}

	@Override
	public String toString() {
		return "%s Normalizer: Similarity[0 to 1] -> Score[%f to %f]".formatted(scoringFunction.getName(), getScore(0),
				getScore(1));
	}

}