RewriteCaseExpressionPredicate.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.isOptimizeCaseExpressionPredicate;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
import static com.facebook.presto.expressions.LogicalRowExpressions.replaceArguments;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

/**
 * This Rule rewrites a CASE expression predicate into a series of AND/OR clauses.
 * The following CASE expression
 * <p>
 * (CASE
 * WHEN expression=constant1 THEN result1
 * WHEN expression=constant2 THEN result2
 * WHEN expression=constant3 THEN result3
 * ELSE elseResult
 * END) = value
 * <p>
 * can be converted into a series AND/OR clauses as below
 * <p>
 * (result1 = value AND expression=constant1) OR
 * (result2 = value AND expression=constant2 AND !(expression=constant1)) OR
 * (result3 = value AND expression=constant3 AND !(expression=constant1) AND !(expression=constant2)) OR
 * (elseResult = value AND !(expression=constant1) AND !(expression=constant2) AND !(expression=constant3))
 * <p>
 * The above conversion evaluates the conditions in WHEN clauses multiple times. But if we ensure these conditions are
 * disjunct, we can skip all the NOT of previous WHEN conditions and simplify the expression to:
 * <p>
 * (result1 = value AND expression=constant1) OR
 * (result2 = value AND expression=constant2) OR
 * (result3 = value AND expression=constant3) OR
 * (elseResult = value AND !(expression=constant1) AND !(expression=constant2) AND !(expression=constant3))
 * <p>
 * To ensure the WHEN conditions are disjunct, the following criteria needs to be met:
 * 1. Value is either a constant or column reference or input reference and not any function
 * 2. The LHS expression in all WHEN clauses are the same.
 *    For example, if one WHEN clause has a expression using col1 and another using col2, it will not work
 * 3. The relational operator in the WHEN clause is equals. With other operators it is hard to check for exclusivity.
 * 4. All the RHS expressions are a constant and unique
 * <p>
 * This conversion is done so that it is easy for the ExpressionInterpreter & other Optimizers to further
 * simplify this and construct a domain for the column that can be used by Readers .
 * i.e, ExpressionInterpreter can discard all conditions in which result != value and
 *      RowExpressionDomainTranslator can construct a Domain for the column
 */
