ImplementIntersectAndExceptAsUnion.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SetOperationNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Maps;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.spi.plan.AggregationNode.Step;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.fromListMultimap;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.relational.Expressions.comparisonExpression;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.concat;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

/**
 * Converts INTERSECT and EXCEPT queries into UNION ALL..GROUP BY...WHERE
 * Eg:  SELECT a FROM foo INTERSECT SELECT x FROM bar
 * <p/>
 * =>
 * <p/>
 * SELECT a
 * FROM
 * (SELECT a,
 * COUNT(foo_marker) AS foo_cnt,
 * COUNT(bar_marker) AS bar_cnt
 * FROM
 * (
 * SELECT a, true as foo_marker, null as bar_marker FROM foo
 * UNION ALL
 * SELECT x, null as foo_marker, true as bar_marker FROM bar
 * ) T1
 * GROUP BY a) T2
 * WHERE foo_cnt >= 1 AND bar_cnt >= 1;
 * <p>
 * Eg:  SELECT a FROM foo EXCEPT SELECT x FROM bar
 * <p/>
 * =>
 * <p/>
 * SELECT a
 * FROM
 * (SELECT a,
 * COUNT(foo_marker) AS foo_cnt,
 * COUNT(bar_marker) AS bar_cnt
 * FROM
 * (
 * SELECT a, true as foo_marker, null as bar_marker FROM foo
 * UNION ALL
 * SELECT x, null as foo_marker, true as bar_marker FROM bar
 * ) T1
 * GROUP BY a) T2
 * WHERE foo_cnt >= 1 AND bar_cnt = 0;
 */
