MongoAotRepositoryFragmentSupport.java
/*
* Copyright 2025-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Consumer;
import java.util.regex.Pattern;
import org.bson.BsonRegularExpression;
import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.Score;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.expression.ValueEvaluationContextProvider;
import org.springframework.data.expression.ValueExpression;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
import org.springframework.data.geo.Polygon;
import org.springframework.data.geo.Shape;
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.geo.GeoJson;
import org.springframework.data.mongodb.core.geo.Sphere;
import org.springframework.data.mongodb.core.mapping.FieldName;
import org.springframework.data.mongodb.core.query.BasicQuery;
import org.springframework.data.mongodb.core.query.Collation;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.MongoRegexCreator;
import org.springframework.data.mongodb.core.query.MongoRegexCreator.MatchMode;
import org.springframework.data.mongodb.repository.query.MongoParameters;
import org.springframework.data.mongodb.repository.query.MongoParametersParameterAccessor;
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.query.ParametersSource;
import org.springframework.data.repository.query.ValueExpressionDelegate;
import org.springframework.data.util.Lazy;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ConcurrentLruCache;
import org.springframework.util.ObjectUtils;
/**
* Support class for MongoDB AOT repository fragments.
*
* @author Christoph Strobl
* @since 5.0
*/
public class MongoAotRepositoryFragmentSupport {
private static final ParameterBindingDocumentCodec CODEC = new ParameterBindingDocumentCodec();
private final RepositoryMetadata repositoryMetadata;
private final MongoOperations mongoOperations;
private final MongoConverter mongoConverter;
private final ProjectionFactory projectionFactory;
private final ValueExpressionDelegate valueExpressions;
private final Lazy<ConcurrentLruCache<String, ValueExpression>> expressions;
private final Lazy<ConcurrentLruCache<Method, MongoParameters>> mongoParameters;
private final Lazy<ConcurrentLruCache<Method, ValueEvaluationContextProvider>> contextProviders;
protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations,
RepositoryFactoryBeanSupport.FragmentCreationContext context) {
this(mongoOperations, context.getRepositoryMetadata(), context.getValueExpressionDelegate(),
context.getProjectionFactory());
}
protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations, RepositoryMetadata repositoryMetadata,
ValueExpressionDelegate valueExpressions, ProjectionFactory projectionFactory) {
this.mongoOperations = mongoOperations;
this.mongoConverter = mongoOperations.getConverter();
this.repositoryMetadata = repositoryMetadata;
this.projectionFactory = projectionFactory;
this.valueExpressions = valueExpressions;
this.expressions = Lazy.of(() -> new ConcurrentLruCache<>(32, valueExpressions::parse));
this.mongoParameters = Lazy
.of(() -> new ConcurrentLruCache<>(32, it -> new MongoParameters(ParametersSource.of(repositoryMetadata, it))));
this.contextProviders = Lazy.of(() -> new ConcurrentLruCache<>(32,
it -> valueExpressions.createValueContextProvider(mongoParameters.get().get(it))));
}
protected Document parse(String json) {
return CODEC.decode(json);
}
protected Document bindParameters(Method method, String source, Object... args) {
expandGeoShapes(args);
MongoParameters mongoParameters = this.mongoParameters.get().get(method);
MongoParametersParameterAccessor parametersParameterAccessor = new MongoParametersParameterAccessor(mongoParameters,
args);
ParameterBindingContext bindingContext = new ParameterBindingContext(parametersParameterAccessor::getBindableValue,
new ValueExpressionEvaluator() {
@Override
@SuppressWarnings("unchecked")
public <T> @Nullable T evaluate(String expression) {
return (T) MongoAotRepositoryFragmentSupport.this.evaluate(method, expression, args);
}
});
return CODEC.decode(source, bindingContext);
}
protected @Nullable Object evaluate(Method method, String source, Object... args) {
expandGeoShapes(args);
ValueExpression expression = this.expressions.get().get(source);
ValueEvaluationContextProvider contextProvider = this.contextProviders.get().get(method);
return expression.evaluate(contextProvider.getEvaluationContext(args, expression.getExpressionDependencies()));
}
/**
* Expand geo shapes in the given arguments to a format that can be handled by the MongoDB converter without us
* passing in the actual {@link Shape} object (except for {@link GeoJson}).
*
* @param args
*/
private static void expandGeoShapes(Object[] args) {
for (int i = 0; i < args.length; i++) {
// renders as generic $geometry, thus can be handled by the converter when parsing
if (args[i] instanceof GeoJson) {
continue;
}
if (args[i] instanceof Circle c) {
args[i] = List.of(List.of(c.getCenter().getX(), c.getCenter().getY()), c.getRadius().getNormalizedValue());
} else if (args[i] instanceof Sphere s) {
args[i] = List.of(List.of(s.getCenter().getX(), s.getCenter().getY()), s.getRadius().getNormalizedValue());
} else if (args[i] instanceof Box b) {
args[i] = List.of(List.of(b.getFirst().getX(), b.getFirst().getY()),
List.of(b.getSecond().getX(), b.getSecond().getY()));
} else if (args[i] instanceof Polygon p) {
args[i] = p.getPoints().stream().map(it -> List.of(it.getX(), it.getY())).toList();
}
}
}
protected Consumer<Criteria> scoreBetween(Range.Bound<? extends Score> lower, Range.Bound<? extends Score> upper) {
return criteria -> {
if (lower.isBounded()) {
double value = lower.getValue().get().getValue();
if (lower.isInclusive()) {
criteria.gte(value);
} else {
criteria.gt(value);
}
}
if (upper.isBounded()) {
double value = upper.getValue().get().getValue();
if (upper.isInclusive()) {
criteria.lte(value);
} else {
criteria.lt(value);
}
}
};
}
protected ScoringFunction scoringFunction(@Nullable Range<? extends Score> scoreRange) {
if (scoreRange != null) {
if (scoreRange.getUpperBound().isBounded()) {
return scoreRange.getUpperBound().getValue().get().getFunction();
}
if (scoreRange.getLowerBound().isBounded()) {
return scoreRange.getLowerBound().getValue().get().getFunction();
}
}
return ScoringFunction.unspecified();
}
protected Collation collationOf(@Nullable Object source) {
if (source == null) {
return Collation.simple();
}
if (source instanceof CharSequence) {
return Collation.parse(source.toString());
}
if (source instanceof Locale locale) {
return Collation.of(locale);
}
if (source instanceof Document document) {
return Collation.from(document);
}
if (source instanceof Collation collation) {
return collation;
}
throw new IllegalArgumentException(
"Unsupported collation source [%s]".formatted(ObjectUtils.nullSafeClassName(source)));
}
protected Object toRegex(Object source) {
return toRegex(source, null);
}
protected Object toRegex(Object source, @Nullable String options) {
if (source instanceof String sv) {
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, MatchMode.LIKE), options);
}
if (source instanceof Pattern pattern) {
return pattern;
}
if (source instanceof Collection<?> collection) {
return collection.stream().map(it -> toRegex(it, options)).toList();
}
if (ObjectUtils.isArray(source)) {
return toRegex(List.of(source), options);
}
return source;
}
protected BasicQuery createQuery(Method method, String queryString, Object... parameters) {
Document queryDocument = bindParameters(method, queryString, parameters);
return new BasicQuery(queryDocument);
}
@SuppressWarnings("NullAway")
protected AggregationPipeline createPipeline(List<Object> rawStages) {
if (rawStages.isEmpty()) {
return new AggregationPipeline(List.of());
}
int size = rawStages.size();
List<AggregationOperation> stages = new ArrayList<>(size);
Object firstElement = CollectionUtils.firstElement(rawStages);
stages.add(rawToAggregationOperation(firstElement, true));
if (size == 1) {
return new AggregationPipeline(stages);
}
for (int i = 1; i < size; i++) {
stages.add(rawToAggregationOperation(rawStages.get(i), false));
}
return new AggregationPipeline(stages);
}
private static AggregationOperation rawToAggregationOperation(Object rawStage, boolean requiresMapping) {
if (rawStage instanceof Document stageDocument) {
if (requiresMapping) {
return (ctx) -> ctx.getMappedObject(stageDocument);
} else {
return (ctx) -> stageDocument;
}
}
if (rawStage instanceof AggregationOperation aggregationOperation) {
return aggregationOperation;
}
throw new RuntimeException("%s cannot be converted to AggregationOperation".formatted(rawStage.getClass()));
}
protected List<Object> convertSimpleRawResults(Class<?> targetType, List<Document> rawResults) {
List<Object> list = new ArrayList<>(rawResults.size());
for (Document it : rawResults) {
list.add(extractSimpleTypeResult(it, targetType, mongoConverter));
}
return list;
}
protected @Nullable Object convertSimpleRawResult(Class<?> targetType, Document rawResult) {
return extractSimpleTypeResult(rawResult, targetType, mongoConverter);
}
private static <T> @Nullable T extractSimpleTypeResult(@Nullable Document source, Class<T> targetType,
MongoConverter converter) {
if (ObjectUtils.isEmpty(source)) {
return null;
}
if (source.size() == 1) {
return getPotentiallyConvertedSimpleTypeValue(converter, source.values().iterator().next(), targetType);
}
Document intermediate = new Document(source);
intermediate.remove(FieldName.ID.name());
if (intermediate.size() == 1) {
return getPotentiallyConvertedSimpleTypeValue(converter, intermediate.values().iterator().next(), targetType);
}
for (Map.Entry<String, Object> entry : intermediate.entrySet()) {
if (entry != null && ClassUtils.isAssignable(targetType, entry.getValue().getClass())) {
return targetType.cast(entry.getValue());
}
}
throw new IllegalArgumentException(
String.format("o_O no entry of type %s found in %s.", targetType.getSimpleName(), source.toJson()));
}
@Nullable
@SuppressWarnings("unchecked")
private static <T> T getPotentiallyConvertedSimpleTypeValue(MongoConverter converter, @Nullable Object value,
Class<T> targetType) {
if (value == null) {
return null;
}
if (ClassUtils.isAssignableValue(targetType, value)) {
return (T) value;
}
return converter.getConversionService().convert(value, targetType);
}
}