public class RewriteCaseExpressionPredicate
        extends RowExpressionRewriteRuleSet
{
    public RewriteCaseExpressionPredicate(FunctionAndTypeManager functionAndTypeManager)
    {
        super(new Rewriter(functionAndTypeManager));
    }

    private static class Rewriter
            implements PlanRowExpressionRewriter
    {
        private final CaseExpressionPredicateRewriter caseExpressionPredicateRewriter;

        public Rewriter(FunctionAndTypeManager functionAndTypeManager)
        {
            requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.caseExpressionPredicateRewriter = new CaseExpressionPredicateRewriter(functionAndTypeManager);
        }

        @Override
        public RowExpression rewrite(RowExpression expression, Rule.Context context)
        {
            return RowExpressionTreeRewriter.rewriteWith(caseExpressionPredicateRewriter, expression);
        }
    }

    private static class CaseExpressionPredicateRewriter
            extends RowExpressionRewriter<Void>
    {
        private final FunctionResolution functionResolution;
        private final LogicalRowExpressions logicalRowExpressions;

        private CaseExpressionPredicateRewriter(FunctionAndTypeManager functionAndTypeManager)
        {
            this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
            this.logicalRowExpressions = new LogicalRowExpressions(
                    new RowExpressionDeterminismEvaluator(functionAndTypeManager),
                    functionResolution,
                    functionAndTypeManager);
        }

        @Override
        public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
        {
            if (functionResolution.isComparisonFunction(node.getFunctionHandle()) && node.getArguments().size() == 2) {
                RowExpression left = node.getArguments().get(0);
                RowExpression right = node.getArguments().get(1);
                if (isCaseExpression(left) && isSimpleExpression(right)) {
                    return processCaseExpression(left, expression -> replaceArguments(node, expression, right), right);
                }
                else if (isCaseExpression(right) && isSimpleExpression(left)) {
                    return processCaseExpression(right, expression -> replaceArguments(node, left, expression), left);
                }
            }
            return null;
        }

        private boolean isCaseExpression(RowExpression expression)
        {
            if (logicalRowExpressions.isCastExpression(expression)) {
                expression = ((CallExpression) expression).getArguments().get(0);
            }
            return expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm().equals(SWITCH);
        }

        private boolean isSimpleExpression(RowExpression expression)
        {
            if (logicalRowExpressions.isCastExpression(expression)) {
                return isSimpleExpression(((CallExpression) expression).getArguments().get(0));
            }
            return expression instanceof ConstantExpression ||
                    expression instanceof VariableReferenceExpression ||
                    expression instanceof InputReferenceExpression;
        }

        private RowExpression processCaseExpression(RowExpression expression,
                                                    Function<RowExpression, RowExpression> comparisonExpressionGenerator,
                                                    RowExpression value)
        {
            if (expression instanceof SpecialFormExpression) {
                checkArgument(logicalRowExpressions.isCaseExpression(expression), "expression must be a CASE expression");
                return processCaseExpression(
                        (SpecialFormExpression) expression,
                        Optional.empty(),
                        comparisonExpressionGenerator,
                        value);
            }
            else {
                checkArgument(logicalRowExpressions.isCastExpression(expression), "expression must be a CAST expression");
                checkArgument(logicalRowExpressions.isCaseExpression(((CallExpression) expression).getArguments().get(0)), "expression argument must be a CASE expression");
                return processCaseExpression(
                        (SpecialFormExpression) ((CallExpression) expression).getArguments().get(0),
                        Optional.of((CallExpression) expression),
                        comparisonExpressionGenerator,
                        value);
            }
        }

        /**
         * RowExpression representation of Case Statement:
         * SpecialFormExpression:
         *  form: SWITCH
         *  arguments:
         *      [0]: RowExpression (or) ConstantExpression(TRUE)    // SimpleCaseExpression (or) SearchedCaseExpression
         *      [1..n-1 (or) n]: SpecialFormExpression(form: WHEN)  // else clause is present (or) absent
         *      [n]: RowExpression                                  // available if else clause is present
         */
        private RowExpression processCaseExpression(SpecialFormExpression caseExpression,
                                                    Optional<CallExpression> castExpression,
                                                    Function<RowExpression, RowExpression> comparisonExpressionGenerator,
                                                    RowExpression value)
        {
            Optional<RowExpression> caseOperand = getCaseOperand(caseExpression.getArguments().get(0));
            List<RowExpression> whenClauses;
            Optional<RowExpression> elseResult = Optional.empty();
            int argumentsSize = caseExpression.getArguments().size();
            RowExpression last = caseExpression.getArguments().get(argumentsSize - 1);

            if (last instanceof SpecialFormExpression && ((SpecialFormExpression) last).getForm().equals(WHEN)) {
                whenClauses = caseExpression.getArguments().subList(1, argumentsSize);
            }
            else {
                whenClauses = caseExpression.getArguments().subList(1, argumentsSize - 1);
                elseResult = Optional.of(last);
            }

            if (caseOperand.isPresent() ?
                    !canRewriteSimpleCaseExpression(whenClauses) :
                    !canRewriteSearchedCaseExpression(whenClauses)) {
                return null;
            }

            ImmutableList.Builder<RowExpression> andExpressions = new ImmutableList.Builder<>();
            ImmutableList.Builder<RowExpression> invertedOperands = new ImmutableList.Builder<>();

            for (RowExpression whenClause : whenClauses) {
                RowExpression whenOperand = ((SpecialFormExpression) whenClause).getArguments().get(0);
                if (caseOperand.isPresent()) {
                    whenOperand = logicalRowExpressions.equalsCallExpression(caseOperand.get(), whenOperand);
                }
                RowExpression whenResult = ((SpecialFormExpression) whenClause).getArguments().get(1);
                if (castExpression.isPresent()) {
                    whenResult = replaceArguments(castExpression.get(), whenResult);
                }

                RowExpression comparisonExpression = comparisonExpressionGenerator.apply(whenResult);
                andExpressions.add(and(comparisonExpression, whenOperand));
                invertedOperands.add(logicalRowExpressions.notCallExpression(whenOperand));
            }
            RowExpression elseCondition = and(
                    getElseExpression(castExpression, value, elseResult, comparisonExpressionGenerator),
                    and(invertedOperands.build()));
            andExpressions.add(elseCondition);

            return or(andExpressions.build());
        }

        private RowExpression getElseExpression(Optional<CallExpression> castExpression,
                                                RowExpression value,
                                                Optional<RowExpression> elseValue,
                                                Function<RowExpression, RowExpression> comparisonExpressionGenerator)
        {
            return elseValue.map(
                    elseVal -> comparisonExpressionGenerator.apply(castExpression.map(castExp -> replaceArguments(castExp, elseVal)).orElse(elseVal)
                    )).orElseGet(() -> new SpecialFormExpression(IS_NULL, BOOLEAN, value));
        }

        private Optional<RowExpression> getCaseOperand(RowExpression expression)
        {
            boolean searchedCase = (expression instanceof ConstantExpression && expression.getType() == BOOLEAN &&
                    ((ConstantExpression) expression).getValue() == Boolean.TRUE);
            return searchedCase ? Optional.empty() : Optional.of(expression);
        }

        private boolean canRewriteSimpleCaseExpression(List<RowExpression> whenClauses)
        {
            List<RowExpression> whenOperands = whenClauses.stream()
                    .map(x -> ((SpecialFormExpression) x).getArguments().get(0))
                    .collect(Collectors.toList());
            return allExpressionsAreConstantAndUnique(whenOperands);
        }

        private boolean canRewriteSearchedCaseExpression(List<RowExpression> whenClauses)
        {
            if (!allAreEqualsExpression(whenClauses) || !allLHSOperandsAreUnique(whenClauses)) {
                return false;
            }
            List<RowExpression> rhsExpressions = whenClauses.stream()
                    .map(whenClause -> ((SpecialFormExpression) whenClause).getArguments().get(0))
                    .map(whenOperand -> ((CallExpression) whenOperand).getArguments().get(1))
                    .collect(Collectors.toList());
            return allExpressionsAreConstantAndUnique(rhsExpressions);
        }

        private boolean allLHSOperandsAreUnique(List<RowExpression> whenClauses)
        {
            return whenClauses.stream()
                    .map(whenClause -> ((SpecialFormExpression) whenClause).getArguments().get(0))
                    .map(whenOperand -> ((CallExpression) whenOperand).getArguments().get(0))
                    .distinct()
                    .count() == 1;
        }

        private boolean allAreEqualsExpression(List<RowExpression> whenClauses)
        {
            return whenClauses.stream()
                    .map(whenClause -> ((SpecialFormExpression) whenClause).getArguments().get(0))
                    .allMatch(logicalRowExpressions::isEqualsExpression);
        }

        private boolean allExpressionsAreConstantAndUnique(List<RowExpression> expressions)
        {
            Set<RowExpression> expressionSet = new HashSet<>();
            for (RowExpression expression : expressions) {
                if (!isConstantExpression(expression) || expressionSet.contains(expression)) {
                    return false;
                }
                expressionSet.add(expression);
            }
            return true;
        }

        private boolean isConstantExpression(RowExpression expression)
        {
            if (logicalRowExpressions.isCastExpression(expression)) {
                return isConstantExpression(((CallExpression) expression).getArguments().get(0));
            }
            return expression instanceof ConstantExpression;
        }
    }

    @Override
    public boolean isRewriterEnabled(Session session)
    {
        return isOptimizeCaseExpressionPredicate(session);
    }

    @Override
    public Set<Rule<?>> rules()
    {
        return ImmutableSet.of(
                filterRowExpressionRewriteRule(),
                joinRowExpressionRewriteRule());
    }
}