InequalityInference.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.optimizations.ExpressionEquivalence;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.EqualityInference.getLeft;
import static com.facebook.presto.sql.planner.EqualityInference.getRight;
import static com.facebook.presto.sql.planner.EqualityInference.isOperation;
import static com.facebook.presto.sql.planner.optimizations.ExpressionEquivalence.swapPair;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.in;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.filter;
import static java.util.Objects.requireNonNull;
public class InequalityInference
{
// inequalityExpressions include the inequalities from the current join predicate
// and those inherited from the join's parent
private final Set<RowExpression> inequalityExpressions;
private final FunctionAndTypeManager functionAndTypeManager;
private final ExpressionEquivalence expressionEquivalence;
// 'outerVariables' is the variable set projected from the outer input for left or right join,
// or EMPTY if the join is an inner join
private final Optional<Collection<VariableReferenceExpression>> outerVariables;
public InequalityInference(Set<RowExpression> inequalityExpressions, FunctionAndTypeManager functionAndTypeManager, ExpressionEquivalence expressionEquivalence, Optional<Collection<VariableReferenceExpression>> outerVariables)
{
if (inequalityExpressions.stream()
.anyMatch(e -> !isOperation(e, LESS_THAN, functionAndTypeManager) && !isOperation(e, LESS_THAN_OR_EQUAL, functionAndTypeManager))) {
throw new PrestoException(GENERIC_INTERNAL_ERROR, "all inequality expressions for inference must be < or <=");
}
this.inequalityExpressions = requireNonNull(inequalityExpressions, "inequalityExpressions is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.expressionEquivalence = requireNonNull(expressionEquivalence, "expressionEquivalence is null");
this.outerVariables = requireNonNull(outerVariables, "outerVariables is null");
}
// Basic idea is to pairwise compare each predicate with all other predicates.
// Since any inferred predicates may trigger other inferences, keep comparing until no new predicates are generated
public Set<RowExpression> inferInequalities()
{
if (inequalityExpressions.size() < 2) {
return ImmutableSet.of();
}
Set<RowExpression> allInferredInequalities = new HashSet<>();
Set<RowExpression> inequalitiesInferredInCurrentTraversal = new HashSet<>(inequalityExpressions);
Set<Set<RowExpression>> exploredCombinations = new HashSet();
while (!allInferredInequalities.containsAll(inequalitiesInferredInCurrentTraversal)) {
allInferredInequalities.addAll(inequalitiesInferredInCurrentTraversal);
Set<Set<RowExpression>> newCombinations = Sets.combinations(allInferredInequalities, 2).stream()
.filter(subset -> !exploredCombinations.contains(subset))
.collect(toImmutableSet());
inequalitiesInferredInCurrentTraversal = newCombinations.stream()
.map(pair -> {
Iterator<RowExpression> it = pair.iterator();
return compareAndExtractInequalities(it.next(), it.next());
})
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableSet());
exploredCombinations.addAll(newCombinations);
}
// we want to exclude all the original inequalities we passed in from the set we return
allInferredInequalities.removeAll(inequalityExpressions);
return ImmutableSet.copyOf(allInferredInequalities);
}
// If the join is an outer join, we need to make sure any inferred predicate is on the inner side only
// (i.e. does not reference any 'outerVariables').
private Optional<RowExpression> compareAndExtractInequalities(RowExpression expression1, RowExpression expression2)
{
CallExpression firstConjunct = (CallExpression) expression1;
OperatorType firstOperatorType = functionAndTypeManager.getFunctionMetadata(firstConjunct.getFunctionHandle()).getOperatorType().get();
CallExpression secondConjunct = (CallExpression) expression2;
OperatorType secondOperatorType = functionAndTypeManager.getFunctionMetadata(secondConjunct.getFunctionHandle()).getOperatorType().get();
// Get the inference operator, or empty if no inference possible.
Optional<OperatorType> inferenceOperatorType = getComparisonInferenceOperatorType(secondOperatorType, firstOperatorType);
if (!inferenceOperatorType.isPresent()) {
return Optional.empty();
}
// For two inequalities "a < b" and "b < 5" we want to infer "a < 5"
// Compare the call expressions' arguments to determine equivalence
// Make sure one input has no column/variable references since we only want to infer predicates with constants
// LHS : Left hand side
// RHS : Right hand side
RowExpression firstConjunctLHS = firstConjunct.getArguments().get(0);
RowExpression firstConjunctRHS = firstConjunct.getArguments().get(1);
RowExpression secondConjunctLHS = secondConjunct.getArguments().get(0);
RowExpression secondConjunctRHS = secondConjunct.getArguments().get(1);
Optional<RowExpression> inferredFirstArgument = Optional.empty();
Optional<RowExpression> inferredSecondArgument = Optional.empty();
Set<VariableReferenceExpression> variablesReferencedInInferredPredicate = ImmutableSet.of();
if (expressionEquivalence.areExpressionsEquivalent(firstConjunctRHS, secondConjunctLHS)) {
variablesReferencedInInferredPredicate = getVariablesReferencedInInferredPredicate(firstConjunctLHS, secondConjunctRHS);
if (!variablesReferencedInInferredPredicate.isEmpty()) {
inferredFirstArgument = Optional.of(firstConjunctLHS);
inferredSecondArgument = Optional.of(secondConjunctRHS);
}
}
else if (expressionEquivalence.areExpressionsEquivalent(firstConjunctLHS, secondConjunctRHS)) {
variablesReferencedInInferredPredicate = getVariablesReferencedInInferredPredicate(firstConjunctRHS, secondConjunctLHS);
if (!variablesReferencedInInferredPredicate.isEmpty()) {
inferredFirstArgument = Optional.of(secondConjunctLHS);
inferredSecondArgument = Optional.of(firstConjunctRHS);
}
}
if (!inferredFirstArgument.isPresent() ||
!inferredSecondArgument.isPresent() ||
(outerVariables.isPresent() && Iterables.any(variablesReferencedInInferredPredicate, in(outerVariables.get())))) {
return Optional.empty();
}
// Build and return the new predicate
FunctionHandle inferredComparatorFunctionHandle = functionAndTypeManager.resolveOperator(inferenceOperatorType.get(), fromTypes(inferredFirstArgument.get().getType(), inferredSecondArgument.get().getType()));
return Optional.of(new CallExpression(inferenceOperatorType.toString(), inferredComparatorFunctionHandle, BOOLEAN, ImmutableList.of(inferredFirstArgument.get(), inferredSecondArgument.get())));
}
// Get comparison operator for inferring comparison predicates given operators from
// two comparison predicates. E.g., '1<=a1 AND a1<a2' will return '<' as the operator.
// NULL is returned if no inference is possible.
private Optional<OperatorType> getComparisonInferenceOperatorType(OperatorType operator1, OperatorType operator2)
{
Optional<OperatorType> inferenceType = Optional.empty();
if (operator1.equals(LESS_THAN)) {
if (operator2.equals(LESS_THAN) || operator2.equals(LESS_THAN_OR_EQUAL)) {
inferenceType = Optional.of(LESS_THAN);
}
}
else if (operator1.equals(LESS_THAN_OR_EQUAL)) {
if (operator2.equals(LESS_THAN)) {
inferenceType = Optional.of(LESS_THAN);
}
else if (operator2.equals(LESS_THAN_OR_EQUAL)) {
inferenceType = Optional.of(LESS_THAN_OR_EQUAL);
}
}
return inferenceType;
}
private static Set<VariableReferenceExpression> getVariablesReferencedInInferredPredicate(RowExpression firstConjunct, RowExpression secondConjunct)
{
Set<VariableReferenceExpression> firstConjunctReferencedVariables = VariablesExtractor.extractUnique(firstConjunct);
if (firstConjunctReferencedVariables.isEmpty()) {
return VariablesExtractor.extractUnique(secondConjunct);
}
Set<VariableReferenceExpression> secondConjunctReferencedVariables = VariablesExtractor.extractUnique(secondConjunct);
if (secondConjunctReferencedVariables.isEmpty()) {
return firstConjunctReferencedVariables;
}
// inferred predicates cannot reference variables from both expressions
// return empty to indicate that there is no valid inferred predicate
return ImmutableSet.of();
}
public static class Builder
{
private final FunctionAndTypeManager functionAndTypeManager;
private final NullabilityAnalyzer nullabilityAnalyzer;
private final RowExpressionDeterminismEvaluator determinismEvaluator;
private final ExpressionEquivalence expressionEquivalence;
private final Set<RowExpression> inequalityExpressions = new HashSet<>();
private final Optional<Collection<VariableReferenceExpression>> outerVariables;
public Builder(FunctionAndTypeManager functionAndTypeManager, ExpressionEquivalence expressionEquivalence, Optional<Collection<VariableReferenceExpression>> outerVariables)
{
this.functionAndTypeManager = functionAndTypeManager;
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
this.expressionEquivalence = expressionEquivalence;
this.nullabilityAnalyzer = new NullabilityAnalyzer(functionAndTypeManager);
this.outerVariables = outerVariables;
}
public InequalityInference build()
{
return new InequalityInference(inequalityExpressions, functionAndTypeManager, expressionEquivalence, outerVariables);
}
public Builder addInequalityInferences(RowExpression... expressions)
{
for (RowExpression expression : expressions) {
extractInequalityInferenceCandidates(expression);
}
return this;
}
private Builder extractInequalityInferenceCandidates(RowExpression expression)
{
Iterable<RowExpression> candidates = filter(extractConjuncts(expression), isInequalityInferenceCandidate());
for (RowExpression conjunct : candidates) {
addInequalityInferenceCandidate(conjunct);
}
return this;
}
private InequalityInference.Builder addInequalityInferenceCandidate(RowExpression expression)
{
checkArgument(isInequalityInferenceCandidate().apply(expression), "RowExpression: " + expression + " is not an inequality inference candidate");
inequalityExpressions.add(canonicalizeInequality(expression));
return this;
}
// Convert all inequalities to LESS_THAN or LESS_THAN_OR_EQUAL
RowExpression canonicalizeInequality(RowExpression expression)
{
if (isOperation(expression, GREATER_THAN, functionAndTypeManager) ||
isOperation(expression, GREATER_THAN_OR_EQUAL, functionAndTypeManager)) {
CallExpression callExpression = (CallExpression) expression;
OperatorType operatorType = functionAndTypeManager.getFunctionMetadata(callExpression.getFunctionHandle()).getOperatorType().get();
operatorType = OperatorType.flip(operatorType);
FunctionHandle functionHandle = functionAndTypeManager.resolveOperator(operatorType, swapPair(fromTypes(callExpression.getArguments().stream().map(RowExpression::getType).collect(toImmutableList()))));
expression = new CallExpression(operatorType.getOperator(), functionHandle, BOOLEAN, swapPair(callExpression.getArguments()));
}
return expression;
}
private Predicate<RowExpression> isInequalityInferenceCandidate()
{
return expression -> (isOperation(expression, GREATER_THAN_OR_EQUAL, functionAndTypeManager) ||
isOperation(expression, GREATER_THAN, functionAndTypeManager) ||
isOperation(expression, LESS_THAN_OR_EQUAL, functionAndTypeManager) ||
isOperation(expression, LESS_THAN, functionAndTypeManager)) &&
determinismEvaluator.isDeterministic(expression) &&
!nullabilityAnalyzer.mayReturnNullOnNonNullInput(expression) &&
!getLeft(expression).equals(getRight(expression));
}
}
}