EqualityInference.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.expressions.RowExpressionNodeInliner;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.util.DisjointSet;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.SetMultimap;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.uniqueSubExpressions;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.equalTo;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.Iterables.filter;
import static java.util.Objects.requireNonNull;

public class EqualityInference
{
    // Ordering used to determine Expression preference when determining canonicals
    private static final Ordering<RowExpression> CANONICAL_ORDERING = Ordering.from((expression1, expression2) -> {
        // Current cost heuristic:
        // 1) Prefer fewer input symbols
        // 2) Prefer smaller expression trees
        // 3) Sort the expressions alphabetically - creates a stable consistent ordering (extremely useful for unit testing)
        // TODO: be more precise in determining the cost of an RowExpression
        return ComparisonChain.start()
                .compare(VariablesExtractor.extractAll(expression1).size(), VariablesExtractor.extractAll(expression2).size())
                .compare(uniqueSubExpressions(expression1).size(), uniqueSubExpressions(expression2).size())
                .compare(expression1.toString(), expression2.toString())
                .result();
    });

    private final SetMultimap<RowExpression, RowExpression> equalitySets; // Indexed by canonical RowExpression
    private final Map<RowExpression, RowExpression> canonicalMap; // Map each known RowExpression to canonical RowExpression
    private final Set<RowExpression> derivedExpressions;
    private final RowExpressionDeterminismEvaluator determinismEvaluator;
    private final FunctionAndTypeManager functionAndTypeManager;

    private EqualityInference(
            Iterable<Set<RowExpression>> equalityGroups,
            Set<RowExpression> derivedExpressions,
            RowExpressionDeterminismEvaluator determinismEvaluator,
            FunctionAndTypeManager functionAndTypeManager)
    {
        this.determinismEvaluator = determinismEvaluator;
        this.functionAndTypeManager = functionAndTypeManager;
        ImmutableSetMultimap.Builder<RowExpression, RowExpression> setBuilder = ImmutableSetMultimap.builder();
        for (Set<RowExpression> equalityGroup : equalityGroups) {
            if (!equalityGroup.isEmpty()) {
                setBuilder.putAll(CANONICAL_ORDERING.min(equalityGroup), equalityGroup);
            }
        }
        equalitySets = setBuilder.build();

        ImmutableMap.Builder<RowExpression, RowExpression> mapBuilder = ImmutableMap.builder();
        for (Map.Entry<RowExpression, RowExpression> entry : equalitySets.entries()) {
            RowExpression canonical = entry.getKey();
            RowExpression expression = entry.getValue();
            mapBuilder.put(expression, canonical);
        }
        canonicalMap = mapBuilder.build();

        this.derivedExpressions = ImmutableSet.copyOf(derivedExpressions);
    }

    public static EqualityInference createEqualityInference(Metadata metadata, RowExpression... equalityInferences)
    {
        return new Builder(metadata)
                .addEqualityInference(equalityInferences)
                .build();
    }

    /**
     * Attempts to rewrite an RowExpression in terms of the symbols allowed by the symbol scope
     * given the known equalities. Returns null if unsuccessful.
     * This method checks if rewritten expression is non-deterministic.
     */
    public RowExpression rewriteExpression(RowExpression expression, Predicate<VariableReferenceExpression> variableScope)
    {
        checkArgument(determinismEvaluator.isDeterministic(expression), "Only deterministic expressions may be considered for rewrite");
        return rewriteExpression(expression, variableScope, true);
    }

    /**
     * Attempts to rewrite an Expression in terms of the symbols allowed by the symbol scope
     * given the known equalities. Returns null if unsuccessful.
     * This method allows rewriting non-deterministic expressions.
     */
    public RowExpression rewriteExpressionAllowNonDeterministic(RowExpression expression, Predicate<VariableReferenceExpression> variableScope)
    {
        return rewriteExpression(expression, variableScope, true);
    }

    private RowExpression rewriteExpression(RowExpression expression, Predicate<VariableReferenceExpression> variableScope, boolean allowFullReplacement)
    {
        Iterable<RowExpression> subExpressions = uniqueSubExpressions(expression);
        if (!allowFullReplacement) {
            subExpressions = filter(subExpressions, not(equalTo(expression)));
        }

        ImmutableMap.Builder<RowExpression, RowExpression> expressionRemap = ImmutableMap.builder();
        for (RowExpression subExpression : subExpressions) {
            RowExpression canonical = getScopedCanonical(subExpression, variableScope);
            if (canonical != null) {
                expressionRemap.put(subExpression, canonical);
            }
        }

        // Perform a naive single-pass traversal to try to rewrite non-compliant portions of the tree. Prefers to replace
        // larger subtrees over smaller subtrees
        // TODO: this rewrite can probably be made more sophisticated
        RowExpression rewritten = RowExpressionTreeRewriter.rewriteWith(new RowExpressionNodeInliner(expressionRemap.build()), expression);
        if (!variableToExpressionPredicate(variableScope).apply(rewritten)) {
            // If the rewritten is still not compliant with the symbol scope, just give up
            return null;
        }
        return rewritten;
    }

