LogicalRowExpressions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.expressions;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression.Form;
import com.facebook.presto.spi.relation.VariableReferenceExpression;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static java.lang.Math.min;
import static java.util.Arrays.asList;
import static java.util.Arrays.stream;
import static java.util.Collections.singletonList;
import static java.util.Collections.unmodifiableList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public final class LogicalRowExpressions
{
    public static final ConstantExpression TRUE_CONSTANT = new ConstantExpression(true, BOOLEAN);
    public static final ConstantExpression FALSE_CONSTANT = new ConstantExpression(false, BOOLEAN);
    // 10000 is very conservative estimation
    private static final int ELIMINATE_COMMON_SIZE_LIMIT = 10000;

    private final DeterminismEvaluator determinismEvaluator;
    private final StandardFunctionResolution functionResolution;
    private final FunctionMetadataManager functionMetadataManager;

    public LogicalRowExpressions(DeterminismEvaluator determinismEvaluator, StandardFunctionResolution functionResolution, FunctionMetadataManager functionMetadataManager)
    {
        this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null");
        this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
        this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
    }

    public static List<RowExpression> extractConjuncts(RowExpression expression)
    {
        return extractPredicates(AND, expression);
    }

    public static List<RowExpression> extractDisjuncts(RowExpression expression)
    {
        return extractPredicates(OR, expression);
    }

    public static List<RowExpression> extractPredicates(RowExpression expression)
    {
        if (expression instanceof SpecialFormExpression) {
            Form form = ((SpecialFormExpression) expression).getForm();
            if (form == AND || form == OR) {
                return extractPredicates(form, expression);
            }
        }
        return singletonList(expression);
    }

    public static List<RowExpression> extractPredicates(Form form, RowExpression expression)
    {
        if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == form) {
            SpecialFormExpression specialFormExpression = (SpecialFormExpression) expression;
            if (specialFormExpression.getArguments().size() == 2) {
                List<RowExpression> predicates = new ArrayList<>();
                predicates.addAll(extractPredicates(form, specialFormExpression.getArguments().get(0)));
                predicates.addAll(extractPredicates(form, specialFormExpression.getArguments().get(1)));
                return unmodifiableList(predicates);
            }
            if (specialFormExpression.getArguments().size() == 1 && form == IS_NULL) {
                return singletonList(expression);
            }
            throw new IllegalStateException("Unexpected operands:" + expression + " " + form);
        }

        return singletonList(expression);
    }

    public static RowExpression and(RowExpression... expressions)
    {
        return and(asList(expressions));
    }

    public static RowExpression and(Collection<? extends RowExpression> expressions)
    {
        return binaryExpression(AND, expressions);
    }

    public static RowExpression or(RowExpression... expressions)
    {
        return or(asList(expressions));
    }

    public static RowExpression or(Collection<? extends RowExpression> expressions)
    {
        return binaryExpression(OR, expressions);
    }

    public static RowExpression binaryExpression(Form form, Collection<? extends RowExpression> expressions)
    {
        requireNonNull(form, "operator is null");
        requireNonNull(expressions, "expressions is null");

        if (expressions.isEmpty()) {
            switch (form) {
                case AND:
                    return TRUE_CONSTANT;
                case OR:
                    return FALSE_CONSTANT;
                default:
                    throw new IllegalArgumentException("Unsupported binary expression operator");
            }
        }

        // Build balanced tree for efficient recursive processing that
        // preserves the evaluation order of the input expressions.
        //
        // The tree is built bottom up by combining pairs of elements into
        // binary AND expressions.
        //
        // Example:
        //
        // Initial state:
        //  a b c d e
        //
        // First iteration:
        //
        //  /\    /\   e
        // a  b  c  d
        //
        // Second iteration:
        //
        //    / \    e
        //  /\   /\
        // a  b c  d
        //
        //
        // Last iteration:
        //
        //      / \
        //    / \  e
        //  /\   /\
        // a  b c  d

        Queue<RowExpression> queue = new ArrayDeque<>(expressions);
        while (queue.size() > 1) {
            Queue<RowExpression> buffer = new ArrayDeque<>();

            // combine pairs of elements
            while (queue.size() >= 2) {
                List<RowExpression> arguments = asList(queue.remove(), queue.remove());
                buffer.add(new SpecialFormExpression(form, BOOLEAN, arguments));
            }

            // if there's and odd number of elements, just append the last one
            if (!queue.isEmpty()) {
                buffer.add(queue.remove());
            }

            // continue processing the pairs that were just built
            queue = buffer;
        }

        return queue.remove();
    }

    public RowExpression combinePredicates(Form form, RowExpression... expressions)
    {
        return combinePredicates(form, asList(expressions));
    }

    public RowExpression combinePredicates(Form form, Collection<RowExpression> expressions)
    {
        if (form == AND) {
            return combineConjuncts(expressions);
        }
        return combineDisjuncts(expressions);
    }

    public RowExpression combineConjuncts(RowExpression... expressions)
    {
        return combineConjuncts(asList(expressions));
    }

    public RowExpression combineConjuncts(Collection<RowExpression> expressions)
    {
        requireNonNull(expressions, "expressions is null");

        List<RowExpression> conjuncts = expressions.stream()
                .flatMap(e -> extractConjuncts(e).stream())
                .filter(e -> !e.equals(TRUE_CONSTANT))
                .collect(toList());

        conjuncts = removeDuplicates(conjuncts);

        if (conjuncts.contains(FALSE_CONSTANT)) {
            return FALSE_CONSTANT;
        }

        return and(conjuncts);
    }

    public RowExpression combineDisjuncts(RowExpression... expressions)
    {
        return combineDisjuncts(asList(expressions));
    }

    public RowExpression combineDisjuncts(Collection<RowExpression> expressions)
    {
        return combineDisjunctsWithDefault(expressions, FALSE_CONSTANT);
    }

    public RowExpression combineDisjunctsWithDefault(Collection<RowExpression> expressions, RowExpression emptyDefault)
    {
        requireNonNull(expressions, "expressions is null");

        List<RowExpression> disjuncts = expressions.stream()
                .flatMap(e -> extractDisjuncts(e).stream())
                .filter(e -> !e.equals(FALSE_CONSTANT))
                .collect(toList());

        disjuncts = removeDuplicates(disjuncts);

        if (disjuncts.contains(TRUE_CONSTANT)) {
            return TRUE_CONSTANT;
        }

        return disjuncts.isEmpty() ? emptyDefault : or(disjuncts);
    }

    /**
     * Given a logical expression, the goal is to push negation to the leaf nodes.
     * This only applies to propositional logic and comparison. this utility cannot be applied to high-order logic.
     * Examples of non-applicable cases could be f(a AND b) > 5
     *
     * An applicable example:
     *
     *        NOT
     *         |
     *      ___OR_                          AND
     *     /      \                      /      \
     *    NOT     OR        ==>        AND      AND
     *     |     /  \                 /  \     /   \
     *    AND   c   NOT              a    b   NOT  d
     *   /  \        |                         |
     *  a    b       d                         c
     */
    public RowExpression pushNegationToLeaves(RowExpression expression)
    {
        return expression.accept(new PushNegationVisitor(), null);
    }

    /**
     * Given a logical expression, the goal is to convert to conjunctive normal form (CNF).
     * This requires making a call to `pushNegationToLeaves`. There is no guarantee as to
     * the balance of the resulting expression tree.
     *
     * This only applies to propositional logic. this utility cannot be applied to high-order logic.
     * Examples of non-applicable cases could be f(a AND b) > 5
     *
     * NOTE: This may exponentially increase the number of RowExpressions in the expression.
     *
     * An applicable example:
     *
     *        NOT
     *         |
     *      ___OR_                          AND
     *     /      \                      /      \
     *    NOT     OR        ==>        OR      AND
     *     |     /  \                 /  \     /   \
     *    OR   c   NOT              a    b   NOT  d
     *   /  \        |                         |
     *  a    b       d                         c
     */
    public RowExpression convertToConjunctiveNormalForm(RowExpression expression)
    {
        return convertToNormalForm(expression, AND);
    }

    /**
     * Given a logical expression, the goal is to convert to disjunctive normal form (DNF).
     * The same limitations, format, and risks apply as for converting to conjunctive normal form (CNF).
     *
     * An applicable example:
     *
     *        NOT                                    OR
     *         |                                 /        \
     *      ___OR_                          AND            AND
     *     /      \                        /    \        /     \
     *    NOT     OR        ==>          a     AND      b     AND
     *     |     /  \                         /   \          /   \
     *    OR   c   NOT                       NOT  d         NOT  d
     *   /  \        |                         |              |
     *  a    b       d                         c              c
     */
    public RowExpression convertToDisjunctiveNormalForm(RowExpression expression)
    {
        return convertToNormalForm(expression, OR);
    }

    public RowExpression minimalNormalForm(RowExpression expression)
    {
        RowExpression conjunctiveNormalForm = convertToConjunctiveNormalForm(expression);
        RowExpression disjunctiveNormalForm = convertToDisjunctiveNormalForm(expression);
        return numOfClauses(conjunctiveNormalForm) > numOfClauses(disjunctiveNormalForm) ? disjunctiveNormalForm : conjunctiveNormalForm;
    }

    public RowExpression convertToNormalForm(RowExpression expression, Form clauseJoiner)
    {
        return pushNegationToLeaves(expression).accept(new ConvertNormalFormVisitor(), rootContext(clauseJoiner));
    }

    public RowExpression filterDeterministicConjuncts(RowExpression expression)
    {
        return filterConjuncts(expression, this.determinismEvaluator::isDeterministic);
    }

    public RowExpression filterNonDeterministicConjuncts(RowExpression expression)
    {
        return filterConjuncts(expression, predicate -> !this.determinismEvaluator.isDeterministic(predicate));
    }

    public RowExpression filterConjuncts(RowExpression expression, Predicate<RowExpression> predicate)
    {
        List<RowExpression> conjuncts = extractConjuncts(expression).stream()
                .filter(predicate)
                .collect(toList());

        return combineConjuncts(conjuncts);
    }

    /**
     * Removes duplicate deterministic expressions. Preserves the relative order
     * of the expressions in the list.
     */
    private List<RowExpression> removeDuplicates(List<RowExpression> expressions)
    {
        Set<RowExpression> seen = new HashSet<>();

        List<RowExpression> result = new ArrayList<>();
        for (RowExpression expression : expressions) {
            if (determinismEvaluator.isDeterministic(expression)) {
                if (!seen.contains(expression)) {
                    result.add(expression);
                    seen.add(expression);
                }
            }
            else {
                result.add(expression);
            }
        }

        return unmodifiableList(result);
    }

    private boolean isConjunctionOrDisjunction(RowExpression expression)
    {
        if (expression instanceof SpecialFormExpression) {
            Form form = ((SpecialFormExpression) expression).getForm();
            return form == AND || form == OR;
        }
        return false;
    }

    private final class PushNegationVisitor
            implements RowExpressionVisitor<RowExpression, Void>
    {
        @Override
        public RowExpression visitCall(CallExpression call, Void context)
        {
            if (!isNegationExpression(call)) {
                return call;
            }

            checkArgument(call.getArguments().size() == 1, "Not expression should have exactly one argument");
            RowExpression argument = call.getArguments().get(0);

            // eliminate two consecutive negations
            if (isNegationExpression(argument)) {
                return ((CallExpression) argument).getArguments().get(0).accept(new PushNegationVisitor(), null);
            }

            if (isComparisonExpression(argument)) {
                return negateComparison((CallExpression) argument);
            }

            if (!isConjunctionOrDisjunction(argument)) {
                return call;
            }

            // push negation through conjunction or disjunction
            SpecialFormExpression specialForm = ((SpecialFormExpression) argument);
            RowExpression left = specialForm.getArguments().get(0);
            RowExpression right = specialForm.getArguments().get(1);
            if (specialForm.getForm() == AND) {
                // !(a AND b) ==> !a OR !b
                return or(notCallExpression(left).accept(new PushNegationVisitor(), null), notCallExpression(right).accept(this, null));
            }
            // !(a OR b) ==> !a AND !b
            return and(notCallExpression(left).accept(new PushNegationVisitor(), null), notCallExpression(right).accept(this, null));
        }

        private RowExpression negateComparison(CallExpression expression)
        {
            OperatorType newOperator = negate(getOperator(expression).orElse(null));
            if (newOperator == null) {
                return new CallExpression(expression.getSourceLocation(), "NOT", functionResolution.notFunction(), BOOLEAN, singletonList(expression));
            }
            checkArgument(expression.getArguments().size() == 2, "Comparison expression must have exactly two arguments");
            RowExpression left = expression.getArguments().get(0).accept(this, null);
            RowExpression right = expression.getArguments().get(1).accept(this, null);
            return new CallExpression(
                    left.getSourceLocation(),
                    newOperator.getOperator(),
                    functionResolution.comparisonFunction(newOperator, left.getType(), right.getType()),
                    BOOLEAN,
                    asList(left, right));
        }

        @Override
        public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context)
        {
            if (!isConjunctionOrDisjunction(specialForm)) {
                return specialForm;
            }

            RowExpression left = specialForm.getArguments().get(0);
            RowExpression right = specialForm.getArguments().get(1);

            if (specialForm.getForm() == AND) {
                return and(left.accept(this, null), right.accept(this, null));
            }
            return or(left.accept(this, null), right.accept(this, null));
        }

        @Override
        public RowExpression visitInputReference(InputReferenceExpression reference, Void context)
        {
            return reference;
        }

        @Override
        public RowExpression visitConstant(ConstantExpression literal, Void context)
        {
            return literal;
        }

        @Override
        public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
        {
            return lambda;
        }

        @Override
        public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
        {
            return reference;
        }
    }

    private static ConvertNormalFormVisitorContext rootContext(Form clauseJoiner)
    {
        return new ConvertNormalFormVisitorContext(clauseJoiner, 0);
    }

    private static class ConvertNormalFormVisitorContext
    {
        private final Form expectedClauseJoiner;
        private final int depth;

        public ConvertNormalFormVisitorContext(Form expectedClauseJoiner, int depth)
        {
            this.expectedClauseJoiner = expectedClauseJoiner;
            this.depth = depth;
        }

        public ConvertNormalFormVisitorContext childContext()
        {
            return new ConvertNormalFormVisitorContext(expectedClauseJoiner, depth + 1);
        }
    }

    private class ConvertNormalFormVisitor
            implements RowExpressionVisitor<RowExpression, ConvertNormalFormVisitorContext>
    {
        @Override
        public RowExpression visitSpecialForm(SpecialFormExpression specialFormExpression, ConvertNormalFormVisitorContext context)
        {
            if (!isConjunctionOrDisjunction(specialFormExpression)) {
                return specialFormExpression;
            }
            // Attempt to convert sub expression to expected normal form, deduplicate and fold constants.
            RowExpression rewritten = combinePredicates(
                    specialFormExpression.getForm(),
                    extractPredicates(specialFormExpression.getForm(), specialFormExpression).stream()
                            .map(subPredicate -> subPredicate.accept(this, context.childContext()))
                            .collect(toList()));

            if (!isConjunctionOrDisjunction(rewritten)) {
                return rewritten;
            }

            SpecialFormExpression rewrittenSpecialForm = (SpecialFormExpression) rewritten;
            Form expressionClauseJoiner = rewrittenSpecialForm.getForm();
            List<List<RowExpression>> groupedClauses = getGroupedClauses(rewrittenSpecialForm);

            if (groupedClauses.stream().mapToInt(List::size).sum() > ELIMINATE_COMMON_SIZE_LIMIT) {
                return rewritten;
            }
            groupedClauses = eliminateCommonPredicates(groupedClauses);

            // extractCommonPredicates can produce opposite expectedClauseJoiner
            List<List<RowExpression>> groupedClausesWithFlippedJoiner = extractCommonPredicates(expressionClauseJoiner, groupedClauses);
            if (groupedClausesWithFlippedJoiner != null) {
                groupedClauses = groupedClausesWithFlippedJoiner;
                expressionClauseJoiner = flip(expressionClauseJoiner);
            }

            int numClauses = groupedClauses.stream().mapToInt(List::size).sum();

            int numClausesProducedByDistributiveLaw = groupedClauses.size();
            for (List<RowExpression> group : groupedClauses) {
                numClausesProducedByDistributiveLaw *= group.size();
                // If distributive rule will produce too many sub expressions, return what we have instead.
                if (context.depth > 0 || numClausesProducedByDistributiveLaw > numClauses * 2) {
                    return combineGroupedClauses(expressionClauseJoiner, groupedClauses);
                }
            }
            // size unchanged means distributive law will not apply, we can save an unnecessary crossProduct call.
            // For example, distributive law cannot apply to (a || b || c).
            if (numClausesProducedByDistributiveLaw == numClauses) {
                return combineGroupedClauses(expressionClauseJoiner, groupedClauses);
            }

            // TODO if the non-deterministic operation only appears in the only sub-predicates that has size >1, we can still expand it.
            // For example: a && b && c && (d || e) can still be expanded if d or e is non-deterministic.
            boolean deterministic = groupedClauses.stream()
                    .flatMap(List::stream)
                    .allMatch(determinismEvaluator::isDeterministic);

            // Do not apply distributive law if there is non-deterministic element or we have already got expected expectedClauseJoiner.
            if (expressionClauseJoiner == context.expectedClauseJoiner || !deterministic) {
                return combineGroupedClauses(expressionClauseJoiner, groupedClauses);
            }

            // else, we apply distributive law and rewrite based on distributive property of Boolean algebra, for example
            // (l1 OR l2) AND (r1 OR r2) <=> (l1 AND r1) OR (l1 AND r2) OR (l2 AND r1) OR (l2 AND r2)
            groupedClauses = crossProduct(groupedClauses);
            return combineGroupedClauses(context.expectedClauseJoiner, groupedClauses);
        }

        @Override
        public RowExpression visitCall(CallExpression call, ConvertNormalFormVisitorContext context)
        {
            return call;
        }

        @Override
        public RowExpression visitInputReference(InputReferenceExpression reference, ConvertNormalFormVisitorContext context)
        {
            return reference;
        }

        @Override
        public RowExpression visitConstant(ConstantExpression literal, ConvertNormalFormVisitorContext context)
        {
            return literal;
        }

        @Override
        public RowExpression visitLambda(LambdaDefinitionExpression lambda, ConvertNormalFormVisitorContext context)
        {
            return lambda;
        }

        @Override
        public RowExpression visitVariableReference(VariableReferenceExpression reference, ConvertNormalFormVisitorContext context)
        {
            return reference;
        }
    }

    private boolean isNegationExpression(RowExpression expression)
    {
        return expression instanceof CallExpression && ((CallExpression) expression).getFunctionHandle().equals(functionResolution.notFunction());
    }

    private boolean isComparisonExpression(RowExpression expression)
    {
        return expression instanceof CallExpression && functionResolution.isComparisonFunction(((CallExpression) expression).getFunctionHandle());
    }

    public boolean isEqualsExpression(RowExpression expression)
    {
        return expression instanceof CallExpression && functionResolution.isEqualsFunction(((CallExpression) expression).getFunctionHandle());
    }

    public boolean isCastExpression(RowExpression expression)
    {
        return expression instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) expression).getFunctionHandle());
    }

    public boolean isCaseExpression(RowExpression expression)
    {
        return expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm().equals(SWITCH);
    }

    /**
     * Extract the component predicates as a list of list in which is grouped so that the outer level has same conjunctive/disjunctive joiner as original predicate and
     * inner level has opposite joiner.
     * For example, (a or b) and (a or c) or ( a or c) returns [[a,b], [a,c], [a,c]]
     */
    private List<List<RowExpression>> getGroupedClauses(SpecialFormExpression expression)
    {
        return extractPredicates(expression.getForm(), expression).stream()
                .map(LogicalRowExpressions::extractPredicates)
                .collect(toList());
    }

    private int numOfClauses(RowExpression expression)
    {
        if (isConjunctionOrDisjunction(expression)) {
            return getGroupedClauses((SpecialFormExpression) expression).stream().mapToInt(List::size).sum();
        }
        return 1;
    }

    /**
     * Eliminate a sub predicate if its sub predicates contain its peer.
     * For example: (a || b) && a = a, (a && b) || b = b
     */
    private List<List<RowExpression>> eliminateCommonPredicates(List<List<RowExpression>> groupedClauses)
    {
        if (groupedClauses.size() < 2) {
            return groupedClauses;
        }
        // initialize to self
        int[] reduceTo = IntStream.range(0, groupedClauses.size()).toArray();
        for (int i = 0; i < groupedClauses.size(); i++) {
            // Do not eliminate predicates contain non-deterministic value
            // (a || b) && a should be kept same if a is non-deterministic.
            // TODO We can eliminate (a || b) && a if a is deterministic even b is not.
            if (groupedClauses.get(i).stream().allMatch(determinismEvaluator::isDeterministic)) {
                for (int j = 0; j < groupedClauses.size(); j++) {
                    if (isSuperSet(groupedClauses.get(reduceTo[i]), groupedClauses.get(j))) {
                        reduceTo[i] = j; //prefer smaller set
                    }
                    else if (isSameSet(groupedClauses.get(reduceTo[i]), groupedClauses.get(j))) {
                        reduceTo[i] = min(reduceTo[i], j); //prefer predicates that appears earlier.
                    }
                }
            }
        }

        return unmodifiableList(stream(reduceTo)
                .distinct()
                .boxed()
                .map(groupedClauses::get)
                .collect(toList()));
    }

    /**
     * Eliminate a sub predicate if its component predicates contain its peer. Will return null if cannot extract common predicates otherwise return a nested list with flipped form
     * For example:
     * (a || b || c || d) && (a || b || e || f)  -> a || b || ((c || d) && (e || f))
     * (a || b) && (c || d) -> null
     */
    private List<List<RowExpression>> extractCommonPredicates(Form rootClauseJoiner, List<List<RowExpression>> groupedPredicates)
    {
        if (groupedPredicates.isEmpty()) {
            return null;
        }
        Set<RowExpression> commonPredicates = new LinkedHashSet<>(groupedPredicates.get(0));
        for (int i = 1; i < groupedPredicates.size(); i++) {
            // remove all non-common predicates
            commonPredicates.retainAll(groupedPredicates.get(i));
        }

        if (commonPredicates.isEmpty()) {
            return null;
        }
        // extract the component predicates that are not in common predicates: [(c || d), (e || f)]
        List<RowExpression> remainingPredicates = new ArrayList<>();
        for (List<RowExpression> group : groupedPredicates) {
            List<RowExpression> remaining = group.stream()
                    .filter(predicate -> !commonPredicates.contains(predicate))
                    .collect(toList());
            remainingPredicates.add(combinePredicates(flip(rootClauseJoiner), remaining));
        }
        // combine common predicates and remaining predicates to flipped nested form. For example: [[a], [b], [ (c || d), (e || f)]
        return Stream.concat(commonPredicates.stream().map(predicate -> singletonList(predicate)), Stream.of(remainingPredicates))
                .collect(toList());
    }

    private RowExpression combineGroupedClauses(Form clauseJoiner, List<List<RowExpression>> nestedPredicates)
    {
        return combinePredicates(clauseJoiner, nestedPredicates.stream()
                .map(predicate -> combinePredicates(flip(clauseJoiner), predicate))
                .collect(toList()));
    }

    /**
     * Cartesian cross product of List of List.
     * For example, [[a], [b, c], [d]] becomes [[a,b,d], [a,c,d]]
     */
    private static List<List<RowExpression>> crossProduct(List<List<RowExpression>> groupedPredicates)
    {
        checkArgument(groupedPredicates.size() > 0, "Must contains more than one child");
        List<List<RowExpression>> result = groupedPredicates.get(0).stream().map(Collections::singletonList).collect(toList());
        for (int i = 1; i < groupedPredicates.size(); i++) {
            result = crossProduct(result, groupedPredicates.get(i));
        }
        return result;
    }

    private static List<List<RowExpression>> crossProduct(List<List<RowExpression>> previousCrossProduct, List<RowExpression> clauses)
    {
        List<List<RowExpression>> result = new ArrayList<>();
        for (List<RowExpression> previousClauses : previousCrossProduct) {
            for (RowExpression newClause : clauses) {
                List<RowExpression> newClauses = new ArrayList<>(previousClauses);
                newClauses.add(newClause);
                result.add(newClauses);
            }
        }
        return result;
    }

    private static Form flip(Form binaryLogicalOperation)
    {
        switch (binaryLogicalOperation) {
            case AND:
                return OR;
            case OR:
                return AND;
        }
        throw new UnsupportedOperationException("Invalid binary logical operation: " + binaryLogicalOperation);
    }

    private Optional<OperatorType> getOperator(RowExpression expression)
    {
        if (expression instanceof CallExpression) {
            return functionMetadataManager.getFunctionMetadata(((CallExpression) expression).getFunctionHandle()).getOperatorType();
        }
        return Optional.empty();
    }

    public RowExpression notCallExpression(RowExpression argument)
    {
        return new CallExpression(argument.getSourceLocation(), "not", functionResolution.notFunction(), BOOLEAN, singletonList(argument));
    }

    public RowExpression equalsCallExpression(RowExpression left, RowExpression right)
    {
        return new CallExpression(
                EQUAL.name(),
                functionResolution.comparisonFunction(EQUAL, left.getType(), right.getType()),
                BOOLEAN,
                asList(left, right));
    }

    public static RowExpression replaceArguments(CallExpression expression, RowExpression... arguments)
    {
        return new CallExpression(
                expression.getDisplayName(),
                expression.getFunctionHandle(),
                expression.getType(),
                asList(arguments));
    }

    private static OperatorType negate(OperatorType operator)
    {
        switch (operator) {
            case EQUAL:
                return NOT_EQUAL;
            case NOT_EQUAL:
                return EQUAL;
            case GREATER_THAN:
                return LESS_THAN_OR_EQUAL;
            case LESS_THAN:
                return GREATER_THAN_OR_EQUAL;
            case LESS_THAN_OR_EQUAL:
                return GREATER_THAN;
            case GREATER_THAN_OR_EQUAL:
                return LESS_THAN;
        }
        return null;
    }

    private static void checkArgument(boolean condition, String message, Object... arguments)
    {
        if (!condition) {
            throw new IllegalArgumentException(String.format(message, arguments));
        }
    }

    private static <T> boolean isSuperSet(Collection<T> a, Collection<T> b)
    {
        // We assumes a, b both are de-duplicated collections.
        return a.size() > b.size() && a.containsAll(b);
    }

    private static <T> boolean isSameSet(Collection<T> a, Collection<T> b)
    {
        // We assumes a, b both are de-duplicated collections.
        return a.size() == b.size() && a.containsAll(b) && b.containsAll(a);
    }
}