ExpressionTreeUtils.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.analyzer;
import com.facebook.presto.UnknownTypeException;
import com.facebook.presto.common.type.EnumType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeWithName;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeLocation;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multimap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
public final class ExpressionTreeUtils
{
private ExpressionTreeUtils() {}
static List<FunctionCall> extractAggregateFunctions(
Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles,
Iterable<? extends Node> nodes,
FunctionAndTypeResolver functionAndTypeResolver)
{
return extractExpressions(nodes, FunctionCall.class, isAggregationPredicate(functionHandles, functionAndTypeResolver));
}
static List<FunctionCall> extractWindowFunctions(Iterable<? extends Node> nodes)
{
return extractExpressions(nodes, FunctionCall.class, ExpressionTreeUtils::isWindowFunction);
}
static List<FunctionCall> extractExternalFunctions(Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles, Iterable<? extends Node> nodes, FunctionAndTypeResolver functionAndTypeResolver)
{
return extractExpressions(nodes, FunctionCall.class, isExternalFunctionPredicate(functionHandles, functionAndTypeResolver));
}
public static <T extends Expression> List<T> extractExpressions(
Iterable<? extends Node> nodes,
Class<T> clazz)
{
return extractExpressions(nodes, clazz, alwaysTrue());
}
private static Predicate<FunctionCall> isAggregationPredicate(Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles, FunctionAndTypeResolver functionAndTypeResolver)
{
return functionCall -> (functionAndTypeResolver.getFunctionMetadata(functionHandles.get(NodeRef.of(functionCall))).getFunctionKind() == AGGREGATE || functionCall.getFilter().isPresent())
&& !functionCall.getWindow().isPresent()
|| functionCall.getOrderBy().isPresent();
}
private static boolean isWindowFunction(FunctionCall functionCall)
{
return functionCall.getWindow().isPresent();
}
private static Predicate<FunctionCall> isExternalFunctionPredicate(Map<NodeRef<FunctionCall>, FunctionHandle> functionHandles, FunctionAndTypeResolver functionAndTypeResolver)
{
return functionCall -> functionAndTypeResolver.getFunctionMetadata(functionHandles.get(NodeRef.of(functionCall))).getImplementationType().isExternalExecution();
}
private static <T extends Expression> List<T> extractExpressions(
Iterable<? extends Node> nodes,
Class<T> clazz,
Predicate<T> predicate)
{
requireNonNull(nodes, "nodes is null");
requireNonNull(clazz, "clazz is null");
requireNonNull(predicate, "predicate is null");
return ImmutableList.copyOf(nodes).stream()
.flatMap(node -> linearizeNodes(node).stream())
.filter(clazz::isInstance)
.map(clazz::cast)
.filter(predicate)
.collect(toImmutableList());
}
private static List<Node> linearizeNodes(Node node)
{
ImmutableList.Builder<Node> nodes = ImmutableList.builder();
new DefaultExpressionTraversalVisitor<Node, Void>()
{
@Override
public Node process(Node node, Void context)
{
Node result = super.process(node, context);
nodes.add(node);
return result;
}
}.process(node, null);
return nodes.build();
}
public static boolean isEqualComparisonExpression(Expression expression)
{
return expression instanceof ComparisonExpression && ((ComparisonExpression) expression).getOperator() == ComparisonExpression.Operator.EQUAL;
}
public static Optional<TypeWithName> tryResolveEnumLiteralType(QualifiedName qualifiedName, FunctionAndTypeResolver functionAndTypeResolver)
{
Optional<QualifiedName> prefix = qualifiedName.getPrefix();
if (!prefix.isPresent()) {
// an enum literal should be of the form `MyEnum.my_key`
return Optional.empty();
}
try {
Type baseType = functionAndTypeResolver.getType(parseTypeSignature(prefix.get().toString()));
if (baseType instanceof TypeWithName
&& ((TypeWithName) baseType).getType() instanceof EnumType
&& ((EnumType<?>) ((TypeWithName) baseType).getType()).getEnumMap().containsKey(qualifiedName.getSuffix().toUpperCase(ENGLISH))) {
return Optional.of((TypeWithName) baseType);
}
}
catch (IllegalArgumentException | UnknownTypeException e) {
return Optional.empty();
}
return Optional.empty();
}
public static Object resolveEnumLiteral(DereferenceExpression node, Type nodeType)
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);
EnumType enumType = (EnumType) ((TypeWithName) nodeType).getType();
String enumKey = qualifiedName.getSuffix().toUpperCase(ENGLISH);
checkArgument(enumType.getEnumMap().containsKey(enumKey), format("No key '%s' in enum '%s'", enumKey, nodeType.getDisplayName()));
Object enumValue = enumType.getEnumMap().get(enumKey);
return enumValue instanceof String ? utf8Slice((String) enumValue) : enumValue;
}
public static FieldId checkAndGetColumnReferenceField(Expression expression, Multimap<NodeRef<Expression>, FieldId> columnReferences)
{
checkState(columnReferences.containsKey(NodeRef.of(expression)), "Missing field reference for expression");
checkState(columnReferences.get(NodeRef.of(expression)).size() == 1, "Multiple field references for expression");
return columnReferences.get(NodeRef.of(expression)).iterator().next();
}
public static boolean isNonNullConstant(Expression expression)
{
Expression tempExpression = expression;
while (tempExpression instanceof Cast) {
tempExpression = ((Cast) tempExpression).getExpression();
}
if (tempExpression instanceof NullLiteral) {
return false;
}
// now allow for things like ARRAY, ROW(...) where null is OK
return isConstant(tempExpression);
}
public static boolean isConstant(Expression expression)
{
Expression tempExpression = expression;
while (tempExpression instanceof Cast) {
tempExpression = ((Cast) tempExpression).getExpression();
}
if (tempExpression instanceof Literal) {
return true;
}
if (tempExpression instanceof ArrayConstructor) {
return ((ArrayConstructor) tempExpression).getValues().stream().allMatch(ExpressionTreeUtils::isConstant);
}
// ROW an MAP are special so we explicitly do that here.
if (tempExpression instanceof Row) {
return (((Row) tempExpression).getItems().stream().allMatch(ExpressionTreeUtils::isConstant));
}
if (tempExpression instanceof FunctionCall) {
// Hack to just allow map constructor
if (((FunctionCall) tempExpression).getName().getSuffix().equalsIgnoreCase("map")) {
return ((FunctionCall) tempExpression).getArguments().stream().allMatch(ExpressionTreeUtils::isConstant);
}
}
// Everything else is considered non-const
return false;
}
public static Optional<SourceLocation> getSourceLocation(Optional<NodeLocation> nodeLocation)
{
return nodeLocation.isPresent()
? Optional.of(new SourceLocation(nodeLocation.get().getLineNumber(), nodeLocation.get().getColumnNumber()))
: Optional.empty();
}
public static Optional<SourceLocation> getSourceLocation(Node node)
{
Optional<NodeLocation> nodeLocation = node.getLocation();
if (!node.getLocation().isPresent()) {
// See if any child has a location
nodeLocation = node.getChildren().stream()
.map(x -> x.getLocation())
.filter(Optional::isPresent)
.findFirst()
.map(x -> x.get());
}
return getSourceLocation(nodeLocation);
}
public static Optional<NodeLocation> getNodeLocation(Optional<SourceLocation> sourceLocation)
{
if (sourceLocation.isPresent()) {
return Optional.of(new NodeLocation(sourceLocation.get().getLine(), sourceLocation.get().getColumn()));
}
return Optional.empty();
}
public static SymbolReference createSymbolReference(VariableReferenceExpression variableReferenceExpression)
{
return new SymbolReference(getNodeLocation(variableReferenceExpression.getSourceLocation()), variableReferenceExpression.getName());
}
}