    /**
     * Dumps the inference equalities as equality expressions that are partitioned by the variableScope.
     * All stored equalities are returned in a compact set and will be classified into three groups as determined by the symbol scope:
     * <ol>
     * <li>equalities that fit entirely within the symbol scope</li>
     * <li>equalities that fit entirely outside of the symbol scope</li>
     * <li>equalities that straddle the symbol scope</li>
     * </ol>
     * <pre>
     * Example:
     *   Stored Equalities:
     *     a = b = c
     *     d = e = f = g
     *
     *   Symbol Scope:
     *     a, b, d, e
     *
     *   Output EqualityPartition:
     *     Scope Equalities:
     *       a = b
     *       d = e
     *     Complement Scope Equalities
     *       f = g
     *     Scope Straddling Equalities
     *       a = c
     *       d = f
     * </pre>
     */
    public EqualityPartition generateEqualitiesPartitionedBy(Predicate<VariableReferenceExpression> variableScope)
    {
        ImmutableSet.Builder<RowExpression> scopeEqualities = ImmutableSet.builder();
        ImmutableSet.Builder<RowExpression> scopeComplementEqualities = ImmutableSet.builder();
        ImmutableSet.Builder<RowExpression> scopeStraddlingEqualities = ImmutableSet.builder();

        for (Collection<RowExpression> equalitySet : equalitySets.asMap().values()) {
            Set<RowExpression> scopeExpressions = new LinkedHashSet<>();
            Set<RowExpression> scopeComplementExpressions = new LinkedHashSet<>();
            Set<RowExpression> scopeStraddlingExpressions = new LinkedHashSet<>();

            // Try to push each non-derived expression into one side of the scope
            for (RowExpression expression : filter(equalitySet, not(derivedExpressions::contains))) {
                RowExpression scopeRewritten = rewriteExpression(expression, variableScope, false);
                if (scopeRewritten != null) {
                    scopeExpressions.add(scopeRewritten);
                }
                RowExpression scopeComplementRewritten = rewriteExpression(expression, not(variableScope), false);
                if (scopeComplementRewritten != null) {
                    scopeComplementExpressions.add(scopeComplementRewritten);
                }
                if (scopeRewritten == null && scopeComplementRewritten == null) {
                    scopeStraddlingExpressions.add(expression);
                }
            }
            // Compile the equality expressions on each side of the scope
            RowExpression matchingCanonical = getCanonical(scopeExpressions);
            if (scopeExpressions.size() >= 2) {
                for (RowExpression expression : filter(scopeExpressions, not(equalTo(matchingCanonical)))) {
                    scopeEqualities.add(buildEqualsExpression(functionAndTypeManager, matchingCanonical, expression));
                }
            }
            RowExpression complementCanonical = getCanonical(scopeComplementExpressions);
            if (scopeComplementExpressions.size() >= 2) {
                for (RowExpression expression : filter(scopeComplementExpressions, not(equalTo(complementCanonical)))) {
                    scopeComplementEqualities.add(buildEqualsExpression(functionAndTypeManager, complementCanonical, expression));
                }
            }

            // Compile the scope straddling equality expressions
            List<RowExpression> connectingExpressions = new ArrayList<>();
            connectingExpressions.add(matchingCanonical);
            connectingExpressions.add(complementCanonical);
            connectingExpressions.addAll(scopeStraddlingExpressions);
            connectingExpressions = ImmutableList.copyOf(filter(connectingExpressions, Predicates.notNull()));
            RowExpression connectingCanonical = getCanonical(connectingExpressions);
            if (connectingCanonical != null) {
                for (RowExpression expression : filter(connectingExpressions, not(equalTo(connectingCanonical)))) {
                    scopeStraddlingEqualities.add(buildEqualsExpression(functionAndTypeManager, connectingCanonical, expression));
                }
            }
        }

        return new EqualityPartition(scopeEqualities.build(), scopeComplementEqualities.build(), scopeStraddlingEqualities.build());
    }

    /**
     * Returns the most preferable expression to be used as the canonical expression
     */
    private static RowExpression getCanonical(Iterable<RowExpression> expressions)
    {
        if (Iterables.isEmpty(expressions)) {
            return null;
        }
        return CANONICAL_ORDERING.min(expressions);
    }

