JSqlParserQueryEnhancer.java
/*
* Copyright 2022-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.jpa.repository.query;
import static org.springframework.data.jpa.repository.query.JSqlParserUtils.*;
import static org.springframework.data.jpa.repository.query.QueryUtils.*;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.parser.CCJSqlParser;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.parser.ParseException;
import net.sf.jsqlparser.parser.feature.Feature;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.merge.Merge;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.OrderByElement;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.select.Values;
import net.sf.jsqlparser.statement.update.Update;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.StringJoiner;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.jspecify.annotations.Nullable;
import org.springframework.data.domain.Sort;
import org.springframework.data.util.Predicates;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.SerializationUtils;
import org.springframework.util.StringUtils;
/**
* The implementation of {@link QueryEnhancer} using JSqlParser.
*
* @author Diego Krupitza
* @author Greg Turnquist
* @author Geoffrey Deremetz
* @author Yanming Zhou
* @author Christoph Strobl
* @author Diego Pedregal
* @since 2.7.0
*/
public class JSqlParserQueryEnhancer implements QueryEnhancer {
private final QueryProvider query;
private final ParsedType parsedType;
private final boolean hasConstructorExpression;
private final @Nullable String primaryAlias;
private final String projection;
private final Set<String> joinAliases;
private final Set<String> selectAliases;
private final byte @Nullable [] serialized;
/**
* @param query the query we want to enhance. Must not be {@literal null}.
*/
public JSqlParserQueryEnhancer(QueryProvider query) {
this.query = query;
Statement statement = parseStatement(query.getQueryString(), Statement.class);
this.parsedType = detectParsedType(statement);
this.hasConstructorExpression = QueryUtils.hasConstructorExpression(query.getQueryString());
this.primaryAlias = detectAlias(this.parsedType, statement);
this.projection = detectProjection(statement);
this.selectAliases = Collections.unmodifiableSet(getSelectionAliases(statement));
this.joinAliases = Collections.unmodifiableSet(getJoinAliases(statement));
this.serialized = SerializationUtils.serialize(statement);
}
/**
* Parses a query string with JSqlParser.
*
* @param sql the query to parse
* @param classOfT the query to parse
* @return the parsed query
*/
static <T extends Statement> T parseStatement(String sql, Class<T> classOfT) {
try {
CCJSqlParser parser = CCJSqlParserUtil.newParser(sql);
boolean allowComplex = parser.getConfiguration().getAsBoolean(Feature.allowComplexParsing);
try {
return classOfT.cast(parser.withAllowComplexParsing(true).Statement());
} catch (ParseException ex) {
if (allowComplex && CCJSqlParserUtil.getNestingDepth(sql) <= CCJSqlParserUtil.ALLOWED_NESTING_DEPTH) {
// beware: the parser must not be reused, but needs to be re-initiated
parser = CCJSqlParserUtil.newParser(sql);
return classOfT.cast(parser.withAllowComplexParsing(true).Statement());
} else {
throw ex;
}
}
} catch (ParseException e) {
throw new IllegalArgumentException("The query you provided is not a valid SQL Query", e);
}
}
/**
* Resolves the alias for the entity to be retrieved from the given JPA query. Note that you only provide valid Query
* strings. Things such as <code>from User u</code> will throw an {@link IllegalArgumentException}.
*
* @return Might return {@literal null}.
*/
private static @Nullable String detectAlias(ParsedType parsedType, Statement statement) {
if (ParsedType.MERGE.equals(parsedType)) {
Merge mergeStatement = (Merge) statement;
Alias alias = mergeStatement.getUsingAlias();
return alias == null ? null : alias.getName();
}
if (ParsedType.SELECT.equals(parsedType)) {
return doWithPlainSelect(statement, it -> it.getFromItem() == null || it.getFromItem().getAlias() == null,
it -> it.getFromItem().getAlias().getName(), () -> null);
}
return null;
}
/**
* Returns the aliases used inside the selection part in the query.
*
* @return a {@literal Set} containing all found aliases. Guaranteed to be not {@literal null}.
*/
private static Set<String> getSelectionAliases(Statement statement) {
if (statement instanceof SetOperationList sel) {
statement = sel.getSelect(0);
}
return doWithPlainSelect(statement, it -> CollectionUtils.isEmpty(it.getSelectItems()), it -> {
Set<String> set = new HashSet<>(it.getSelectItems().size(), 1.0f);
for (SelectItem<?> selectItem : it.getSelectItems()) {
Alias alias = selectItem.getAlias();
if (alias != null) {
set.add(alias.getName());
}
}
return set;
}, Collections::emptySet);
}
/**
* Returns the aliases used for {@code join}s.
*
* @return a {@literal Set} of aliases used in the query. Guaranteed to be not {@literal null}.
*/
private static Set<String> getJoinAliases(Statement statement) {
if (statement instanceof SetOperationList sel) {
statement = sel.getSelect(0);
}
return doWithPlainSelect(statement, it -> CollectionUtils.isEmpty(it.getJoins()), it -> {
Set<String> set = new HashSet<>(it.getJoins().size(), 1.0f);
for (Join join : it.getJoins()) {
Alias alias = join.getRightItem().getAlias();
if (alias != null) {
set.add(alias.getName());
}
}
return set;
}, Collections::emptySet);
}
/**
* Apply a {@link java.util.function.Function mapping function} to the {@link PlainSelect} of the given
* {@link Statement} is or contains a {@link PlainSelect}.
*
* @param statement
* @param mapper
* @param fallback
* @param <T>
* @return
*/
private static <T> T doWithPlainSelect(Statement statement, java.util.function.Function<PlainSelect, T> mapper,
Supplier<T> fallback) {
Predicate<PlainSelect> neverSkip = Predicates.isFalse();
return doWithPlainSelect(statement, neverSkip, mapper, fallback);
}
/**
* Apply a {@link java.util.function.Function mapping function} to the {@link PlainSelect} of the given
* {@link Statement} is or contains a {@link PlainSelect}.
* <p>
* The operation is only applied if {@link Predicate skipIf} returns {@literal false} for the given statement
* returning the fallback value from {@code fallback}.
*
* @param statement
* @param skipIf
* @param mapper
* @param fallback
* @param <T>
* @return
*/
private static <T> T doWithPlainSelect(Statement statement, Predicate<PlainSelect> skipIf,
java.util.function.Function<PlainSelect, T> mapper, Supplier<T> fallback) {
if (!(statement instanceof Select select)) {
return fallback.get();
}
try {
if (skipIf.test(select.getPlainSelect())) {
return fallback.get();
}
}
// e.g. SetOperationList is a subclass of Select but it is not a PlainSelect
catch (ClassCastException e) {
return fallback.get();
}
return mapper.apply(select.getPlainSelect());
}
private static String detectProjection(Statement statement) {
if (!(statement instanceof Select select)) {
return "";
}
if (select instanceof Values) {
return "";
}
Select selectBody = select;
if (select instanceof SetOperationList setOperationList) {
// using the first one since for setoperations the projection has to be the same
selectBody = setOperationList.getSelects().get(0);
}
return doWithPlainSelect(selectBody, it -> CollectionUtils.isEmpty(it.getSelectItems()), it -> {
StringJoiner joiner = new StringJoiner(", ");
for (SelectItem<?> selectItem : it.getSelectItems()) {
joiner.add(selectItem.toString());
}
return joiner.toString().trim();
}, () -> "");
}
/**
* Detects what type of query is provided.
*
* @return the parsed type
*/
private static ParsedType detectParsedType(Statement statement) {
if (statement instanceof Insert) {
return ParsedType.INSERT;
} else if (statement instanceof Update) {
return ParsedType.UPDATE;
} else if (statement instanceof Delete) {
return ParsedType.DELETE;
} else if (statement instanceof Select) {
return ParsedType.SELECT;
} else if (statement instanceof Merge) {
return ParsedType.MERGE;
} else {
return ParsedType.OTHER;
}
}
@Override
public boolean hasConstructorExpression() {
return hasConstructorExpression;
}
@Override
public @Nullable String detectAlias() {
return this.primaryAlias;
}
@Override
public String getProjection() {
return this.projection;
}
public Set<String> getSelectionAliases() {
return selectAliases;
}
@Override
public QueryProvider getQuery() {
return this.query;
}
@Override
public String rewrite(QueryRewriteInformation rewriteInformation) {
return doApplySorting(rewriteInformation.getSort(), primaryAlias);
}
private String doApplySorting(Sort sort, @Nullable String alias) {
String queryString = query.getQueryString();
Assert.hasText(queryString, "Query must not be null or empty");
if (this.parsedType != ParsedType.SELECT || sort.isUnsorted()) {
return queryString;
}
return applySorting(deserializeRequired(this.serialized, Select.class), sort, alias);
}
private String applySorting(@Nullable Select selectStatement, Sort sort, @Nullable String alias) {
Assert.notNull(selectStatement, "SelectStatement must not be null");
if (selectStatement instanceof SetOperationList setOperationList) {
return applySortingToSetOperationList(setOperationList, sort);
}
doWithPlainSelect(selectStatement, it -> {
List<OrderByElement> orderByElements = new ArrayList<>(16);
for (Sort.Order order : sort) {
orderByElements.add(getOrderClause(joinAliases, selectAliases, alias, order));
}
if (CollectionUtils.isEmpty(it.getOrderByElements())) {
it.setOrderByElements(orderByElements);
} else {
it.getOrderByElements().addAll(orderByElements);
}
return null;
}, () -> "");
return selectStatement.toString();
}
@Override
@SuppressWarnings("NullAway")
public String createCountQueryFor(@Nullable String countProjection) {
if (this.parsedType != ParsedType.SELECT) {
return this.query.getQueryString();
}
Assert.hasText(this.query.getQueryString(), "OriginalQuery must not be null or empty");
Statement statement = (Statement) deserialize(this.serialized);
return doWithPlainSelect(statement, it -> createCountQueryFor(it, countProjection, primaryAlias),
this.query::getQueryString);
}
private static String createCountQueryFor(PlainSelect selectBody, @Nullable String countProjection,
@Nullable String primaryAlias) {
// remove order by
selectBody.setOrderByElements(null);
if (StringUtils.hasText(countProjection)) {
selectBody.setSelectItems(
Collections.singletonList(SelectItem.from(getJSqlCount(Collections.singletonList(countProjection), false))));
} else {
boolean distinct = selectBody.getDistinct() != null;
selectBody.setDistinct(null); // reset possible distinct
Function jSqlCount = getJSqlCount(
Collections.singletonList(countPropertyNameForSelection(selectBody.getSelectItems(), distinct, primaryAlias)),
distinct);
selectBody.setSelectItems(Collections.singletonList(SelectItem.from(jSqlCount)));
}
return selectBody.toString();
}
/**
* Returns the {@link SetOperationList} as a string query with {@link Sort}s applied in the right order.
*
* @param setOperationListStatement
* @param sort
* @return
*/
private static String applySortingToSetOperationList(SetOperationList setOperationListStatement, Sort sort) {
// special case: ValuesStatements are detected as nested OperationListStatements
for (Select select : setOperationListStatement.getSelects()) {
if (select instanceof Values) {
return setOperationListStatement.toString();
}
}
List<OrderByElement> orderByElements = new ArrayList<>(16);
for (Sort.Order order : sort) {
orderByElements.add(getOrderClause(Collections.emptySet(), Collections.emptySet(), null, order));
}
if (setOperationListStatement.getOrderByElements() == null) {
setOperationListStatement.setOrderByElements(orderByElements);
} else {
setOperationListStatement.getOrderByElements().addAll(orderByElements);
}
return setOperationListStatement.toString();
}
/**
* Returns the order clause for the given {@link Sort.Order}. Will prefix the clause with the given alias if the
* referenced property refers to a join alias, i.e. starts with {@code $alias.}.
*
* @param joinAliases the join aliases of the original query. Must not be {@literal null}.
* @param alias the alias for the root entity. May be {@literal null}.
* @param order the order object to build the clause for. Must not be {@literal null}.
* @return a {@link OrderByElement} containing an order clause. Guaranteed to be not {@literal null}.
*/
private static OrderByElement getOrderClause(Set<String> joinAliases, Set<String> selectionAliases,
@Nullable String alias, Sort.Order order) {
OrderByElement orderByElement = new OrderByElement();
orderByElement.setAsc(order.getDirection().isAscending());
orderByElement.setAscDescPresent(true);
String property = order.getProperty();
checkSortExpression(order);
if (selectionAliases.contains(property)) {
Expression orderExpression = order.isIgnoreCase() ? getJSqlLower(property) : new Column(property);
orderByElement.setExpression(orderExpression);
return orderByElement;
}
boolean qualifyReference = true;
for (String joinAlias : joinAliases) {
if (property.startsWith(joinAlias.concat("."))) {
qualifyReference = false;
break;
}
}
boolean functionIndicator = property.contains("(");
String reference = qualifyReference && !functionIndicator && StringUtils.hasText(alias) ? alias + "." + property
: property;
Expression orderExpression = order.isIgnoreCase() ? getJSqlLower(reference) : new Column(reference);
orderByElement.setExpression(orderExpression);
switch (order.getNullHandling()) {
case NULLS_FIRST -> orderByElement.setNullOrdering(OrderByElement.NullOrdering.NULLS_FIRST);
case NULLS_LAST -> orderByElement.setNullOrdering(OrderByElement.NullOrdering.NULLS_LAST);
default -> {
// do nothing
}
}
return orderByElement;
}
/**
* Get the count property if present in {@link SelectItem slected items}, {@literal *} or {@literal 1} for native ones
* and {@literal *} or the given {@literal tableAlias}.
*
* @param selectItems items from the select.
* @param distinct indicator if query for distinct values.
* @param tableAlias the table alias which can be {@literal null}.
* @return
*/
private static String countPropertyNameForSelection(List<SelectItem<?>> selectItems, boolean distinct,
@Nullable String tableAlias) {
if (onlyASingleColumnProjection(selectItems)) {
SelectItem<?> singleProjection = selectItems.get(0);
Column column = (Column) singleProjection.getExpression();
return column.getFullyQualifiedName();
}
return distinct ? ((tableAlias != null ? tableAlias + "." : "") + "*") : "1";
}
/**
* Checks whether a given projection only contains a single column definition (aka without functions, etc.)
*
* @param projection the projection to analyse.
* @return {@code true} when the projection only contains a single column definition otherwise {@code false}.
*/
private static boolean onlyASingleColumnProjection(List<SelectItem<?>> projection) {
// this is unfortunately the only way to check without any hacky & hard string regex magic
return projection.size() == 1 && projection.get(0) instanceof SelectItem<?>
&& ((projection.get(0)).getExpression()) instanceof Column;
}
/**
* An enum to represent the top level parsed statement of the provided query.
* <ul>
* <li>{@code ParsedType.DELETE}: means the top level statement is {@link Delete}</li>
* <li>{@code ParsedType.UPDATE}: means the top level statement is {@link Update}</li>
* <li>{@code ParsedType.SELECT}: means the top level statement is {@link Select}</li>
* <li>{@code ParsedType.INSERT}: means the top level statement is {@link Insert}</li>
* <li>{@code ParsedType.MERGE}: means the top level statement is {@link Merge}</li>
* <li>{@code ParsedType.OTHER}: means the top level statement is a different top-level type</li>
* </ul>
*/
enum ParsedType {
DELETE, UPDATE, SELECT, INSERT, MERGE, OTHER
}
/**
* Deserialize the byte array into an object.
*
* @param bytes a serialized object
* @return the result of deserializing the bytes
*/
private static @Nullable Object deserialize(byte @Nullable [] bytes) {
if (ObjectUtils.isEmpty(bytes)) {
return null;
}
try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
return ois.readObject();
} catch (IOException ex) {
throw new IllegalArgumentException("Failed to deserialize object", ex);
} catch (ClassNotFoundException ex) {
throw new IllegalStateException("Failed to deserialize object type", ex);
}
}
private static <T> T deserializeRequired(byte @Nullable [] bytes, Class<T> type) {
Object deserialize = deserialize(bytes);
if (deserialize != null) {
return type.cast(deserialize);
}
throw new IllegalStateException("Failed to deserialize object type");
}
}