DynamicFilterMatcher.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.Session;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.expressions.DynamicFilters;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.base.Joiner;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;

public class DynamicFilterMatcher
        implements Matcher
{
    // LEFT_SYMBOL -> RIGHT_SYMBOL
    private final Map<SymbolAlias, SymbolAlias> expectedDynamicFilters;
    private final Map<String, String> joinExpectedMappings;
    private final Map<String, String> filterExpectedMappings;
    private final Optional<Expression> expectedStaticFilter;

    private JoinNode joinNode;
    private SymbolAliases symbolAliases;
    private FilterNode filterNode;

    public DynamicFilterMatcher(Map<SymbolAlias, SymbolAlias> expectedDynamicFilters, Optional<Expression> expectedStaticFilter)
    {
        this.expectedDynamicFilters = requireNonNull(expectedDynamicFilters, "expectedDynamicFilters is null");
        this.joinExpectedMappings = expectedDynamicFilters.values().stream()
                .collect(toImmutableMap(rightSymbol -> rightSymbol.toString() + "_alias", SymbolAlias::toString));
        this.filterExpectedMappings = expectedDynamicFilters.entrySet().stream()
                .collect(toImmutableMap(entry -> entry.getKey().toString(), entry -> entry.getValue().toString() + "_alias"));
        this.expectedStaticFilter = requireNonNull(expectedStaticFilter, "expectedStaticFilter is null");
    }

    public MatchResult match(JoinNode joinNode, SymbolAliases symbolAliases)
    {
        checkState(this.joinNode == null, "joinNode must be null at this point");
        this.joinNode = joinNode;
        this.symbolAliases = symbolAliases;
        return new MatchResult(match());
    }

    private MatchResult match(FilterNode filterNode, Session session, Metadata metadata, SymbolAliases symbolAliases)
    {
        checkState(this.filterNode == null, "filterNode must be null at this point");
        this.filterNode = filterNode;
        this.symbolAliases = symbolAliases;

        LogicalRowExpressions logicalRowExpressions = new LogicalRowExpressions(
                new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()),
                new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()),
                metadata.getFunctionAndTypeManager());
        boolean staticFilterMatches = expectedStaticFilter.map(filter -> {
            RowExpressionVerifier verifier = new RowExpressionVerifier(symbolAliases, metadata, session);
            RowExpression staticFilter = logicalRowExpressions.combineConjuncts(extractDynamicFilters(filterNode.getPredicate()).getStaticConjuncts());
            return verifier.process(filter, staticFilter);
        }).orElse(true);

        return new MatchResult(match() && staticFilterMatches);
    }

    private boolean match()
    {
        checkState(symbolAliases != null, "symbolAliases is null");

        // both nodes must be provided to do the matching
        if (filterNode == null || joinNode == null) {
            return true;
        }

        Map<String, VariableReferenceExpression> idToProbeSymbolMap = extractDynamicFilters(filterNode.getPredicate())
                .getDynamicConjuncts().stream()
                .collect(toImmutableMap(DynamicFilters.DynamicFilterPlaceholder::getId, filter -> (VariableReferenceExpression) filter.getInput()));
        Map<String, VariableReferenceExpression> idToBuildSymbolMap = joinNode.getDynamicFilters();

        if (idToProbeSymbolMap == null) {
            return false;
        }

        if (idToProbeSymbolMap.size() != expectedDynamicFilters.size()) {
            return false;
        }

        Map<Symbol, Symbol> actual = new HashMap<>();
        for (Map.Entry<String, VariableReferenceExpression> idToProbeSymbol : idToProbeSymbolMap.entrySet()) {
            String id = idToProbeSymbol.getKey();
            VariableReferenceExpression probe = idToProbeSymbol.getValue();
            VariableReferenceExpression build = idToBuildSymbolMap.get(id);
            if (build == null) {
                return false;
            }
            actual.put(new Symbol(probe.getName()), new Symbol(build.getName()));
        }

        Map<Symbol, Symbol> expected = expectedDynamicFilters.entrySet().stream()
                .collect(toImmutableMap(entry -> entry.getKey().toSymbol(symbolAliases), entry -> entry.getValue().toSymbol(symbolAliases)));

        return expected.equals(actual);
    }

    @Override
    public boolean shapeMatches(PlanNode node)
    {
        return node instanceof FilterNode;
    }

    @Override
    public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases)
    {
        if (!(node instanceof FilterNode)) {
            return new MatchResult(false);
        }
        return match((FilterNode) node, session, metadata, symbolAliases);
    }

    public Map<String, String> getJoinExpectedMappings()
    {
        return joinExpectedMappings;
    }

    @Override
    public String toString()
    {
        String predicate = Joiner.on(" AND ")
                .join(filterExpectedMappings.entrySet().stream()
                        .map(entry -> entry.getKey() + " = " + entry.getValue())
                        .collect(toImmutableList()));
        return toStringHelper(this)
                .add("dynamicPredicate", predicate)
                .add("staticPredicate", expectedStaticFilter)
                .toString();
    }
}