    /**
     * Returns a canonical expression that is fully contained by the variableScope and that is equivalent
     * to the specified expression. Returns null if unable to to find a canonical.
     */
    @VisibleForTesting
    RowExpression getScopedCanonical(RowExpression expression, Predicate<VariableReferenceExpression> variableScope)
    {
        RowExpression canonicalIndex = canonicalMap.get(expression);
        if (canonicalIndex == null) {
            return null;
        }
        return getCanonical(filter(equalitySets.get(canonicalIndex), variableToExpressionPredicate(variableScope)));
    }

    private static Predicate<RowExpression> variableToExpressionPredicate(final Predicate<VariableReferenceExpression> variableScope)
    {
        return expression -> Iterables.all(VariablesExtractor.extractUnique(expression), variableScope);
    }

    public static class EqualityPartition
    {
        private final List<RowExpression> scopeEqualities;
        private final List<RowExpression> scopeComplementEqualities;
        private final List<RowExpression> scopeStraddlingEqualities;

        public EqualityPartition(Iterable<RowExpression> scopeEqualities, Iterable<RowExpression> scopeComplementEqualities, Iterable<RowExpression> scopeStraddlingEqualities)
        {
            this.scopeEqualities = ImmutableList.copyOf(requireNonNull(scopeEqualities, "scopeEqualities is null"));
            this.scopeComplementEqualities = ImmutableList.copyOf(requireNonNull(scopeComplementEqualities, "scopeComplementEqualities is null"));
            this.scopeStraddlingEqualities = ImmutableList.copyOf(requireNonNull(scopeStraddlingEqualities, "scopeStraddlingEqualities is null"));
        }

        public List<RowExpression> getScopeEqualities()
        {
            return scopeEqualities;
        }

        public List<RowExpression> getScopeComplementEqualities()
        {
            return scopeComplementEqualities;
        }

        public List<RowExpression> getScopeStraddlingEqualities()
        {
            return scopeStraddlingEqualities;
        }
    }

    public static class Builder
    {
        private final DisjointSet<RowExpression> equalities = new DisjointSet<>();
        private final Set<RowExpression> derivedExpressions = new LinkedHashSet<>();
        private final FunctionAndTypeManager functionAndTypeManager;
        private final NullabilityAnalyzer nullabilityAnalyzer;
        private final RowExpressionDeterminismEvaluator determinismEvaluator;

        public Builder(FunctionAndTypeManager functionAndTypeManager)
        {
            this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
            this.functionAndTypeManager = functionAndTypeManager;
            this.nullabilityAnalyzer = new NullabilityAnalyzer(functionAndTypeManager);
        }

        public Builder(Metadata metadata)
        {
            this(metadata.getFunctionAndTypeManager());
        }

        /**
         * Determines whether an RowExpression may be successfully applied to the equality inference
         */
        public Predicate<RowExpression> isInferenceCandidate()
        {
            return expression -> {
                expression = normalizeInPredicateToEquality(expression);
                if (isOperation(expression, EQUAL, functionAndTypeManager) &&
                        determinismEvaluator.isDeterministic(expression) &&
                        !nullabilityAnalyzer.mayReturnNullOnNonNullInput(expression)) {
                    // We should only consider equalities that have distinct left and right components
                    return !getLeft(expression).equals(getRight(expression));
                }
                return false;
            };
        }

        public static Predicate<RowExpression> isInferenceCandidate(Metadata metadata)
        {
            return new Builder(metadata).isInferenceCandidate();
        }

        /**
         * Rewrite single value InPredicates as equality if possible
         */
        private RowExpression normalizeInPredicateToEquality(RowExpression expression)
        {
            if (isInPredicate(expression)) {
                int size = ((SpecialFormExpression) expression).getArguments().size() - 1;
                checkArgument(size >= 1, "InList cannot be empty");
                if (size == 1) {
                    RowExpression leftValue = ((SpecialFormExpression) expression).getArguments().get(0);
                    RowExpression rightValue = ((SpecialFormExpression) expression).getArguments().get(1);
                    return buildEqualsExpression(functionAndTypeManager, leftValue, rightValue);
                }
            }
            return expression;
        }

        /**
         * Provides a convenience Iterable of RowExpression conjuncts which have not been added to the inference
         */
        public Iterable<RowExpression> nonInferableConjuncts(RowExpression expression)
        {
            return filter(extractConjuncts(expression), not(isInferenceCandidate()));
        }

        public static Iterable<RowExpression> nonInferableConjuncts(Metadata metadata, RowExpression expression)
        {
            return new Builder(metadata).nonInferableConjuncts(expression);
        }

        public Builder addEqualityInference(RowExpression... expressions)
        {
            for (RowExpression expression : expressions) {
                extractInferenceCandidates(expression);
            }
            return this;
        }