public class ImplementIntersectAndExceptAsUnion
        implements PlanOptimizer
{
    private final FunctionAndTypeManager functionAndTypeManager;

    public ImplementIntersectAndExceptAsUnion(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        requireNonNull(plan, "plan is null");
        requireNonNull(session, "session is null");
        requireNonNull(types, "types is null");
        requireNonNull(variableAllocator, "variableAllocator is null");
        requireNonNull(idAllocator, "idAllocator is null");

        Rewriter rewriter = new Rewriter(session, functionAndTypeManager, idAllocator, variableAllocator);
        PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan);
        return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
    }

    private static class Rewriter
            extends SimplePlanRewriter<Void>
    {
        private static final String MARKER = "marker";

        private final Session session;
        private final StandardFunctionResolution functionResolution;
        private final PlanNodeIdAllocator idAllocator;
        private final VariableAllocator variableAllocator;
        private boolean planChanged;

        private Rewriter(Session session, FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
        {
            requireNonNull(functionAndTypeManager, "functionManager is null");
            this.session = requireNonNull(session, "session is null");
            this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
            this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
            this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
        }

        @Override
        public PlanNode visitIntersect(IntersectNode node, RewriteContext<Void> rewriteContext)
        {
            List<PlanNode> sources = node.getSources().stream()
                    .map(rewriteContext::rewrite)
                    .collect(toList());

            List<VariableReferenceExpression> markers = allocateVariables(sources.size(), MARKER, BOOLEAN);

            // identity projection for all the fields in each of the sources plus marker columns
            List<PlanNode> withMarkers = appendMarkers(markers, sources, node);

            // add a union over all the rewritten sources. The outputs of the union have the same name as the
            // original intersect node
            List<VariableReferenceExpression> outputs = node.getOutputVariables();
            UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));

            // add count aggregations and filter rows where any of the counts is >= 1
            List<VariableReferenceExpression> aggregationOutputs = allocateVariables(markers.size(), "count", BIGINT);
            AggregationNode aggregation = computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs);
            FilterNode filterNode = addFilterForIntersect(aggregation);

            planChanged = true;
            return project(filterNode, outputs);
        }

        @Override
        public PlanNode visitExcept(ExceptNode node, RewriteContext<Void> rewriteContext)
        {
            List<PlanNode> sources = node.getSources().stream()
                    .map(rewriteContext::rewrite)
                    .collect(toList());

            List<VariableReferenceExpression> markers = allocateVariables(sources.size(), MARKER, BOOLEAN);

            // identity projection for all the fields in each of the sources plus marker columns
            List<PlanNode> withMarkers = appendMarkers(markers, sources, node);

            // add a union over all the rewritten sources. The outputs of the union have the same name as the
            // original except node
            List<VariableReferenceExpression> outputs = node.getOutputVariables();
            UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));

            // add count aggregations and filter rows where count for the first source is >= 1 and all others are 0
            List<VariableReferenceExpression> aggregationOutputs = allocateVariables(markers.size(), "count", BIGINT);
            AggregationNode aggregation = computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs);
            FilterNode filterNode = addFilterForExcept(aggregation, aggregationOutputs.get(0), aggregationOutputs.subList(1, aggregationOutputs.size()));

            planChanged = true;
            return project(filterNode, outputs);
        }

        private List<VariableReferenceExpression> allocateVariables(int count, String nameHint, Type type)
        {
            ImmutableList.Builder<VariableReferenceExpression> variablesBuilder = ImmutableList.builder();
            for (int i = 0; i < count; i++) {
                variablesBuilder.add(variableAllocator.newVariable(nameHint, type));
            }
            return variablesBuilder.build();
        }

        private List<PlanNode> appendMarkers(List<VariableReferenceExpression> markers, List<PlanNode> nodes, SetOperationNode node)
        {
            ImmutableList.Builder<PlanNode> result = ImmutableList.builder();
            for (int i = 0; i < nodes.size(); i++) {
                result.add(appendMarkers(nodes.get(i), i, markers, Maps.transformValues(node.sourceVariableMap(i), variable -> variable)));
            }
            return result.build();
        }

        private PlanNode appendMarkers(PlanNode source, int markerIndex, List<VariableReferenceExpression> markers, Map<VariableReferenceExpression, VariableReferenceExpression> projections)
        {
            Assignments.Builder assignments = Assignments.builder();
            // add existing intersect symbols to projection
            for (Map.Entry<VariableReferenceExpression, VariableReferenceExpression> entry : projections.entrySet()) {
                VariableReferenceExpression variable = variableAllocator.newVariable(entry.getKey().getSourceLocation(), entry.getKey().getName(), entry.getKey().getType());
                assignments.put(variable, entry.getValue());
            }

            // add extra marker fields to the projection
            for (int i = 0; i < markers.size(); ++i) {
                RowExpression expression = (i == markerIndex) ? TRUE_CONSTANT : new ConstantExpression(null, BOOLEAN);
                assignments.put(variableAllocator.newVariable(markers.get(i).getSourceLocation(), markers.get(i).getName(), BOOLEAN), expression);
            }

            return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
        }

        private UnionNode union(List<PlanNode> nodes, List<VariableReferenceExpression> outputs)
        {
            ImmutableListMultimap.Builder<VariableReferenceExpression, VariableReferenceExpression> outputsToInputs = ImmutableListMultimap.builder();
            for (PlanNode source : nodes) {
                for (int i = 0; i < source.getOutputVariables().size(); i++) {
                    outputsToInputs.put(outputs.get(i), source.getOutputVariables().get(i));
                }
            }

            ListMultimap<VariableReferenceExpression, VariableReferenceExpression> mapping = outputsToInputs.build();
            return new UnionNode(nodes.get(0).getSourceLocation(), idAllocator.getNextId(), nodes, ImmutableList.copyOf(mapping.keySet()), fromListMultimap(mapping));
        }

        private AggregationNode computeCounts(UnionNode sourceNode, List<VariableReferenceExpression> originalColumns, List<VariableReferenceExpression> markers, List<VariableReferenceExpression> aggregationOutputs)
        {
            ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();

            for (int i = 0; i < markers.size(); i++) {
                VariableReferenceExpression output = aggregationOutputs.get(i);
                aggregations.put(output, new Aggregation(
                        new CallExpression(
                                output.getSourceLocation(),
                                "count",
                                functionResolution.countFunction(markers.get(i).getType()),
                                BIGINT,
                                ImmutableList.of(markers.get(i))),
                        Optional.empty(),
                        Optional.empty(),
                        false,
                        Optional.empty()));
            }

            return new AggregationNode(
                    sourceNode.getSourceLocation(),
                    idAllocator.getNextId(),
                    sourceNode,
                    aggregations.build(),
                    singleGroupingSet(originalColumns),
                    ImmutableList.of(),
                    Step.SINGLE,
                    Optional.empty(),
                    Optional.empty(),
                    Optional.empty());
        }

        private FilterNode addFilterForIntersect(AggregationNode aggregation)
        {
            ImmutableList<RowExpression> predicates = aggregation.getAggregations().keySet().stream()
                    .map(column -> comparisonExpression(functionResolution, GREATER_THAN_OR_EQUAL, column, new ConstantExpression(1L, BIGINT)))
                    .collect(toImmutableList());
            return new FilterNode(aggregation.getSourceLocation(), idAllocator.getNextId(), aggregation, and(predicates));
        }

        private FilterNode addFilterForExcept(AggregationNode aggregation, VariableReferenceExpression firstSource, List<VariableReferenceExpression> remainingSources)
        {
            ImmutableList.Builder<RowExpression> predicatesBuilder = ImmutableList.builder();
            predicatesBuilder.add(comparisonExpression(functionResolution, GREATER_THAN_OR_EQUAL, firstSource, new ConstantExpression(1L, BIGINT)));
            for (VariableReferenceExpression variable : remainingSources) {
                predicatesBuilder.add(comparisonExpression(functionResolution, EQUAL, variable, new ConstantExpression(0L, BIGINT)));
            }

            return new FilterNode(aggregation.getSourceLocation(), idAllocator.getNextId(), aggregation, and(predicatesBuilder.build()));
        }

        private ProjectNode project(PlanNode node, List<VariableReferenceExpression> columns)
        {
            return new ProjectNode(
                    idAllocator.getNextId(),
                    node,
                    identityAssignments(columns));
        }

        public boolean isPlanChanged()
        {
            return planChanged;
        }
    }
}