QuerydslKeyValuePredicateExecutor.java
/*
* Copyright 2021-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.keyvalue.repository.support;
import static org.springframework.data.keyvalue.repository.support.KeyValueQuerydslUtils.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.jspecify.annotations.Nullable;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.data.convert.DtoInstantiatingConverter;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.keyvalue.core.IterableConverter;
import org.springframework.data.keyvalue.core.KeyValueOperations;
import org.springframework.data.mapping.PersistentEntity;
import org.springframework.data.mapping.PersistentProperty;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.mapping.model.EntityInstantiators;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.querydsl.EntityPathResolver;
import org.springframework.data.querydsl.ListQuerydslPredicateExecutor;
import org.springframework.data.querydsl.QuerydslPredicateExecutor;
import org.springframework.data.querydsl.SimpleEntityPathResolver;
import org.springframework.data.repository.core.EntityInformation;
import org.springframework.data.repository.query.FluentQuery;
import org.springframework.util.Assert;
import com.querydsl.collections.AbstractCollQuery;
import com.querydsl.collections.CollQuery;
import com.querydsl.core.NonUniqueResultException;
import com.querydsl.core.QueryResults;
import com.querydsl.core.types.EntityPath;
import com.querydsl.core.types.OrderSpecifier;
import com.querydsl.core.types.Predicate;
import com.querydsl.core.types.dsl.PathBuilder;
/**
* {@link QuerydslPredicateExecutor} capable of applying {@link Predicate}s using {@link CollQuery}.
*
* @author Mark Paluch
* @since 2.6
*/
public class QuerydslKeyValuePredicateExecutor<T> implements ListQuerydslPredicateExecutor<T> {
private static final EntityPathResolver DEFAULT_ENTITY_PATH_RESOLVER = SimpleEntityPathResolver.INSTANCE;
private final MappingContext<? extends PersistentEntity<?, ?>, ? extends PersistentProperty<?>> context;
private final PathBuilder<T> builder;
private final Supplier<List<T>> findAll;
private final EntityInformation<T, ?> entityInformation;
private final ProjectionFactory projectionFactory;
private final EntityInstantiators entityInstantiators = new EntityInstantiators();
/**
* Creates a new {@link QuerydslKeyValuePredicateExecutor} for the given {@link EntityInformation}.
*
* @param entityInformation must not be {@literal null}.
* @param operations must not be {@literal null}.
*/
public QuerydslKeyValuePredicateExecutor(EntityInformation<T, ?> entityInformation, KeyValueOperations operations) {
this(entityInformation, new SpelAwareProxyProjectionFactory(), operations, DEFAULT_ENTITY_PATH_RESOLVER);
}
/**
* Creates a new {@link QuerydslKeyValuePredicateExecutor} for the given {@link EntityInformation}, and
* {@link EntityPathResolver}.
*
* @param entityInformation must not be {@literal null}.
* @param projectionFactory must not be {@literal null}.
* @param operations must not be {@literal null}.
* @param resolver must not be {@literal null}.
*/
public QuerydslKeyValuePredicateExecutor(EntityInformation<T, ?> entityInformation,
ProjectionFactory projectionFactory, KeyValueOperations operations,
EntityPathResolver resolver) {
Assert.notNull(entityInformation, "EntityInformation must not be null");
Assert.notNull(projectionFactory, "ProjectionFactory must not be null");
Assert.notNull(operations, "KeyValueOperations must not be null");
Assert.notNull(resolver, "EntityPathResolver must not be null");
this.projectionFactory = projectionFactory;
this.context = operations.getMappingContext();
EntityPath<T> path = resolver.createPath(entityInformation.getJavaType());
this.builder = new PathBuilder<>(path.getType(), path.getMetadata());
this.entityInformation = entityInformation;
findAll = () -> IterableConverter.toList(operations.findAll(entityInformation.getJavaType()));
}
@Override
public Optional<T> findOne(Predicate predicate) {
Assert.notNull(predicate, "Predicate must not be null");
try {
return Optional.ofNullable(prepareQuery(predicate).fetchOne());
} catch (NonUniqueResultException o_O) {
throw new IncorrectResultSizeDataAccessException("Expected one or no result but found more than one", 1, o_O);
}
}
@Override
public List<T> findAll(Predicate predicate) {
Assert.notNull(predicate, "Predicate must not be null");
return prepareQuery(predicate).fetchResults().getResults();
}
@Override
public List<T> findAll(Predicate predicate, OrderSpecifier<?>... orders) {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(orders, "OrderSpecifiers must not be null");
AbstractCollQuery<T, ?> query = prepareQuery(predicate);
query.orderBy(orders);
return query.fetchResults().getResults();
}
@Override
public List<T> findAll(Predicate predicate, Sort sort) {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(sort, "Sort must not be null");
return findAll(predicate, toOrderSpecifier(sort, builder));
}
@Override
public Page<T> findAll(Predicate predicate, Pageable pageable) {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(pageable, "Pageable must not be null");
AbstractCollQuery<T, ?> query = prepareQuery(predicate);
if (pageable.isPaged() || pageable.getSort().isSorted()) {
query.offset(pageable.getOffset());
query.limit(pageable.getPageSize());
if (pageable.getSort().isSorted()) {
query.orderBy(toOrderSpecifier(pageable.getSort(), builder));
}
}
return new PageImpl<>(query.fetchResults().getResults(), pageable, count(predicate));
}
@Override
public List<T> findAll(OrderSpecifier<?>... orders) {
Assert.notNull(orders, "OrderSpecifiers must not be null");
if (orders.length == 0) {
return findAll.get();
}
AbstractCollQuery<T, ?> query = prepareQuery(null);
query.orderBy(orders);
return query.fetchResults().getResults();
}
@Override
public long count(Predicate predicate) {
Assert.notNull(predicate, "Predicate must not be null");
return prepareQuery(predicate).fetchCount();
}
@Override
public boolean exists(Predicate predicate) {
Assert.notNull(predicate, "Predicate must not be null");
return count(predicate) > 0;
}
@Override
@SuppressWarnings("unchecked")
public <S extends T, R> R findBy(Predicate predicate,
Function<FluentQuery.FetchableFluentQuery<S>, R> queryFunction) {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(queryFunction, "Query function must not be null");
return queryFunction.apply(new FluentQuerydsl<>(predicate, (Class<S>) entityInformation.getJavaType()));
}
/**
* Creates executable query for given {@link Predicate}.
*
* @param predicate
* @return
*/
protected AbstractCollQuery<T, ?> prepareQuery(@Nullable Predicate predicate) {
CollQuery<T> query = new CollQuery<>();
query.from(builder, findAll.get());
return predicate != null ? query.where(predicate) : query;
}
/**
* {@link org.springframework.data.repository.query.FluentQuery.FetchableFluentQuery} using Querydsl
* {@link Predicate}.
*
* @author Mark Paluch
* @since 2.6
*/
class FluentQuerydsl<R> implements FluentQuery.FetchableFluentQuery<R> {
private final Predicate predicate;
private final Sort sort;
private final Class<?> entityType;
private final Class<R> resultType;
private final List<String> fieldsToInclude;
FluentQuerydsl(Predicate predicate, Class<R> resultType) {
this(predicate, Sort.unsorted(), resultType, resultType, Collections.emptyList());
}
public FluentQuerydsl(Predicate predicate, Sort sort, Class<?> entityType, Class<R> resultType,
List<String> fieldsToInclude) {
this.predicate = predicate;
this.sort = sort;
this.entityType = entityType;
this.resultType = resultType;
this.fieldsToInclude = fieldsToInclude;
}
@Override
public FluentQuery.FetchableFluentQuery<R> sortBy(Sort sort) {
Assert.notNull(sort, "Sort must not be null");
return new FluentQuerydsl<>(predicate, sort, entityType, resultType, fieldsToInclude);
}
@Override
public <NR> FluentQuery.FetchableFluentQuery<NR> as(Class<NR> projection) {
Assert.notNull(projection, "Projection target type must not be null");
return new FluentQuerydsl<>(predicate, sort, entityType, projection, fieldsToInclude);
}
public FluentQuery.FetchableFluentQuery<R> project(Collection<String> properties) {
Assert.notNull(properties, "Projection properties must not be null");
return new FluentQuerydsl<>(predicate, sort, entityType, resultType, new ArrayList<>(properties));
}
@Override
public @Nullable R oneValue() {
List<T> results = createQuery().limit(2).fetch();
if (results.isEmpty()) {
return null;
}
if (results.size() > 1) {
throw new IncorrectResultSizeDataAccessException(1);
}
T one = results.get(0);
return getConversionFunction().apply(one);
}
@Override
public @Nullable R firstValue() {
List<T> results = createQuery().limit(1).fetch();
if (results.isEmpty()) {
return null;
}
T one = results.get(0);
return getConversionFunction().apply(one);
}
@Override
public List<R> all() {
List<T> results = createQuery().fetch();
return mapResults(results);
}
@Override
public Page<R> page(Pageable pageable) {
Assert.notNull(pageable, "Pageable must not be null");
AbstractCollQuery<T, ?> query = createQuery();
if (pageable.isPaged() || pageable.getSort().isSorted()) {
query.offset(pageable.getOffset());
query.limit(pageable.getPageSize());
if (pageable.getSort().isSorted()) {
query.orderBy(toOrderSpecifier(pageable.getSort(), builder));
}
}
QueryResults<T> results = query.limit(pageable.getPageSize()).offset(pageable.getOffset()).fetchResults();
return new PageImpl<>(mapResults(results.getResults()), pageable, results.getTotal());
}
@Override
public Stream<R> stream() {
return createQuery().stream().map(getConversionFunction());
}
@Override
public long count() {
return createQuery().fetchCount();
}
@Override
public boolean exists() {
return count() > 0;
}
private AbstractCollQuery<T, ?> createQuery() {
AbstractCollQuery<T, ?> query = prepareQuery(predicate);
if (sort.isSorted()) {
query.orderBy(toOrderSpecifier(sort, builder));
}
return query;
}
@SuppressWarnings("unchecked")
private List<R> mapResults(List<T> results) {
if (entityType == resultType) {
return (List<R>) results;
}
List<R> mapped = new ArrayList<>(results.size());
Function<Object, R> converter = getConversionFunction();
for (T result : results) {
mapped.add(converter.apply(result));
}
return mapped;
}
@SuppressWarnings("unchecked")
private <P> Function<Object, P> getConversionFunction(Class<?> inputType, Class<P> targetType) {
if (targetType.isAssignableFrom(inputType)) {
return (Function<Object, P>) Function.identity();
}
if (targetType.isInterface()) {
return o -> projectionFactory.createProjection(targetType, o);
}
DtoInstantiatingConverter converter = new DtoInstantiatingConverter(targetType, context, entityInstantiators);
return o -> (P) converter.convert(o);
}
private Function<Object, R> getConversionFunction() {
return getConversionFunction(entityType, resultType);
}
}
}