ExtractCommonPredicatesExpressionRewriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.List;
import java.util.Set;

import static com.facebook.presto.sql.ExpressionUtils.combinePredicates;
import static com.facebook.presto.sql.ExpressionUtils.extractPredicates;
import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Operator.OR;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Collections.emptySet;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

public class ExtractCommonPredicatesExpressionRewriter
{
    public static Expression extractCommonPredicates(Expression expression)
    {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(), expression, NodeContext.ROOT_NODE);
    }

    private ExtractCommonPredicatesExpressionRewriter() {}

    private static class Visitor
            extends ExpressionRewriter<NodeContext>
    {
        @Override
        public Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter)
        {
            if (context.isRootNode()) {
                return treeRewriter.rewrite(node, NodeContext.NOT_ROOT_NODE);
            }

            return null;
        }

        @Override
        public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter)
        {
            Expression expression = combinePredicates(
                    node.getOperator(),
                    extractPredicates(node.getOperator(), node).stream()
                            .map(subExpression -> treeRewriter.rewrite(subExpression, NodeContext.NOT_ROOT_NODE))
                            .collect(toImmutableList()));

            if (!(expression instanceof LogicalBinaryExpression)) {
                return expression;
            }

            Expression simplified = extractCommonPredicates((LogicalBinaryExpression) expression);

            // Prefer AND LogicalBinaryExpression at the root if possible
            if (context.isRootNode() && simplified instanceof LogicalBinaryExpression && ((LogicalBinaryExpression) simplified).getOperator() == OR) {
                return distributeIfPossible((LogicalBinaryExpression) simplified);
            }

            return simplified;
        }

        private static Expression extractCommonPredicates(LogicalBinaryExpression node)
        {
            List<List<Expression>> subPredicates = getSubPredicates(node);

            Set<Expression> commonPredicates = ImmutableSet.copyOf(subPredicates.stream()
                    .map(Visitor::filterDeterministicPredicates)
                    .reduce(Sets::intersection)
                    .orElse(emptySet()));

            List<List<Expression>> uncorrelatedSubPredicates = subPredicates.stream()
                    .map(predicateList -> removeAll(predicateList, commonPredicates))
                    .collect(toImmutableList());

            LogicalBinaryExpression.Operator flippedOperator = node.getOperator().flip();

            List<Expression> uncorrelatedPredicates = uncorrelatedSubPredicates.stream()
                    .map(predicate -> combinePredicates(flippedOperator, predicate))
                    .collect(toImmutableList());
            Expression combinedUncorrelatedPredicates = combinePredicates(node.getOperator(), uncorrelatedPredicates);

            return combinePredicates(flippedOperator, ImmutableList.<Expression>builder()
                    .addAll(commonPredicates)
                    .add(combinedUncorrelatedPredicates)
                    .build());
        }

        private static List<List<Expression>> getSubPredicates(LogicalBinaryExpression expression)
        {
            return extractPredicates(expression.getOperator(), expression).stream()
                    .map(predicate -> predicate instanceof LogicalBinaryExpression ?
                            extractPredicates((LogicalBinaryExpression) predicate) : ImmutableList.of(predicate))
                    .collect(toImmutableList());
        }

        /**
         * Applies the boolean distributive property.
         * <p>
         * For example:
         * ( A & B ) | ( C & D ) => ( A | C ) & ( A | D ) & ( B | C ) & ( B | D)
         * <p>
         * Returns the original expression if the expression is non-deterministic or if the distribution will
         * expand the expression by too much.
         */
        private static Expression distributeIfPossible(LogicalBinaryExpression expression)
        {
            if (!ExpressionDeterminismEvaluator.isDeterministic(expression)) {
                // Do not distribute boolean expressions if there are any non-deterministic elements
                // TODO: This can be optimized further if non-deterministic elements are not repeated
                return expression;
            }
            List<Set<Expression>> subPredicates = getSubPredicates(expression).stream()
                    .map(ImmutableSet::copyOf)
                    .collect(toList());

            int originalBaseExpressions = subPredicates.stream()
                    .mapToInt(Set::size)
                    .sum();

            int newBaseExpressions;
            try {
                newBaseExpressions = Math.multiplyExact(subPredicates.stream()
                        .mapToInt(Set::size)
                        .reduce(Math::multiplyExact)
                        .getAsInt(), subPredicates.size());
            }
            catch (ArithmeticException e) {
                // Integer overflow from multiplication means there are too many expressions
                return expression;
            }

            if (newBaseExpressions > originalBaseExpressions * 2) {
                // Do not distribute boolean expressions if it would create 2x more base expressions
                // (e.g. A, B, C, D from the above example). This is just an arbitrary heuristic to
                // avoid cross product expression explosion.
                return expression;
            }

            Set<List<Expression>> crossProduct = Sets.cartesianProduct(subPredicates);

            return combinePredicates(
                    expression.getOperator().flip(),
                    crossProduct.stream()
                            .map(expressions -> combinePredicates(expression.getOperator(), expressions))
                            .collect(toImmutableList()));
        }

        private static Set<Expression> filterDeterministicPredicates(List<Expression> predicates)
        {
            return predicates.stream()
                    .filter(ExpressionDeterminismEvaluator::isDeterministic)
                    .collect(toSet());
        }

        private static <T> List<T> removeAll(Collection<T> collection, Collection<T> elementsToRemove)
        {
            return collection.stream()
                    .filter(element -> !elementsToRemove.contains(element))
                    .collect(toImmutableList());
        }
    }

    private enum NodeContext
    {
        ROOT_NODE,
        NOT_ROOT_NODE;

        boolean isRootNode()
        {
            return this == ROOT_NODE;
        }
    }
}