ValuesMatcher.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.ShortDecimalType;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.StringLiteral;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import io.airlift.slice.Slice;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference;
import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH;
import static com.facebook.presto.sql.planner.assertions.MatchResult.match;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;

public class ValuesMatcher
        implements Matcher
{
    private final Map<String, Integer> outputSymbolAliases;
    private final Optional<Integer> expectedOutputSymbolCount;
    private final Optional<Set<List<Expression>>> expectedRows;

    public ValuesMatcher(
            Map<String, Integer> outputSymbolAliases,
            Optional<Integer> expectedOutputSymbolCount,
            Optional<List<List<Expression>>> expectedRows)
    {
        this.outputSymbolAliases = ImmutableMap.copyOf(outputSymbolAliases);
        this.expectedOutputSymbolCount = requireNonNull(expectedOutputSymbolCount, "expectedOutputSymbolCount is null");
        this.expectedRows = requireNonNull(expectedRows, "expectedRows is null").map(ImmutableSet::copyOf);
    }

    @Override
    public boolean shapeMatches(PlanNode node)
    {
        return (node instanceof ValuesNode) &&
                expectedOutputSymbolCount.map(Integer.valueOf(node.getOutputVariables().size())::equals).orElse(true);
    }

    @Override
    public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases)
    {
        checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
        ValuesNode valuesNode = (ValuesNode) node;

        if (!expectedRows.map(rows -> rows.equals(valuesNode.getRows()
                .stream()
                .map(rowExpressions -> rowExpressions.stream()
                        .map(rowExpression -> {
                            ConstantExpression expression = (ConstantExpression) rowExpression;
                            if (expression.getType().getJavaType() == boolean.class) {
                                return new BooleanLiteral(String.valueOf(expression.getValue()));
                            }
                            if (expression.getType() instanceof ShortDecimalType) {
                                return new DecimalLiteral(String.valueOf(expression.getValue()));
                            }
                            if (expression.getType().getJavaType() == long.class) {
                                return new LongLiteral(String.valueOf(expression.getValue()));
                            }
                            if (expression.getType().getJavaType() == double.class) {
                                return new DoubleLiteral(String.valueOf(expression.getValue()));
                            }
                            if (expression.getType().getJavaType() == Slice.class) {
                                return new StringLiteral(((Slice) expression.getValue()).toStringUtf8());
                            }
                            return new GenericLiteral(expression.getType().toString(), String.valueOf(expression.getValue()));
                        })
                        .collect(toImmutableList()))
                .collect(toImmutableSet())))
                .orElse(true)) {
            return NO_MATCH;
        }

        MatchResult result = match(SymbolAliases.builder()
                .putAll(Maps.transformValues(outputSymbolAliases, index -> createSymbolReference(valuesNode.getOutputVariables().get(index))))
                .build());
        return result;
    }

    @Override
    public String toString()
    {
        return toStringHelper(this)
                .omitNullValues()
                .add("outputSymbolAliases", outputSymbolAliases)
                .add("expectedOutputSymbolCount", expectedOutputSymbolCount)
                .add("expectedRows", expectedRows)
                .toString();
    }
}