ExpressionVerifier.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.sql.tree.WhenClause;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

/**
 * Expression visitor which verifies if given expression (actual) is matching other expression given as context (expected).
 * Visitor returns true if plans match to each other.
 * <p/>
 * Note that actual expression is using real name references (table columns etc) while expected expression is using symbol aliases.
 * Given symbol alias can point only to one real name reference.
 * <p/>
 * Example:
 * <pre>
 * NOT (orderkey = 3 AND custkey = 3 AND orderkey < 10)
 * </pre>
 * will match to:
 * <pre>
 * NOT (X = 3 AND Y = 3 AND X < 10)
 * </pre>
 * , but will not match to:
 * <pre>
 * NOT (X = 3 AND Y = 3 AND Z < 10)
 * </pre>
 * nor  to
 * <pre>
 * NOT (X = 3 AND X = 3 AND X < 10)
 * </pre>
 */
final class ExpressionVerifier
        extends AstVisitor<Boolean, Node>
{
    private final SymbolAliases symbolAliases;

    ExpressionVerifier(SymbolAliases symbolAliases)
    {
        this.symbolAliases = requireNonNull(symbolAliases, "symbolAliases is null");
    }

    @Override
    protected Boolean visitNode(Node node, Node context)
    {
        throw new IllegalStateException(format("Node %s is not supported", node));
    }

    @Override
    protected Boolean visitTryExpression(TryExpression actual, Node expected)
    {
        if (!(expected instanceof TryExpression)) {
            return false;
        }

        return process(actual.getInnerExpression(), ((TryExpression) expected).getInnerExpression());
    }

    @Override
    protected Boolean visitCast(Cast actual, Node expectedExpression)
    {
        if (!(expectedExpression instanceof Cast)) {
            return false;
        }

        Cast expected = (Cast) expectedExpression;

        if (!actual.getType().equals(expected.getType())) {
            return false;
        }

        return process(actual.getExpression(), expected.getExpression());
    }

    @Override
    protected Boolean visitIsNullPredicate(IsNullPredicate actual, Node expectedExpression)
    {
        if (!(expectedExpression instanceof IsNullPredicate)) {
            return false;
        }

        IsNullPredicate expected = (IsNullPredicate) expectedExpression;

        return process(actual.getValue(), expected.getValue());
    }

    @Override
    protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Node expectedExpression)
    {
        if (!(expectedExpression instanceof IsNotNullPredicate)) {
            return false;
        }

        IsNotNullPredicate expected = (IsNotNullPredicate) expectedExpression;

        return process(actual.getValue(), expected.getValue());
    }

    @Override
    protected Boolean visitInPredicate(InPredicate actual, Node expectedExpression)
    {
        if (expectedExpression instanceof InPredicate) {
            InPredicate expected = (InPredicate) expectedExpression;

            if (actual.getValueList() instanceof InListExpression) {
                return process(actual.getValue(), expected.getValue()) && process(actual.getValueList(), expected.getValueList());
            }
            else {
                checkState(
                        expected.getValueList() instanceof InListExpression,
                        "ExpressionVerifier doesn't support unpacked expected values. Feel free to add support if needed");
                /*
                 * If the expected value is a value list, but the actual is e.g. a SymbolReference,
                 * we need to unpack the value from the list so that when we hit visitSymbolReference, the
                 * expected.toString() call returns something that the symbolAliases actually contains.
                 * For example, InListExpression.toString returns "(onlyitem)" rather than "onlyitem".
                 *
                 * This is required because actual passes through the analyzer, planner, and possibly optimizers,
                 * one of which sometimes takes the liberty of unpacking the InListExpression.
                 *
                 * Since the expected value doesn't go through all of that, we have to deal with the case
                 * of the actual value being unpacked, but the expected value being an InListExpression.
                 */
                List<Expression> values = ((InListExpression) expected.getValueList()).getValues();
                checkState(values.size() == 1, "Multiple expressions in expected value list %s, but actual value is not a list", values, actual.getValue());
                Expression onlyExpectedExpression = values.get(0);
                return process(actual.getValue(), expected.getValue()) && process(actual.getValueList(), onlyExpectedExpression);
            }
        }
        return false;
    }

    @Override
    protected Boolean visitComparisonExpression(ComparisonExpression actual, Node expectedExpression)
    {
        if (expectedExpression instanceof ComparisonExpression) {
            ComparisonExpression expected = (ComparisonExpression) expectedExpression;
            if (actual.getOperator() == expected.getOperator()) {
                return process(actual.getLeft(), expected.getLeft()) && process(actual.getRight(), expected.getRight());
            }
        }
        return false;
    }

    @Override
    protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression actual, Node expectedExpression)
    {
        if (expectedExpression instanceof ArithmeticBinaryExpression) {
            ArithmeticBinaryExpression expected = (ArithmeticBinaryExpression) expectedExpression;
            if (actual.getOperator() == expected.getOperator()) {
                return process(actual.getLeft(), expected.getLeft()) && process(actual.getRight(), expected.getRight());
            }
        }
        return false;
    }

    protected Boolean visitGenericLiteral(GenericLiteral actual, Node expected)
    {
        if (expected instanceof GenericLiteral) {
            return getValueFromLiteral(actual).equals(getValueFromLiteral(expected));
        }

        return false;
    }

    @Override
    protected Boolean visitLongLiteral(LongLiteral actual, Node expected)
    {
        if (expected instanceof LongLiteral) {
            return getValueFromLiteral(actual).equals(getValueFromLiteral(expected));
        }

        return false;
    }

    @Override
    protected Boolean visitDoubleLiteral(DoubleLiteral actual, Node expected)
    {
        if (expected instanceof DoubleLiteral) {
            return getValueFromLiteral(actual).equals(getValueFromLiteral(expected));
        }

        return false;
    }

    @Override
    protected Boolean visitDecimalLiteral(DecimalLiteral actual, Node expected)
    {
        if (expected instanceof DecimalLiteral) {
            return getValueFromLiteral(actual).equals(getValueFromLiteral(expected));
        }

        return false;
    }

    @Override
    protected Boolean visitBooleanLiteral(BooleanLiteral actual, Node expected)
    {
        if (expected instanceof BooleanLiteral) {
            return getValueFromLiteral(actual).equals(getValueFromLiteral(expected));
        }
        return false;
    }

    private static String getValueFromLiteral(Node expression)
    {
        if (expression instanceof LongLiteral) {
            return String.valueOf(((LongLiteral) expression).getValue());
        }
        else if (expression instanceof BooleanLiteral) {
            return String.valueOf(((BooleanLiteral) expression).getValue());
        }
        else if (expression instanceof DoubleLiteral) {
            return String.valueOf(((DoubleLiteral) expression).getValue());
        }
        else if (expression instanceof DecimalLiteral) {
            return String.valueOf(((DecimalLiteral) expression).getValue());
        }
        else if (expression instanceof GenericLiteral) {
            return ((GenericLiteral) expression).getValue();
        }
        else {
            throw new IllegalArgumentException("Unsupported literal expression type: " + expression.getClass().getName());
        }
    }

    @Override
    protected Boolean visitStringLiteral(StringLiteral actual, Node expectedExpression)
    {
        if (expectedExpression instanceof StringLiteral) {
            StringLiteral expected = (StringLiteral) expectedExpression;
            return actual.getValue().equals(expected.getValue());
        }
        return false;
    }

    @Override
    protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression actual, Node expectedExpression)
    {
        if (expectedExpression instanceof LogicalBinaryExpression) {
            LogicalBinaryExpression expected = (LogicalBinaryExpression) expectedExpression;
            if (actual.getOperator() == expected.getOperator()) {
                return process(actual.getLeft(), expected.getLeft()) && process(actual.getRight(), expected.getRight());
            }
        }
        return false;
    }

    @Override
    protected Boolean visitBetweenPredicate(BetweenPredicate actual, Node expectedExpression)
    {
        if (expectedExpression instanceof BetweenPredicate) {
            BetweenPredicate expected = (BetweenPredicate) expectedExpression;
            return process(actual.getValue(), expected.getValue()) && process(actual.getMin(), expected.getMin()) && process(actual.getMax(), expected.getMax());
        }

        return false;
    }

    @Override
    protected Boolean visitNotExpression(NotExpression actual, Node expected)
    {
        if (expected instanceof NotExpression) {
            return process(actual.getValue(), ((NotExpression) expected).getValue());
        }
        return false;
    }

    @Override
    protected Boolean visitSymbolReference(SymbolReference actual, Node expected)
    {
        if (!(expected instanceof SymbolReference)) {
            return false;
        }
        return symbolAliases.get(((SymbolReference) expected).getName()).equals(actual) ||
                expected.equals(actual);
    }

    @Override
    protected Boolean visitCoalesceExpression(CoalesceExpression actual, Node expected)
    {
        if (!(expected instanceof CoalesceExpression)) {
            return false;
        }

        CoalesceExpression expectedCoalesce = (CoalesceExpression) expected;
        if (actual.getOperands().size() == expectedCoalesce.getOperands().size()) {
            boolean verified = true;
            for (int i = 0; i < actual.getOperands().size(); i++) {
                verified &= process(actual.getOperands().get(i), expectedCoalesce.getOperands().get(i));
            }
            return verified;
        }
        return false;
    }

    @Override
    protected Boolean visitSimpleCaseExpression(SimpleCaseExpression actual, Node expected)
    {
        if (!(expected instanceof SimpleCaseExpression)) {
            return false;
        }
        SimpleCaseExpression expectedCase = (SimpleCaseExpression) expected;
        if (!process(actual.getOperand(), expectedCase.getOperand())) {
            return false;
        }

        if (!process(actual.getWhenClauses(), expectedCase.getWhenClauses())) {
            return false;
        }

        return process(actual.getDefaultValue(), expectedCase.getDefaultValue());
    }

    @Override
    protected Boolean visitWhenClause(WhenClause actual, Node expected)
    {
        if (!(expected instanceof WhenClause)) {
            return false;
        }
        WhenClause expectedWhenClause = (WhenClause) expected;

        return process(actual.getOperand(), expectedWhenClause.getOperand()) && process(actual.getResult(), expectedWhenClause.getResult());
    }

    @Override
    protected Boolean visitFunctionCall(FunctionCall actual, Node expected)
    {
        if (!(expected instanceof FunctionCall)) {
            return false;
        }
        FunctionCall expectedFunction = (FunctionCall) expected;

        if (actual.isDistinct() != expectedFunction.isDistinct()) {
            return false;
        }

        if (!actual.getName().equals(expectedFunction.getName())) {
            return false;
        }

        if (!process(actual.getArguments(), expectedFunction.getArguments())) {
            return false;
        }

        if (!process(actual.getFilter(), expectedFunction.getFilter())) {
            return false;
        }

        if (!process(actual.getWindow(), expectedFunction.getWindow())) {
            return false;
        }

        return true;
    }

    @Override
    protected Boolean visitNullLiteral(NullLiteral node, Node expected)
    {
        return expected instanceof NullLiteral;
    }

    @Override
    protected Boolean visitInListExpression(InListExpression actual, Node expected)
    {
        if (!(expected instanceof InListExpression)) {
            return false;
        }

        InListExpression expectedInList = (InListExpression) expected;
        return process(actual.getValues(), expectedInList.getValues());
    }

    @Override
    protected Boolean visitDereferenceExpression(DereferenceExpression actual, Node expectedExpression)
    {
        if (!(expectedExpression instanceof DereferenceExpression)) {
            return false;
        }

        DereferenceExpression expected = (DereferenceExpression) expectedExpression;
        if (actual.getField().equals(expected.getField())) {
            return process(actual.getBase(), expected.getBase());
        }
        return false;
    }

    @Override
    protected Boolean visitSubscriptExpression(SubscriptExpression actual, Node expectedExpression)
    {
        if (!(expectedExpression instanceof SubscriptExpression)) {
            return false;
        }

        SubscriptExpression expected = (SubscriptExpression) expectedExpression;

        return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex());
    }

    @Override
    protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Node expectedExpression)
    {
        if (!(expectedExpression instanceof SearchedCaseExpression)) {
            return false;
        }

        SearchedCaseExpression expected = (SearchedCaseExpression) expectedExpression;
        return process(actual.getDefaultValue(), expected.getDefaultValue()) && process(actual.getWhenClauses(), expected.getWhenClauses());
    }

    @Override
    protected Boolean visitIfExpression(IfExpression actual, Node expectedExpression)
    {
        if (!(expectedExpression instanceof IfExpression)) {
            return false;
        }

        IfExpression expected = (IfExpression) expectedExpression;

        return process(actual.getCondition(), expected.getCondition())
                && process(actual.getTrueValue(), expected.getTrueValue())
                && process(actual.getFalseValue(), expected.getFalseValue());
    }

    private <T extends Node> boolean process(List<T> actuals, List<T> expecteds)
    {
        if (actuals.size() != expecteds.size()) {
            return false;
        }
        for (int i = 0; i < actuals.size(); i++) {
            if (!process(actuals.get(i), expecteds.get(i))) {
                return false;
            }
        }
        return true;
    }

    private <T extends Node> boolean process(Optional<T> actual, Optional<T> expected)
    {
        if (actual.isPresent() != expected.isPresent()) {
            return false;
        }
        if (actual.isPresent()) {
            return process(actual.get(), expected.get());
        }
        return true;
    }
}