        public Builder extractInferenceCandidates(RowExpression expression)
        {
            return addAllEqualities(filter(extractConjuncts(expression), isInferenceCandidate()));
        }

        public EqualityInference.Builder addAllEqualities(Iterable<RowExpression> expressions)
        {
            for (RowExpression expression : expressions) {
                addEquality(expression);
            }
            return this;
        }

        public EqualityInference.Builder addEquality(RowExpression expression)
        {
            expression = normalizeInPredicateToEquality(expression);
            checkArgument(isInferenceCandidate().apply(expression), "RowExpression must be a simple equality: " + expression);
            addEquality(getLeft(expression), getRight(expression));
            return this;
        }

        public EqualityInference.Builder addEquality(RowExpression expression1, RowExpression expression2)
        {
            checkArgument(!expression1.equals(expression2), "Need to provide equality between different expressions");
            checkArgument(determinismEvaluator.isDeterministic(expression1), "RowExpression must be deterministic: " + expression1);
            checkArgument(determinismEvaluator.isDeterministic(expression2), "RowExpression must be deterministic: " + expression2);

            equalities.findAndUnion(expression1, expression2);
            return this;
        }

        /**
         * Performs one pass of generating more equivalences by rewriting sub-expressions in terms of known equivalences.
         */
        private void generateMoreEquivalences()
        {
            Collection<Set<RowExpression>> equivalentClasses = equalities.getEquivalentClasses();

            // Map every expression to the set of equivalent expressions
            ImmutableMap.Builder<RowExpression, Set<RowExpression>> mapBuilder = ImmutableMap.builder();
            for (Set<RowExpression> expressions : equivalentClasses) {
                expressions.forEach(expression -> mapBuilder.put(expression, expressions));
            }

            // For every non-derived expression, extract the sub-expressions and see if they can be rewritten as other expressions. If so,
            // use this new information to update the known equalities.
            Map<RowExpression, Set<RowExpression>> map = mapBuilder.build();
            for (RowExpression expression : map.keySet()) {
                if (!derivedExpressions.contains(expression)) {
                    for (RowExpression subExpression : filter(uniqueSubExpressions(expression), not(equalTo(expression)))) {
                        Set<RowExpression> equivalentSubExpressions = map.get(subExpression);
                        if (equivalentSubExpressions != null) {
                            for (RowExpression equivalentSubExpression : filter(equivalentSubExpressions, not(equalTo(subExpression)))) {
                                RowExpression rewritten = RowExpressionTreeRewriter.rewriteWith(new RowExpressionNodeInliner(ImmutableMap.of(subExpression, equivalentSubExpression)), expression);
                                equalities.findAndUnion(expression, rewritten);
                                derivedExpressions.add(rewritten);
                            }
                        }
                    }
                }
            }
        }

        public EqualityInference build()
        {
            generateMoreEquivalences();
            return new EqualityInference(equalities.getEquivalentClasses(), derivedExpressions, determinismEvaluator, functionAndTypeManager);
        }
    }

    protected static RowExpression getLeft(RowExpression expression)
    {
        checkArgument(expression instanceof CallExpression && ((CallExpression) expression).getArguments().size() == 2, "must be binary call expression");
        return ((CallExpression) expression).getArguments().get(0);
    }

    protected static RowExpression getRight(RowExpression expression)
    {
        checkArgument(expression instanceof CallExpression && ((CallExpression) expression).getArguments().size() == 2, "must be binary call expression");
        return ((CallExpression) expression).getArguments().get(1);
    }

    protected static boolean isOperation(RowExpression expression, OperatorType type, FunctionAndTypeManager functionAndTypeManager)
    {
        if (expression instanceof CallExpression) {
            CallExpression call = (CallExpression) expression;
            Optional<OperatorType> expressionOperatorType = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle()).getOperatorType();
            if (expressionOperatorType.isPresent()) {
                return expressionOperatorType.get() == type;
            }
        }
        return false;
    }

    private static boolean isInPredicate(RowExpression expression)
    {
        if (expression instanceof SpecialFormExpression) {
            return ((SpecialFormExpression) expression).getForm() == SpecialFormExpression.Form.IN;
        }
        return false;
    }

    private static CallExpression buildEqualsExpression(FunctionAndTypeManager functionAndTypeManager, RowExpression left, RowExpression right)
    {
        return binaryOperation(functionAndTypeManager, EQUAL, left, right);
    }

    private static CallExpression binaryOperation(FunctionAndTypeManager functionAndTypeManager, OperatorType type, RowExpression left, RowExpression right)
    {
        return call(
                left.getSourceLocation(),
                type.getFunctionName().getObjectName(),
                functionAndTypeManager.resolveOperator(type, fromTypes(left.getType(), right.getType())),
                BOOLEAN,
                left,
                right);
    }
}