PruneUnreferencedOutputs.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.DeleteNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IndexSourceNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SetOperationNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.SpatialJoinNode;
import com.facebook.presto.spi.plan.StatisticAggregations;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.StatisticsWriterNode;
import com.facebook.presto.sql.planner.plan.TableWriterMergeNode;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.UpdateNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractAggregationUniqueVariables;
import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.fromListMultimap;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Sets.intersection;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

/**
 * Removes all computation that does is not referenced transitively from the root of the plan
 * <p>
 * E.g.,
 * <p>
 * {@code Output[$0] -> Project[$0 := $1 + $2, $3 = $4 / $5] -> ...}
 * <p>
 * gets rewritten as
 * <p>
 * {@code Output[$0] -> Project[$0 := $1 + $2] -> ...}
 */
public class PruneUnreferencedOutputs
        implements PlanOptimizer
{
    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        requireNonNull(plan, "plan is null");

        Rewriter rewriter = new Rewriter();
        PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, ImmutableSet.of());
        return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
    }

    private static class Rewriter
            extends SimplePlanRewriter<Set<VariableReferenceExpression>>
    {
        private boolean planChanged;

        public boolean isPlanChanged()
        {
            return planChanged;
        }

        @Override
        public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            return context.defaultRewrite(node, ImmutableSet.copyOf(node.getSource().getOutputVariables()));
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedOutputVariables = Sets.newHashSet(context.get());
            node.getPartitioningScheme().getHashColumn().ifPresent(expectedOutputVariables::add);
            node.getPartitioningScheme().getPartitioning().getVariableReferences()
                    .forEach(expectedOutputVariables::add);
            node.getOrderingScheme().ifPresent(orderingScheme -> expectedOutputVariables.addAll(orderingScheme.getOrderByVariables()));

            List<List<VariableReferenceExpression>> inputsBySource = new ArrayList<>(node.getInputs().size());
            for (int i = 0; i < node.getInputs().size(); i++) {
                inputsBySource.add(new ArrayList<>());
            }

            List<VariableReferenceExpression> newOutputVariables = new ArrayList<>(node.getOutputVariables().size());
            for (int i = 0; i < node.getOutputVariables().size(); i++) {
                VariableReferenceExpression outputVariable = node.getOutputVariables().get(i);
                if (expectedOutputVariables.contains(outputVariable)) {
                    newOutputVariables.add(outputVariable);
                    for (int source = 0; source < node.getInputs().size(); source++) {
                        inputsBySource.get(source).add(node.getInputs().get(source).get(i));
                    }
                }
            }
            planChanged = node.getOutputVariables().size() != newOutputVariables.size();

            // newOutputVariables contains all partition, sort and hash variables so simply swap the output layout
            PartitioningScheme partitioningScheme = new PartitioningScheme(
                    node.getPartitioningScheme().getPartitioning(),
                    newOutputVariables,
                    node.getPartitioningScheme().getHashColumn(),
                    node.getPartitioningScheme().isReplicateNullsAndAny(),
                    node.getPartitioningScheme().isScaleWriters(),
                    node.getPartitioningScheme().getEncoding(),
                    node.getPartitioningScheme().getBucketToPartition());

            ImmutableList.Builder<PlanNode> rewrittenSources = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); i++) {
                ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                        .addAll(inputsBySource.get(i));

                rewrittenSources.add(context.rewrite(
                        node.getSources().get(i),
                        expectedInputs.build()));
            }

            return new ExchangeNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getType(),
                    node.getScope(),
                    partitioningScheme,
                    rewrittenSources.build(),
                    inputsBySource,
                    node.isEnsureSourceOrdering(),
                    node.getOrderingScheme());
        }

        @Override
        public PlanNode visitJoin(JoinNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedFilterInputs = new HashSet<>();
            if (node.getFilter().isPresent()) {
                expectedFilterInputs = ImmutableSet.<VariableReferenceExpression>builder()
                        .addAll(VariablesExtractor.extractUnique(node.getFilter().get()))
                        .addAll(context.get())
                        .build();
            }

            ImmutableSet.Builder<VariableReferenceExpression> leftInputsBuilder = ImmutableSet.builder();
            leftInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), EquiJoinClause::getLeft));
            if (node.getLeftHashVariable().isPresent()) {
                leftInputsBuilder.add(node.getLeftHashVariable().get());
            }
            leftInputsBuilder.addAll(expectedFilterInputs);
            Set<VariableReferenceExpression> leftInputs = leftInputsBuilder.build();

            ImmutableSet.Builder<VariableReferenceExpression> rightInputsBuilder = ImmutableSet.builder();
            rightInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), EquiJoinClause::getRight));
            if (node.getRightHashVariable().isPresent()) {
                rightInputsBuilder.add(node.getRightHashVariable().get());
            }
            rightInputsBuilder.addAll(expectedFilterInputs);
            Set<VariableReferenceExpression> rightInputs = rightInputsBuilder.build();

            PlanNode left = context.rewrite(node.getLeft(), leftInputs);
            PlanNode right = context.rewrite(node.getRight(), rightInputs);

            List<VariableReferenceExpression> outputVariables;
            if (node.isCrossJoin()) {
                // do not prune nested joins output since it is not supported
                // TODO: remove this "if" branch when output symbols selection is supported by nested loop join
                outputVariables = ImmutableList.<VariableReferenceExpression>builder()
                        .addAll(left.getOutputVariables())
                        .addAll(right.getOutputVariables())
                        .build();
            }
            else {
                outputVariables = node.getOutputVariables().stream()
                        .filter(variable -> context.get().contains(variable))
                        .distinct()
                        .collect(toImmutableList());
            }

            planChanged = node.getOutputVariables().size() != outputVariables.size();

            return new JoinNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    node.getType(),
                    left,
                    right,
                    node.getCriteria(),
                    outputVariables,
                    node.getFilter(),
                    node.getLeftHashVariable(),
                    node.getRightHashVariable(),
                    node.getDistributionType(),
                    node.getDynamicFilters());
        }

        @Override
        public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> sourceInputsBuilder = ImmutableSet.builder();
            sourceInputsBuilder.addAll(context.get()).add(node.getSourceJoinVariable());
            if (node.getSourceHashVariable().isPresent()) {
                sourceInputsBuilder.add(node.getSourceHashVariable().get());
            }
            Set<VariableReferenceExpression> sourceInputs = sourceInputsBuilder.build();

            ImmutableSet.Builder<VariableReferenceExpression> filteringSourceInputBuilder = ImmutableSet.builder();
            filteringSourceInputBuilder.add(node.getFilteringSourceJoinVariable());
            if (node.getFilteringSourceHashVariable().isPresent()) {
                filteringSourceInputBuilder.add(node.getFilteringSourceHashVariable().get());
            }
            Set<VariableReferenceExpression> filteringSourceInputs = filteringSourceInputBuilder.build();

            PlanNode source = context.rewrite(node.getSource(), sourceInputs);
            PlanNode filteringSource = context.rewrite(node.getFilteringSource(), filteringSourceInputs);

            return new SemiJoinNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    filteringSource,
                    node.getSourceJoinVariable(),
                    node.getFilteringSourceJoinVariable(),
                    node.getSemiJoinOutput(),
                    node.getSourceHashVariable(),
                    node.getFilteringSourceHashVariable(),
                    node.getDistributionType(),
                    node.getDynamicFilters());
        }

        @Override
        public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> filterSymbols = VariablesExtractor.extractUnique(node.getFilter());
            Set<VariableReferenceExpression> requiredInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(filterSymbols)
                    .addAll(context.get())
                    .build();

            ImmutableSet.Builder<VariableReferenceExpression> leftInputs = ImmutableSet.builder();
            node.getLeftPartitionVariable().map(leftInputs::add);

            ImmutableSet.Builder<VariableReferenceExpression> rightInputs = ImmutableSet.builder();
            node.getRightPartitionVariable().map(rightInputs::add);

            PlanNode left = context.rewrite(node.getLeft(), leftInputs.addAll(requiredInputs).build());
            PlanNode right = context.rewrite(node.getRight(), rightInputs.addAll(requiredInputs).build());

            List<VariableReferenceExpression> outputVariables = node.getOutputVariables().stream()
                    .filter(context.get()::contains)
                    .distinct()
                    .collect(toImmutableList());

            planChanged = outputVariables.size() != node.getOutputVariables().size();

            return new SpatialJoinNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), node.getType(), left, right, outputVariables, node.getFilter(), node.getLeftPartitionVariable(), node.getRightPartitionVariable(), node.getKdbTree());
        }

        @Override
        public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedFilterInputs = new HashSet<>();
            if (node.getFilter().isPresent()) {
                expectedFilterInputs = ImmutableSet.<VariableReferenceExpression>builder()
                        .addAll(VariablesExtractor.extractUnique(node.getFilter().get()))
                        .build();
            }

            ImmutableSet.Builder<VariableReferenceExpression> probeInputsBuilder = ImmutableSet.builder();
            probeInputsBuilder.addAll(context.get())
                    .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe));
            if (node.getProbeHashVariable().isPresent()) {
                probeInputsBuilder.add(node.getProbeHashVariable().get());
            }
            probeInputsBuilder.addAll(expectedFilterInputs);
            Set<VariableReferenceExpression> probeInputs = probeInputsBuilder.build();

            ImmutableSet.Builder<VariableReferenceExpression> indexInputBuilder = ImmutableSet.builder();
            indexInputBuilder.addAll(context.get())
                    .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getIndex));
            if (node.getIndexHashVariable().isPresent()) {
                indexInputBuilder.add(node.getIndexHashVariable().get());
            }
            indexInputBuilder.addAll(expectedFilterInputs);
            Set<VariableReferenceExpression> indexInputs = indexInputBuilder.build();

            PlanNode probeSource = context.rewrite(node.getProbeSource(), probeInputs);
            PlanNode indexSource = context.rewrite(node.getIndexSource(), indexInputs);

            return new IndexJoinNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    node.getType(),
                    probeSource,
                    indexSource,
                    node.getCriteria(),
                    node.getFilter(),
                    node.getProbeHashVariable(),
                    node.getIndexHashVariable());
        }

        @Override
        public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            List<VariableReferenceExpression> newOutputVariables = node.getOutputVariables().stream()
                    .filter(context.get()::contains)
                    .collect(toImmutableList());

            Set<VariableReferenceExpression> newLookupVariables = node.getLookupVariables().stream()
                    .filter(context.get()::contains)
                    .collect(toImmutableSet());

            Map<VariableReferenceExpression, ColumnHandle> newAssignments = newOutputVariables.stream()
                    .collect(toImmutableMap(identity(), node.getAssignments()::get));

            planChanged = newLookupVariables.size() != node.getLookupVariables().size();

            return new IndexSourceNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), node.getIndexHandle(), node.getTableHandle(), newLookupVariables, newOutputVariables, newAssignments, node.getCurrentConstraint());
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(node.getGroupingKeys());
            if (node.getHashVariable().isPresent()) {
                expectedInputs.add(node.getHashVariable().get());
            }

            ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, Aggregation> entry : node.getAggregations().entrySet()) {
                VariableReferenceExpression variable = entry.getKey();

                if (context.get().contains(variable)) {
                    Aggregation aggregation = entry.getValue();
                    expectedInputs.addAll(extractAggregationUniqueVariables(aggregation));
                    aggregation.getMask().ifPresent(expectedInputs::add);
                    aggregations.put(variable, aggregation);
                }
            }

            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

            return new AggregationNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    aggregations.build(),
                    node.getGroupingSets(),
                    ImmutableList.of(),
                    node.getStep(),
                    node.getHashVariable(),
                    node.getGroupIdVariable(),
                    node.getAggregationId());
        }

        @Override
        public PlanNode visitWindow(WindowNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(context.get())
                    .addAll(node.getPartitionBy());

            node.getOrderingScheme().ifPresent(orderingScheme ->
                    orderingScheme.getOrderByVariables()
                            .forEach(expectedInputs::add));

            for (WindowNode.Frame frame : node.getFrames()) {
                if (frame.getStartValue().isPresent()) {
                    expectedInputs.add(frame.getStartValue().get());
                }
                if (frame.getEndValue().isPresent()) {
                    expectedInputs.add(frame.getEndValue().get());
                }
                if (frame.getSortKeyCoercedForFrameStartComparison().isPresent()) {
                    expectedInputs.add(frame.getSortKeyCoercedForFrameStartComparison().get());
                }
                if (frame.getSortKeyCoercedForFrameEndComparison().isPresent()) {
                    expectedInputs.add(frame.getSortKeyCoercedForFrameEndComparison().get());
                }
            }

            if (node.getHashVariable().isPresent()) {
                expectedInputs.add(node.getHashVariable().get());
            }

            ImmutableMap.Builder<VariableReferenceExpression, WindowNode.Function> functionsBuilder = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, WindowNode.Function> entry : node.getWindowFunctions().entrySet()) {
                VariableReferenceExpression variable = entry.getKey();
                WindowNode.Function function = entry.getValue();
                if (context.get().contains(variable)) {
                    expectedInputs.addAll(WindowNodeUtil.extractWindowFunctionUniqueVariables(function));
                    functionsBuilder.put(variable, entry.getValue());
                }
            }

            Map<VariableReferenceExpression, WindowNode.Function> functions = functionsBuilder.build();
            if (functions.size() == 0) {
                // As the window plan node is getting skipped, use the inputs needed by the parent of the Window plan node
                return context.rewrite(node.getSource(), context.get());
            }

            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new WindowNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    node.getSpecification(),
                    functions,
                    node.getHashVariable(),
                    node.getPrePartitionedInputs(),
                    node.getPreSortedOrderPrefix());
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            List<VariableReferenceExpression> newOutputs = node.getOutputVariables().stream()
                    .filter(context.get()::contains)
                    .collect(toImmutableList());

            Map<VariableReferenceExpression, ColumnHandle> newAssignments = newOutputs.stream()
                    .collect(Collectors.toMap(identity(), node.getAssignments()::get));

            planChanged = newOutputs.size() != node.getOutputVariables().size();

            return new TableScanNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    node.getTable(),
                    newOutputs,
                    newAssignments,
                    node.getTableConstraints(),
                    node.getCurrentConstraint(),
                    node.getEnforcedConstraint(),
                    node.getCteMaterializationInfo());
        }

        @Override
        public PlanNode visitFilter(FilterNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(VariablesExtractor.extractUnique(node.getPredicate()))
                    .addAll(context.get())
                    .build();

            PlanNode source = context.rewrite(node.getSource(), expectedInputs);

            return new FilterNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getPredicate());
        }

        @Override
        public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            // Some output can be pruned but current implementation of PhysicalCteProducer does not allow cteconsumer pruning
            return node;
        }

        @Override
        public PlanNode visitCteProducer(CteProducerNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedInputs = ImmutableSet.copyOf(node.getOutputVariables());
            PlanNode source = context.rewrite(node.getSource(), expectedInputs);
            return new CteProducerNode(node.getSourceLocation(), node.getId(), source, node.getCteId(), node.getRowCountVariable(), node.getOutputVariables());
        }

        @Override
        public PlanNode visitSequence(SequenceNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> cteProducersBuilder = ImmutableSet.builder();
            node.getCteProducers().forEach(leftSource -> cteProducersBuilder.addAll(leftSource.getOutputVariables()));
            Set<VariableReferenceExpression> leftInputs = cteProducersBuilder.build();
            List<PlanNode> cteProducers = node.getCteProducers().stream()
                    .map(leftSource -> context.rewrite(leftSource, leftInputs)).collect(toImmutableList());
            Set<VariableReferenceExpression> rightInputs = ImmutableSet.copyOf(node.getPrimarySource().getOutputVariables());
            PlanNode primarySource = context.rewrite(node.getPrimarySource(), rightInputs);
            return new SequenceNode(node.getSourceLocation(), node.getId(), cteProducers, primarySource, node.getCteDependencyGraph());
        }

        @Override
        public PlanNode visitGroupId(GroupIdNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.builder();

            List<VariableReferenceExpression> newAggregationArguments = node.getAggregationArguments().stream()
                    .filter(context.get()::contains)
                    .collect(Collectors.toList());
            expectedInputs.addAll(newAggregationArguments);

            ImmutableList.Builder<List<VariableReferenceExpression>> newGroupingSets = ImmutableList.builder();
            Map<VariableReferenceExpression, VariableReferenceExpression> newGroupingMapping = new HashMap<>();

            for (List<VariableReferenceExpression> groupingSet : node.getGroupingSets()) {
                ImmutableList.Builder<VariableReferenceExpression> newGroupingSet = ImmutableList.builder();

                for (VariableReferenceExpression output : groupingSet) {
                    if (context.get().contains(output)) {
                        newGroupingSet.add(output);
                        newGroupingMapping.putIfAbsent(output, node.getGroupingColumns().get(output));
                        expectedInputs.add(node.getGroupingColumns().get(output));
                    }
                }
                newGroupingSets.add(newGroupingSet.build());
                planChanged = groupingSet.size() != newGroupingSet.build().size();
            }

            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new GroupIdNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, newGroupingSets.build(), newGroupingMapping, newAggregationArguments, node.getGroupIdVariable());
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            if (!context.get().contains(node.getMarkerVariable())) {
                planChanged = true;
                return context.rewrite(node.getSource(), context.get());
            }

            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(node.getDistinctVariables())
                    .addAll(context.get().stream()
                            .filter(variable -> !variable.equals(node.getMarkerVariable()))
                            .collect(toImmutableList()));

            if (node.getHashVariable().isPresent()) {
                expectedInputs.add(node.getHashVariable().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

            return new MarkDistinctNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getMarkerVariable(), node.getDistinctVariables(), node.getHashVariable());
        }

        @Override
        public PlanNode visitUnnest(UnnestNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            List<VariableReferenceExpression> replicateVariables = node.getReplicateVariables().stream()
                    .filter(context.get()::contains)
                    .collect(toImmutableList());

            planChanged = replicateVariables.size() != node.getReplicateVariables().size();

            Optional<VariableReferenceExpression> ordinalityVariable = node.getOrdinalityVariable();
            if (ordinalityVariable.isPresent() && !context.get().contains(ordinalityVariable.get())) {
                planChanged = true;
                ordinalityVariable = Optional.empty();
            }
            Map<VariableReferenceExpression, List<VariableReferenceExpression>> unnestVariables = node.getUnnestVariables();
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(replicateVariables)
                    .addAll(unnestVariables.keySet());

            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new UnnestNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, replicateVariables, unnestVariables, ordinalityVariable);
        }

        @Override
        public PlanNode visitProject(ProjectNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.builder();

            Assignments.Builder builder = Assignments.builder();
            node.getAssignments().forEach((variable, expression) -> {
                if (context.get().contains(variable)) {
                    expectedInputs.addAll(VariablesExtractor.extractUnique(expression));
                    builder.put(variable, expression);
                }
                else {
                    planChanged = true;
                }
            });

            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

            return new ProjectNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, builder.build(), node.getLocality());
        }

        @Override
        public PlanNode visitOutput(OutputNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedInputs = ImmutableSet.copyOf(node.getOutputVariables());
            PlanNode source = context.rewrite(node.getSource(), expectedInputs);
            return new OutputNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getColumnNames(), node.getOutputVariables());
        }

        @Override
        public PlanNode visitLimit(LimitNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(context.get());
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new LimitNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getCount(), node.getStep());
        }

        @Override
        public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedInputs;
            if (node.getHashVariable().isPresent()) {
                expectedInputs = ImmutableSet.copyOf(concat(node.getDistinctVariables(), ImmutableList.of(node.getHashVariable().get())));
            }
            else {
                expectedInputs = ImmutableSet.copyOf(node.getDistinctVariables());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs);
            return new DistinctLimitNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getLimit(), node.isPartial(), node.getDistinctVariables(), node.getHashVariable(), node.getTimeoutMillis());
        }

        @Override
        public PlanNode visitTopN(TopNNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(context.get())
                    .addAll(node.getOrderingScheme().getOrderByVariables());

            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

            return new TopNNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getCount(), node.getOrderingScheme(), node.getStep());
        }

        @Override
        public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> inputsBuilder = ImmutableSet.builder();
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = inputsBuilder
                    .addAll(context.get())
                    .addAll(node.getPartitionBy());

            if (node.getHashVariable().isPresent()) {
                inputsBuilder.add(node.getHashVariable().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

            return new RowNumberNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getPartitionBy(), node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), node.isPartial(), node.getHashVariable());
        }

        @Override
        public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(context.get())
                    .addAll(node.getPartitionBy())
                    .addAll(node.getOrderingScheme().getOrderByVariables());

            if (node.getHashVariable().isPresent()) {
                expectedInputs.add(node.getHashVariable().get());
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

            return new TopNRowNumberNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    node.getSpecification(),
                    node.getRowNumberVariable(),
                    node.getMaxRowCountPerPartition(),
                    node.isPartial(),
                    node.getHashVariable());
        }

        @Override
        public PlanNode visitSort(SortNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            Set<VariableReferenceExpression> expectedInputs = ImmutableSet.copyOf(concat(context.get(), node.getOrderingScheme().getOrderByVariables()));

            PlanNode source = context.rewrite(node.getSource(), expectedInputs);

            return new SortNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getOrderingScheme(), node.isPartial(), node.getPartitionBy());
        }

        @Override
        public PlanNode visitTableWriter(TableWriterNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(node.getColumns());
            if (node.getTablePartitioningScheme().isPresent()) {
                PartitioningScheme partitioningScheme = node.getTablePartitioningScheme().get();
                partitioningScheme.getPartitioning().getVariableReferences().forEach(expectedInputs::add);
                partitioningScheme.getHashColumn().ifPresent(expectedInputs::add);
            }
            if (node.getStatisticsAggregation().isPresent()) {
                StatisticAggregations aggregations = node.getStatisticsAggregation().get();
                expectedInputs.addAll(aggregations.getGroupingVariables());
                aggregations.getAggregations()
                        .values()
                        .forEach(aggregation -> expectedInputs.addAll(extractAggregationUniqueVariables(aggregation)));
            }
            PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
            return new TableWriterNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    node.getTarget(),
                    node.getRowCountVariable(),
                    node.getFragmentVariable(),
                    node.getTableCommitContextVariable(),
                    node.getColumns(),
                    node.getColumnNames(),
                    node.getNotNullColumnVariables(),
                    node.getTablePartitioningScheme(),
                    node.getStatisticsAggregation(),
                    node.getTaskCountIfScaledWriter(),
                    node.getIsTemporaryTableWriter());
        }

        @Override
        public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputVariables()));
            return new TableWriterMergeNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    node.getRowCountVariable(),
                    node.getFragmentVariable(),
                    node.getTableCommitContextVariable(),
                    node.getStatisticsAggregation());
        }

        @Override
        public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputVariables()));
            return new StatisticsWriterNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    node.getTableHandle(),
                    node.getRowCountVariable(),
                    node.isRowCountEnabled(),
                    node.getDescriptor());
        }

        @Override
        public PlanNode visitTableFinish(TableFinishNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputVariables()));
            return new TableFinishNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getStatsEquivalentPlanNode(),
                    source,
                    node.getTarget(),
                    node.getRowCountVariable(),
                    node.getStatisticsAggregation(),
                    node.getStatisticsAggregationDescriptor(),
                    node.getCteMaterializationInfo());
        }

        @Override
        public PlanNode visitDelete(DeleteNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();
            node.getRowId().ifPresent(r -> builder.add(r));
            if (node.getInputDistribution().isPresent()) {
                builder.addAll(node.getInputDistribution().get().getInputVariables());
            }
            PlanNode source = context.rewrite(node.getSource(), builder.build());
            return new DeleteNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), source, node.getRowId(), node.getOutputVariables(), node.getInputDistribution());
        }

        @Override
        public PlanNode visitUpdate(UpdateNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            return new UpdateNode(node.getSourceLocation(), node.getId(), node.getSource(), node.getRowId(), node.getColumnValueAndRowIdSymbols(), node.getOutputVariables());
        }

        @Override
        public PlanNode visitUnion(UnionNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ListMultimap<VariableReferenceExpression, VariableReferenceExpression> rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, true);
            ImmutableList<PlanNode> rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping);
            return new UnionNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping));
        }

        @Override
        public PlanNode visitIntersect(IntersectNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ListMultimap<VariableReferenceExpression, VariableReferenceExpression> rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, false);
            ImmutableList<PlanNode> rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping);
            return new IntersectNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping));
        }

        @Override
        public PlanNode visitExcept(ExceptNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ListMultimap<VariableReferenceExpression, VariableReferenceExpression> rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, context, false);
            ImmutableList<PlanNode> rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping);
            return new ExceptNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping));
        }

        private ListMultimap<VariableReferenceExpression, VariableReferenceExpression> rewriteSetOperationVariableMapping(SetOperationNode node, RewriteContext<Set<VariableReferenceExpression>> context, boolean pruneUnreferencedOutput)
        {
            // Find out which output variables we need to keep
            ImmutableListMultimap.Builder<VariableReferenceExpression, VariableReferenceExpression> rewrittenVariableMappingBuilder = ImmutableListMultimap.builder();
            for (VariableReferenceExpression variable : node.getOutputVariables()) {
                if (context.get().contains(variable) || !pruneUnreferencedOutput) {
                    rewrittenVariableMappingBuilder.putAll(
                            variable,
                            node.getVariableMapping().get(variable));
                }
            }
            return rewrittenVariableMappingBuilder.build();
        }

        private ImmutableList<PlanNode> rewriteSetOperationSubPlans(SetOperationNode node, RewriteContext<Set<VariableReferenceExpression>> context, ListMultimap<VariableReferenceExpression, VariableReferenceExpression> rewrittenVariableMapping)
        {
            // Find the corresponding input symbol to the remaining output symbols and prune the subplans
            ImmutableList.Builder<PlanNode> rewrittenSubPlans = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); i++) {
                ImmutableSet.Builder<VariableReferenceExpression> expectedInputSymbols = ImmutableSet.builder();
                for (Collection<VariableReferenceExpression> variables : rewrittenVariableMapping.asMap().values()) {
                    expectedInputSymbols.add(Iterables.get(variables, i));
                }
                rewrittenSubPlans.add(context.rewrite(node.getSources().get(i), expectedInputSymbols.build()));
            }
            return rewrittenSubPlans.build();
        }

        @Override
        public PlanNode visitValues(ValuesNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            ImmutableList.Builder<VariableReferenceExpression> rewrittenOutputVariablesBuilder = ImmutableList.builder();
            ImmutableList.Builder<ImmutableList.Builder<RowExpression>> rowBuildersBuilder = ImmutableList.builder();
            // Initialize builder for each row
            for (int i = 0; i < node.getRows().size(); i++) {
                rowBuildersBuilder.add(ImmutableList.builder());
            }
            ImmutableList<ImmutableList.Builder<RowExpression>> rowBuilders = rowBuildersBuilder.build();
            for (int i = 0; i < node.getOutputVariables().size(); i++) {
                VariableReferenceExpression outputVariable = node.getOutputVariables().get(i);
                // If output symbol is used
                if (context.get().contains(outputVariable)) {
                    rewrittenOutputVariablesBuilder.add(outputVariable);
                    // Add the value of the output symbol for each row
                    for (int j = 0; j < node.getRows().size(); j++) {
                        rowBuilders.get(j).add(node.getRows().get(j).get(i));
                    }
                }
            }
            List<List<RowExpression>> rewrittenRows = rowBuilders.stream()
                    .map(ImmutableList.Builder::build)
                    .collect(toImmutableList());
            List<VariableReferenceExpression> rewrittenOutputVariables = rewrittenOutputVariablesBuilder.build();
            planChanged = rewrittenOutputVariables.size() != node.getOutputVariables().size();
            return new ValuesNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), rewrittenOutputVariables, rewrittenRows, node.getValuesNodeLabel());
        }

        @Override
        public PlanNode visitApply(ApplyNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            // remove unused apply nodes
            if (intersection(node.getSubqueryAssignments().getVariables(), context.get()).isEmpty()) {
                planChanged = true;
                return context.rewrite(node.getInput(), context.get());
            }

            // extract symbols required subquery plan
            ImmutableSet.Builder<VariableReferenceExpression> subqueryAssignmentsVariablesBuilder = ImmutableSet.builder();
            Assignments.Builder subqueryAssignments = Assignments.builder();
            for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getSubqueryAssignments().getMap().entrySet()) {
                VariableReferenceExpression output = entry.getKey();
                RowExpression expression = entry.getValue();
                if (context.get().contains(output)) {
                    subqueryAssignmentsVariablesBuilder.addAll(VariablesExtractor.extractUnique(expression));
                    subqueryAssignments.put(output, expression);
                }
            }

            Set<VariableReferenceExpression> subqueryAssignmentsVariables = subqueryAssignmentsVariablesBuilder.build();
            PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsVariables);

            // prune not used correlation symbols
            Set<VariableReferenceExpression> subquerySymbols = VariablesExtractor.extractUnique(subquery);
            List<VariableReferenceExpression> newCorrelation = node.getCorrelation().stream()
                    .filter(subquerySymbols::contains)
                    .collect(toImmutableList());
            planChanged = newCorrelation.size() != node.getCorrelation().size();

            Set<VariableReferenceExpression> inputContext = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(context.get())
                    .addAll(newCorrelation)
                    .addAll(subqueryAssignmentsVariables) // need to include those: e.g: "expr" from "expr IN (SELECT 1)"
                    .build();
            PlanNode input = context.rewrite(node.getInput(), inputContext);
            Assignments assignments = subqueryAssignments.build();
            verifySubquerySupported(assignments);
            return new ApplyNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), input, subquery, assignments, newCorrelation, node.getOriginSubqueryError(), node.getMayParticipateInAntiJoin());
        }

        @Override
        public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            if (!context.get().contains(node.getIdVariable())) {
                planChanged = true;
                return context.rewrite(node.getSource(), context.get());
            }
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            PlanNode subquery = context.rewrite(node.getSubquery(), context.get());

            // remove unused lateral nodes
            if (intersection(ImmutableSet.copyOf(subquery.getOutputVariables()), context.get()).isEmpty() && isScalar(subquery)) {
                planChanged = true;
                return context.rewrite(node.getInput(), context.get());
            }

            // prune not used correlation symbols
            Set<VariableReferenceExpression> subqueryVariables = VariablesExtractor.extractUnique(subquery);
            List<VariableReferenceExpression> newCorrelation = node.getCorrelation().stream()
                    .filter(subqueryVariables::contains)
                    .collect(toImmutableList());
            planChanged = newCorrelation.size() != node.getCorrelation().size();

            Set<VariableReferenceExpression> inputContext = ImmutableSet.<VariableReferenceExpression>builder()
                    .addAll(context.get())
                    .addAll(newCorrelation)
                    .build();
            PlanNode input = context.rewrite(node.getInput(), inputContext);

            // remove unused lateral nodes
            if (intersection(ImmutableSet.copyOf(input.getOutputVariables()), inputContext).isEmpty() && isScalar(input)) {
                planChanged = true;
                return subquery;
            }

            return new LateralJoinNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), input, subquery, newCorrelation, node.getType(), node.getOriginSubqueryError());
        }
    }
}