RowExpressionVerifier.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.assertions;
import com.facebook.presto.Session;
import com.facebook.presto.common.block.IntArrayBlock;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.TimestampLiteral;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.sql.tree.WhenClause;
import io.airlift.slice.Slice;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.function.OperatorType.ADD;
import static com.facebook.presto.common.function.OperatorType.DIVIDE;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.MODULUS;
import static com.facebook.presto.common.function.OperatorType.MULTIPLY;
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.function.OperatorType.SUBTRACT;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Operator.AND;
import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Operator.OR;
import static com.facebook.presto.type.JoniRegexpType.JONI_REGEXP;
import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
/**
* RowExpression visitor which verifies if given expression (actual) is matching other RowExpression given as context (expected).
*/
public final class RowExpressionVerifier
extends AstVisitor<Boolean, RowExpression>
{
// either use variable or input reference for symbol mapping
private final SymbolAliases symbolAliases;
private final Metadata metadata;
private final Session session;
private final FunctionResolution functionResolution;
private final Set<String> lambdaArguments;
public RowExpressionVerifier(SymbolAliases symbolAliases, Metadata metadata, Session session)
{
this.symbolAliases = requireNonNull(symbolAliases, "symbolLayout is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.session = requireNonNull(session, "session is null");
this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
this.lambdaArguments = new HashSet<>();
}
@Override
protected Boolean visitNode(Node node, RowExpression context)
{
throw new IllegalStateException(format("Node %s is not supported", node));
}
@Override
protected Boolean visitTimestampLiteral(TimestampLiteral node, RowExpression context)
{
return compareLiteral(node, context);
}
@Override
protected Boolean visitArrayConstructor(ArrayConstructor node, RowExpression context)
{
if (context instanceof CallExpression) {
if (!((CallExpression) context).getFunctionHandle().getName().equals("presto.default.array_constructor")) {
return false;
}
for (int i = 0; i < node.getValues().size(); ++i) {
if (!process(node.getValues().get(i), ((CallExpression) context).getArguments().get(i))) {
return false;
}
}
return true;
}
else if (context instanceof ConstantExpression && ((ConstantExpression) context).getValue() instanceof IntArrayBlock) {
IntArrayBlock block = (IntArrayBlock) ((ConstantExpression) context).getValue();
if (block.getPositionCount() != node.getValues().size()) {
return false;
}
for (int i = 0; i < node.getValues().size(); ++i) {
if (!process(node.getValues().get(i), constant((long) block.getInt(i), BIGINT))) {
return false;
}
}
return true;
}
return false;
}
@Override
protected Boolean visitTryExpression(TryExpression expected, RowExpression actual)
{
if (!(actual instanceof CallExpression) || !functionResolution.isTryFunction(((CallExpression) actual).getFunctionHandle())) {
return false;
}
LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) ((CallExpression) actual).getArguments().get(0);
return process(expected.getInnerExpression(), lambdaExpression.getBody());
}
@Override
protected Boolean visitCast(Cast expected, RowExpression actual)
{
// TODO: clean up cast path
if (actual instanceof ConstantExpression && expected.getExpression() instanceof Literal && expected.getType().equals(actual.getType().toString())) {
Literal literal = (Literal) expected.getExpression();
if (literal instanceof StringLiteral) {
Object value = LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actual);
String actualString = value instanceof Slice ? ((Slice) value).toStringUtf8() : String.valueOf(value);
return ((StringLiteral) literal).getValue().equals(actualString);
}
return getValueFromLiteral(literal).equals(String.valueOf(LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actual)));
}
if (!(actual instanceof CallExpression) || (!functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle()) && !functionResolution.isTryCastFunction(((CallExpression) actual).getFunctionHandle()))) {
return false;
}
if (!expected.getType().equalsIgnoreCase(actual.getType().toString()) &&
!(expected.getType().toLowerCase(ENGLISH).equals(VARCHAR) && actual.getType().getTypeSignature().getBase().equals(VARCHAR))) {
return false;
}
return process(expected.getExpression(), ((CallExpression) actual).getArguments().get(0));
}
@Override
protected Boolean visitIfExpression(IfExpression expected, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression) || !((SpecialFormExpression) actual).getForm().equals(IF)) {
return false;
}
return process(expected.getCondition(), ((SpecialFormExpression) actual).getArguments().get(0)) &&
process(expected.getTrueValue(), ((SpecialFormExpression) actual).getArguments().get(1)) &&
process(expected.getFalseValue().orElseGet(() -> new NullLiteral()), ((SpecialFormExpression) actual).getArguments().get(2));
}
@Override
protected Boolean visitSubscriptExpression(SubscriptExpression expected, RowExpression actual)
{
if (!(actual instanceof CallExpression) || !functionResolution.isSubscriptFunction(((CallExpression) actual).getFunctionHandle())) {
return false;
}
return process(expected.getBase(), ((CallExpression) actual).getArguments().get(0)) &&
process(expected.getIndex(), ((CallExpression) actual).getArguments().get(1));
}
@Override
protected Boolean visitIsNullPredicate(IsNullPredicate expected, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression) || !((SpecialFormExpression) actual).getForm().equals(IS_NULL)) {
return false;
}
return process(expected.getValue(), ((SpecialFormExpression) actual).getArguments().get(0));
}
@Override
protected Boolean visitIsNotNullPredicate(IsNotNullPredicate expected, RowExpression actual)
{
if (!(actual instanceof CallExpression) || !functionResolution.notFunction().equals(((CallExpression) actual).getFunctionHandle())) {
return false;
}
RowExpression argument = ((CallExpression) actual).getArguments().get(0);
if (!(argument instanceof SpecialFormExpression) || !((SpecialFormExpression) argument).getForm().equals(IS_NULL)) {
return false;
}
return process(expected.getValue(), ((SpecialFormExpression) argument).getArguments().get(0));
}
@Override
protected Boolean visitSearchedCaseExpression(SearchedCaseExpression node, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression) || !((SpecialFormExpression) actual).getForm().equals(SWITCH)) {
return false;
}
SpecialFormExpression specialForm = (SpecialFormExpression) actual;
int argumentSize = node.getWhenClauses().size() + 1;
if (node.getDefaultValue().isPresent()) {
++argumentSize;
}
if (specialForm.getArguments().size() != argumentSize) {
return false;
}
if (!specialForm.getArguments().get(0).equals(constant(true, BooleanType.BOOLEAN))) {
return false;
}
for (int i = 0; i < node.getWhenClauses().size(); ++i) {
if (!process(node.getWhenClauses().get(i), specialForm.getArguments().get(i + 1))) {
return false;
}
}
if (node.getDefaultValue().isPresent()) {
return process(node.getDefaultValue().get(), specialForm.getArguments().get(argumentSize - 1));
}
return true;
}
@Override
protected Boolean visitInPredicate(InPredicate expected, RowExpression actual)
{
if (actual instanceof SpecialFormExpression && ((SpecialFormExpression) actual).getForm().equals(IN)) {
List<RowExpression> arguments = ((SpecialFormExpression) actual).getArguments();
if (expected.getValueList() instanceof InListExpression) {
return process(expected.getValue(), arguments.get(0)) && process(((InListExpression) expected.getValueList()).getValues(), arguments.subList(1, arguments.size()));
}
else {
/*
* If the actual value is a value list, but the expected is e.g. a SymbolReference,
* we need to unpack the value from the list so that when we hit visitSymbolReference, the
* actual.toString() call returns something that the symbolAliases expectedly contains.
* For example, InListExpression.toString returns "(onlyitem)" rather than "onlyitem".
*
* This is required because expected passes through the analyzer, planner, and possibly optimizers,
* one of which sometimes takes the liberty of unpacking the InListExpression.
*
* Since the actual value doesn't go through all of that, we have to deal with the case
* of the expected value being unpacked, but the actual value being an InListExpression.
*/
checkState(arguments.size() == 2, "Multiple expressions in actual value list %s, but expected value is not a list", arguments.subList(1, arguments.size()), expected.getValue());
return process(expected.getValue(), arguments.get(0)) && process(expected.getValueList(), arguments.get(1));
}
}
return false;
}
@Override
protected Boolean visitLambdaExpression(LambdaExpression expected, RowExpression actual)
{
if (!(actual instanceof LambdaDefinitionExpression)) {
return false;
}
LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) actual;
if (lambda.getArguments().size() != expected.getArguments().size()) {
return false;
}
for (int i = 0; i < lambda.getArguments().size(); ++i) {
lambdaArguments.add(lambda.getArguments().get(i));
if (!lambda.getArguments().get(i).equals(expected.getArguments().get(i).getName().getValue())) {
return false;
}
}
Boolean value = process(expected.getBody(), lambda.getBody());
lambda.getArguments().forEach(argument -> lambdaArguments.remove(argument));
return value;
}
@Override
protected Boolean visitComparisonExpression(ComparisonExpression expected, RowExpression actual)
{
if (actual instanceof CallExpression) {
FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(((CallExpression) actual).getFunctionHandle());
if (!functionMetadata.getOperatorType().isPresent() || !functionMetadata.getOperatorType().get().isComparisonOperator()) {
return false;
}
OperatorType actualOperatorType = functionMetadata.getOperatorType().get();
OperatorType expectedOperatorType = getOperatorType(expected.getOperator());
if (expectedOperatorType.equals(actualOperatorType)) {
if (actualOperatorType == EQUAL) {
return (process(expected.getLeft(), ((CallExpression) actual).getArguments().get(0)) && process(expected.getRight(), ((CallExpression) actual).getArguments().get(1)))
|| (process(expected.getLeft(), ((CallExpression) actual).getArguments().get(1)) && process(expected.getRight(), ((CallExpression) actual).getArguments().get(0)));
}
// TODO support other comparison operators
return process(expected.getLeft(), ((CallExpression) actual).getArguments().get(0)) && process(expected.getRight(), ((CallExpression) actual).getArguments().get(1));
}
}
return false;
}
private static OperatorType getOperatorType(ComparisonExpression.Operator operator)
{
OperatorType operatorType;
switch (operator) {
case EQUAL:
operatorType = EQUAL;
break;
case NOT_EQUAL:
operatorType = NOT_EQUAL;
break;
case LESS_THAN:
operatorType = LESS_THAN;
break;
case LESS_THAN_OR_EQUAL:
operatorType = LESS_THAN_OR_EQUAL;
break;
case GREATER_THAN:
operatorType = GREATER_THAN;
break;
case GREATER_THAN_OR_EQUAL:
operatorType = GREATER_THAN_OR_EQUAL;
break;
case IS_DISTINCT_FROM:
operatorType = IS_DISTINCT_FROM;
break;
default:
throw new IllegalStateException("Unsupported comparison operator type: " + operator);
}
return operatorType;
}
@Override
protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression expected, RowExpression actual)
{
if (actual instanceof CallExpression) {
FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(((CallExpression) actual).getFunctionHandle());
if (!functionMetadata.getOperatorType().isPresent() || !functionMetadata.getOperatorType().get().isArithmeticOperator()) {
return false;
}
OperatorType actualOperatorType = functionMetadata.getOperatorType().get();
OperatorType expectedOperatorType = getOperatorType(expected.getOperator());
if (expectedOperatorType.equals(actualOperatorType)) {
return process(expected.getLeft(), ((CallExpression) actual).getArguments().get(0)) && process(expected.getRight(), ((CallExpression) actual).getArguments().get(1));
}
}
return false;
}
private static OperatorType getOperatorType(ArithmeticBinaryExpression.Operator operator)
{
OperatorType operatorType;
switch (operator) {
case ADD:
operatorType = ADD;
break;
case SUBTRACT:
operatorType = SUBTRACT;
break;
case MULTIPLY:
operatorType = MULTIPLY;
break;
case DIVIDE:
operatorType = DIVIDE;
break;
case MODULUS:
operatorType = MODULUS;
break;
default:
throw new IllegalStateException("Unknown arithmetic operator: " + operator);
}
return operatorType;
}
@Override
protected Boolean visitGenericLiteral(GenericLiteral expected, RowExpression actual)
{
return compareLiteral(expected, actual);
}
@Override
protected Boolean visitLongLiteral(LongLiteral expected, RowExpression actual)
{
return compareLiteral(expected, actual);
}
@Override
protected Boolean visitDoubleLiteral(DoubleLiteral expected, RowExpression actual)
{
return compareLiteral(expected, actual);
}
@Override
protected Boolean visitDecimalLiteral(DecimalLiteral expected, RowExpression actual)
{
return compareLiteral(expected, actual);
}
@Override
protected Boolean visitBooleanLiteral(BooleanLiteral expected, RowExpression actual)
{
return compareLiteral(expected, actual);
}
@Override
protected Boolean visitDereferenceExpression(DereferenceExpression expected, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression) || !(((SpecialFormExpression) actual).getForm().equals(DEREFERENCE))) {
return false;
}
SpecialFormExpression actualDereference = (SpecialFormExpression) actual;
if (actualDereference.getArguments().size() == 2 &&
actualDereference.getArguments().get(0).getType() instanceof RowType &&
actualDereference.getArguments().get(1) instanceof ConstantExpression) {
RowType rowType = (RowType) actualDereference.getArguments().get(0).getType();
Object value = LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actualDereference.getArguments().get(1));
checkState(value instanceof Long);
long index = (Long) value;
checkState(index >= 0 && index < rowType.getFields().size());
RowType.Field field = rowType.getFields().get(toIntExact(index));
checkState(field.getName().isPresent());
return expected.getField().getValue().equals(field.getName().get()) && process(expected.getBase(), actualDereference.getArguments().get(0));
}
return false;
}
private static String getValueFromLiteral(Node expression)
{
if (expression instanceof LongLiteral) {
return String.valueOf(((LongLiteral) expression).getValue());
}
else if (expression instanceof BooleanLiteral) {
return String.valueOf(((BooleanLiteral) expression).getValue());
}
else if (expression instanceof DoubleLiteral) {
return String.valueOf(((DoubleLiteral) expression).getValue());
}
else if (expression instanceof DecimalLiteral) {
return String.valueOf(((DecimalLiteral) expression).getValue());
}
else if (expression instanceof TimestampLiteral) {
return ((TimestampLiteral) expression).getValue();
}
else if (expression instanceof GenericLiteral) {
return ((GenericLiteral) expression).getValue();
}
else {
throw new IllegalArgumentException("Unsupported literal expression type: " + expression.getClass().getName());
}
}
private Boolean compareLiteral(Node expected, RowExpression actual)
{
if (actual instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle())) {
return getValueFromLiteral(expected).equals(String.valueOf(rowExpressionInterpreter(actual, metadata.getFunctionAndTypeManager(), session.toConnectorSession()).evaluate()));
}
if (actual instanceof ConstantExpression) {
return getValueFromLiteral(expected).equals(String.valueOf(LiteralInterpreter.evaluate(session.toConnectorSession(), (ConstantExpression) actual)));
}
return false;
}
@Override
protected Boolean visitStringLiteral(StringLiteral expected, RowExpression actual)
{
if (actual instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle())) {
Object value = rowExpressionInterpreter(actual, metadata.getFunctionAndTypeManager(), session.toConnectorSession()).evaluate();
if (value instanceof Slice) {
return expected.getValue().equals(((Slice) value).toStringUtf8());
}
}
if (actual instanceof ConstantExpression && actual.getType().getJavaType() == Slice.class) {
String actualString = (String) LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actual);
return expected.getValue().equals(actualString);
}
return false;
}
@Override
protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression expected, RowExpression actual)
{
if (actual instanceof SpecialFormExpression) {
SpecialFormExpression actualLogicalBinary = (SpecialFormExpression) actual;
if ((expected.getOperator() == OR && actualLogicalBinary.getForm() == SpecialFormExpression.Form.OR) ||
(expected.getOperator() == AND && actualLogicalBinary.getForm() == SpecialFormExpression.Form.AND)) {
// `Logical AND` and `Logical OR` both satisfy the commutative property
return process(expected.getLeft(), actualLogicalBinary.getArguments().get(0)) ?
process(expected.getRight(), actualLogicalBinary.getArguments().get(1)) :
process(expected.getLeft(), actualLogicalBinary.getArguments().get(1)) &&
process(expected.getRight(), actualLogicalBinary.getArguments().get(0));
}
}
return false;
}
@Override
protected Boolean visitBetweenPredicate(BetweenPredicate expected, RowExpression actual)
{
if (actual instanceof CallExpression && functionResolution.isBetweenFunction(((CallExpression) actual).getFunctionHandle())) {
return process(expected.getValue(), ((CallExpression) actual).getArguments().get(0)) &&
process(expected.getMin(), ((CallExpression) actual).getArguments().get(1)) &&
process(expected.getMax(), ((CallExpression) actual).getArguments().get(2));
}
return false;
}
@Override
protected Boolean visitNotExpression(NotExpression expected, RowExpression actual)
{
if (!(actual instanceof CallExpression) || !functionResolution.notFunction().equals(((CallExpression) actual).getFunctionHandle())) {
return false;
}
return process(expected.getValue(), ((CallExpression) actual).getArguments().get(0));
}
@Override
protected Boolean visitSymbolReference(SymbolReference expected, RowExpression actual)
{
// LIKE will add a cast from VARCHAR to LIKE_PATTERN. However, LIKE_PATTERN is not a data type and can not add a cast(varchar as like_pattern) in test
// Hence match the cast argument here
if (actual instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) actual).getFunctionHandle()) &&
(actual.getType().equals(LIKE_PATTERN) || actual.getType().equals(JONI_REGEXP))) {
actual = ((CallExpression) actual).getArguments().get(0);
}
if (!(actual instanceof VariableReferenceExpression)) {
return false;
}
if (lambdaArguments.contains(expected.getName())) {
return ((VariableReferenceExpression) actual).getName().equals(expected.getName());
}
return symbolAliases.get((expected).getName()).getName().equals(((VariableReferenceExpression) actual).getName());
}
@Override
protected Boolean visitCoalesceExpression(CoalesceExpression expected, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression) || !(((SpecialFormExpression) actual).getForm().equals(COALESCE))) {
return false;
}
SpecialFormExpression actualCoalesce = (SpecialFormExpression) actual;
if (expected.getOperands().size() == actualCoalesce.getArguments().size()) {
boolean verified = true;
for (int i = 0; i < expected.getOperands().size(); i++) {
verified &= process(expected.getOperands().get(i), actualCoalesce.getArguments().get(i));
}
return verified;
}
return false;
}
@Override
protected Boolean visitSimpleCaseExpression(SimpleCaseExpression expected, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression && ((SpecialFormExpression) actual).getForm().equals(SWITCH))) {
return false;
}
SpecialFormExpression actualCase = (SpecialFormExpression) actual;
if (!process(expected.getOperand(), actualCase.getArguments().get(0))) {
return false;
}
List<RowExpression> whenClauses;
Optional<RowExpression> elseValue;
RowExpression last = actualCase.getArguments().get(actualCase.getArguments().size() - 1);
if (last instanceof SpecialFormExpression && ((SpecialFormExpression) last).getForm().equals(WHEN)) {
whenClauses = actualCase.getArguments().subList(1, actualCase.getArguments().size());
elseValue = Optional.empty();
}
else {
whenClauses = actualCase.getArguments().subList(1, actualCase.getArguments().size() - 1);
elseValue = Optional.of(last);
}
if (!process(expected.getWhenClauses(), whenClauses)) {
return false;
}
return process(expected.getDefaultValue(), elseValue);
}
@Override
protected Boolean visitWhenClause(WhenClause expected, RowExpression actual)
{
if (!(actual instanceof SpecialFormExpression && ((SpecialFormExpression) actual).getForm().equals(WHEN))) {
return false;
}
SpecialFormExpression actualWhenClause = (SpecialFormExpression) actual;
return process(expected.getOperand(), ((SpecialFormExpression) actual).getArguments().get(0)) &&
process(expected.getResult(), actualWhenClause.getArguments().get(1));
}
@Override
protected Boolean visitFunctionCall(FunctionCall expected, RowExpression actual)
{
if (!(actual instanceof CallExpression)) {
return false;
}
CallExpression actualFunction = (CallExpression) actual;
if (!expected.getName().getSuffix().equals(metadata.getFunctionAndTypeManager().getFunctionMetadata(actualFunction.getFunctionHandle()).getName().getObjectName())) {
return false;
}
return process(expected.getArguments(), actualFunction.getArguments());
}
@Override
protected Boolean visitNullLiteral(NullLiteral node, RowExpression actual)
{
return actual instanceof ConstantExpression && ((ConstantExpression) actual).getValue() == null;
}
@Override
protected Boolean visitLikePredicate(LikePredicate node, RowExpression actual)
{
if (!(actual instanceof CallExpression)) {
return false;
}
CallExpression callExpression = (CallExpression) actual;
if (!functionResolution.isLikeFunction(callExpression.getFunctionHandle())) {
return false;
}
return process(node.getValue(), callExpression.getArguments().get(0)) && process(node.getPattern(), callExpression.getArguments().get(1));
}
private <T extends Node> boolean process(List<T> expecteds, List<RowExpression> actuals)
{
if (expecteds.size() != actuals.size()) {
return false;
}
for (int i = 0; i < expecteds.size(); i++) {
if (!process(expecteds.get(i), actuals.get(i))) {
return false;
}
}
return true;
}
private <T extends Node> boolean process(Optional<T> expected, Optional<RowExpression> actual)
{
if (expected.isPresent() != actual.isPresent()) {
return false;
}
if (expected.isPresent()) {
return process(expected.get(), actual.get());
}
return true;
}
}