CanonicalPlanGenerator.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner;

import com.facebook.presto.Session;
import com.facebook.presto.common.plan.PlanCanonicalizationStrategy;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.SchemaTableName;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.AggregationNode.GroupingSetDescriptor;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.CteReferenceNode;
import com.facebook.presto.spi.plan.DataOrganizationSpecification;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.collect.TreeMultimap;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.usePerfectlyConsistentHistories;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.plan.PlanCanonicalizationStrategy.DEFAULT;
import static com.facebook.presto.common.plan.PlanCanonicalizationStrategy.IGNORE_SAFE_CONSTANTS;
import static com.facebook.presto.common.plan.PlanCanonicalizationStrategy.IGNORE_SCAN_CONSTANTS;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.expressions.CanonicalRowExpressionRewriter.canonicalizeRowExpression;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.spi.StandardErrorCode.PLAN_SERIALIZATION_ERROR;
import static com.facebook.presto.sql.planner.CanonicalPartitioningScheme.getCanonicalPartitioningScheme;
import static com.facebook.presto.sql.planner.CanonicalTableScanNode.CanonicalTableHandle.getCanonicalTableHandle;
import static com.facebook.presto.sql.planner.RowExpressionVariableInliner.inlineVariables;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.graph.Traverser.forTree;
import static java.lang.String.format;
import static java.util.Comparator.comparing;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toCollection;

public class CanonicalPlanGenerator
        extends InternalPlanVisitor<Optional<PlanNode>, CanonicalPlanGenerator.Context>
{
    private static final String CANONICAL_STRING = "CANONICAL";

    // Not using a new override to objectMapper because PlanNodeId has a JsonValue annotation which cannot be directly overriden in a serializer
    private final PlanNodeIdAllocator planNodeidAllocator;
    private final VariableAllocator variableAllocator = new VariableAllocator();
    // TODO: DEFAULT strategy has a very different canonicalizaiton implementation, refactor it into a separate class.
    private final PlanCanonicalizationStrategy strategy;
    private final ObjectMapper objectMapper;
    private final Session session;

    public CanonicalPlanGenerator(PlanCanonicalizationStrategy strategy, ObjectMapper objectMapper, Session session)
    {
        this.strategy = requireNonNull(strategy, "strategy is null");
        this.objectMapper = requireNonNull(objectMapper, "objectMapper is null");
        this.session = requireNonNull(session, "session is null");
        this.planNodeidAllocator = createPlanNodeIdAllocator(strategy);
    }

    private PlanNodeIdAllocator createPlanNodeIdAllocator(PlanCanonicalizationStrategy strategy)
    {
        //ToDO: For HBO we always want planNodeId to be canonicalized but currently fragment result caching is using the same class with default strategy
        // refactor the default strategy to a different class
        if (strategy.equals(DEFAULT)) {
            return new PlanNodeIdAllocator();
        }
        else {
            return new PlanNodeIdAllocator()
            {
                @Override
                public PlanNodeId getNextId()
                {
                    return new PlanNodeId(CANONICAL_STRING);
                }
            };
        }
    }

    public static Optional<CanonicalPlanFragment> generateCanonicalPlanFragment(PlanNode root, PartitioningScheme partitioningScheme, ObjectMapper objectMapper, Session session)
    {
        Context context = new Context();
        Optional<PlanNode> canonicalPlan = root.accept(new CanonicalPlanGenerator(PlanCanonicalizationStrategy.DEFAULT, objectMapper, session), context);
        if (!context.getExpressions().keySet().containsAll(partitioningScheme.getOutputLayout())) {
            return Optional.empty();
        }
        return canonicalPlan.map(planNode -> new CanonicalPlanFragment(new CanonicalPlan(planNode, DEFAULT), getCanonicalPartitioningScheme(partitioningScheme, context.getExpressions())));
    }

    // Returns `CanonicalPlan`. If we encounter a `PlanNode` with unimplemented canonicalization, we return `Optional.empty()`
    public static Optional<CanonicalPlan> generateCanonicalPlan(PlanNode root, PlanCanonicalizationStrategy strategy, ObjectMapper objectMapper, Session session)
    {
        Optional<PlanNode> canonicalPlanNode = root.accept(new CanonicalPlanGenerator(strategy, objectMapper, session), new CanonicalPlanGenerator.Context());
        return canonicalPlanNode.map(planNode -> new CanonicalPlan(planNode, strategy));
    }

    @Override
    public Optional<PlanNode> visitPlan(PlanNode node, Context context)
    {
        // TODO: Support canonicalization for more plan node types
        return Optional.empty();
    }

    @Override
    public Optional<PlanNode> visitStatsEquivalentPlanNodeWithLimit(StatsEquivalentPlanNodeWithLimit node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> limit = node.getLimit().accept(this, context);
        if (!limit.isPresent()) {
            return Optional.empty();
        }

        Optional<PlanNode> plan = node.getPlan().accept(this, context);
        if (!plan.isPresent()) {
            return Optional.empty();
        }

        PlanNode result = new StatsEquivalentPlanNodeWithLimit(plan.get().getId(), plan.get(), limit.get());
        context.addPlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitTableWriter(TableWriterNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        List<VariableReferenceExpression> columns = node.getColumns().stream()
                .map(variable -> rename(variable, "", context))
                .sorted()
                .collect(toImmutableList());
        List<String> columnNames = node.getColumnNames().stream().sorted().collect(toImmutableList());

        PlanNode result = new TableWriterNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                node.getTarget().map(target -> CanonicalWriterTarget.from(target)),
                node.getRowCountVariable(),
                node.getFragmentVariable(),
                node.getTableCommitContextVariable(),
                columns,
                columnNames,
                ImmutableSet.of(),
                Optional.empty(),
                Optional.empty(),
                Optional.empty(),
                Optional.empty());
        context.addPlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitTableFinish(TableFinishNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        PlanNode result = new TableFinishNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                node.getTarget().map(target -> CanonicalWriterTarget.from(target)),
                node.getRowCountVariable(),
                Optional.empty(),
                Optional.empty(),
                node.getCteMaterializationInfo());
        context.addPlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitLimit(LimitNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        PlanNode result = new LimitNode(Optional.empty(), planNodeidAllocator.getNextId(), source.get(), node.getCount(), node.getStep());
        context.addLimitingNodePlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitTopN(TopNNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        PlanNode result = new TopNNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                node.getCount(),
                getCanonicalOrderingScheme(node.getOrderingScheme(), context.getExpressions()),
                node.getStep());
        context.addLimitingNodePlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitJoin(JoinNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }
        if (node.getType().equals(JoinType.RIGHT)) {
            return visitJoin(node.flipChildren(), context);
        }
        List<PlanNode> sources = new ArrayList<>();
        ImmutableList.Builder<RowExpression> allFilters = ImmutableList.builder();
        ImmutableList.Builder<EquiJoinClause> criterias = ImmutableList.builder();
        Stack<JoinNode> stack = new Stack<>();

        stack.push(node);
        while (!stack.empty()) {
            JoinNode top = stack.pop();
            top.getCriteria().forEach(criterias::add);
            // ReorderJoins can move predicates between `criteria` and `filters`, so we put all equalities
            // in `criteria` to make it consistent.
            if (top.getFilter().isPresent()) {
                List<RowExpression> filters = extractConjuncts(top.getFilter().get());
                filters.forEach(filter -> {
                    Optional<EquiJoinClause> criteria = toEquiJoinClause(filter);
                    criteria.ifPresent(criterias::add);
                    allFilters.add(filter);
                });
            }
            for (PlanNode source : top.getSources()) {
                if (source instanceof JoinNode
                        && ((JoinNode) source).getType().equals(node.getType())
                        && shouldMergeJoinNodes(node.getType())) {
                    stack.push((JoinNode) source);
                }
                else {
                    sources.add(source);
                }
            }
        }

        // Sort sources if all are INNER, or full outer join of 2 nodes
        if (shouldMergeJoinNodes(node.getType()) || (node.getType().equals(JoinType.FULL) && sources.size() == 2)) {
            Optional<List<Integer>> sourceIndexes = orderSources(sources);
            if (!sourceIndexes.isPresent()) {
                return Optional.empty();
            }
            sources = sourceIndexes.get().stream().map(sources::get).collect(toImmutableList());
        }

        ImmutableList.Builder<PlanNode> newSources = ImmutableList.builder();
        for (PlanNode source : sources) {
            Optional<PlanNode> newSource = source.accept(this, context);
            if (!newSource.isPresent()) {
                return Optional.empty();
            }
            newSources.add(newSource.get());
        }
        Set<EquiJoinClause> newCriterias = criterias.build().stream()
                .map(criteria -> canonicalize(criteria, context))
                .sorted(comparing(EquiJoinClause::toString))
                .collect(toCollection(LinkedHashSet::new));
        Set<RowExpression> newFilters = allFilters.build().stream()
                .map(filter -> inlineAndCanonicalize(context.getExpressions(), filter))
                .sorted(comparing(this::writeValueAsString))
                .collect(toCollection(LinkedHashSet::new));
        List<VariableReferenceExpression> outputVariables = node.getOutputVariables().stream()
                .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable))
                .sorted()
                .collect(toImmutableList());

        PlanNode result = new CanonicalJoinNode(
                planNodeidAllocator.getNextId(),
                newSources.build(),
                node.getType(),
                newCriterias,
                newFilters,
                outputVariables);
        context.addPlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitSemiJoin(SemiJoinNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        Optional<PlanNode> filteringSource = node.getFilteringSource().accept(this, context);
        if (!filteringSource.isPresent()) {
            return Optional.empty();
        }

        VariableReferenceExpression sourceJoinVariable = inlineAndCanonicalize(context.getExpressions(), node.getSourceJoinVariable());
        VariableReferenceExpression filteringSourceJoinVariable = inlineAndCanonicalize(context.getExpressions(), node.getFilteringSourceJoinVariable());
        VariableReferenceExpression semiJoinOutput = rename(node.getSemiJoinOutput(), "semijoinoutput", context);

        PlanNode result = new SemiJoinNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                filteringSource.get(),
                sourceJoinVariable,
                filteringSourceJoinVariable,
                semiJoinOutput,
                Optional.empty(),
                Optional.empty(),
                Optional.empty(),
                ImmutableMap.of());
        context.addPlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitUnion(UnionNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<List<Integer>> sourceIndexes = orderSources(node.getSources());
        if (!sourceIndexes.isPresent()) {
            return Optional.empty();
        }

        ImmutableList.Builder<PlanNode> canonicalSources = ImmutableList.builder();
        ImmutableList.Builder<VariableReferenceExpression> outputVariables = ImmutableList.builder();
        ImmutableMap.Builder<VariableReferenceExpression, List<VariableReferenceExpression>> outputsToInputs = ImmutableMap.builder();

        for (Integer sourceIndex : sourceIndexes.get()) {
            Optional<PlanNode> canonicalSource = node.getSources().get(sourceIndex).accept(this, context);
            if (!canonicalSource.isPresent()) {
                return Optional.empty();
            }
            canonicalSources.add(canonicalSource.get());
        }

        node.getVariableMapping().forEach((outputVariable, sourceVariables) -> {
            ImmutableList.Builder<VariableReferenceExpression> newSourceVariablesBuilder = ImmutableList.builder();
            sourceIndexes.get().forEach(index -> {
                newSourceVariablesBuilder.add(inlineAndCanonicalize(context.getExpressions(), sourceVariables.get(index)));
            });
            ImmutableList<VariableReferenceExpression> newSourceVariables = newSourceVariablesBuilder.build();
            VariableReferenceExpression newVariable = variableAllocator.newVariable(newSourceVariables.get(0));
            outputVariables.add(newVariable);
            context.mapExpression(outputVariable, newVariable);
            outputsToInputs.put(newVariable, newSourceVariables);
        });

        PlanNode result = new UnionNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                canonicalSources.build(),
                outputVariables.build().stream().sorted().collect(toImmutableList()),
                ImmutableSortedMap.copyOf(outputsToInputs.build()));

        context.addPlan(node, new CanonicalPlan(result, strategy));
        return Optional.of(result);
    }

    @Override
    public Optional<PlanNode> visitWindow(WindowNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        Set<VariableReferenceExpression> prePartitionedInputs = node.getPrePartitionedInputs().stream()
                .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable))
                .sorted(comparing(this::writeValueAsString))
                .collect(toImmutableSet());

        DataOrganizationSpecification specification = new DataOrganizationSpecification(
                node.getSpecification().getPartitionBy().stream()
                        .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable))
                        .sorted(comparing(this::writeValueAsString))
                        .collect(toImmutableList()),
                node.getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions())));

        Map<VariableReferenceExpression, WindowNode.Function> windowFunctions = node.getWindowFunctions()
                .entrySet().stream()
                .map(entry -> {
                    WindowNode.Function function = entry.getValue();
                    CallExpression callExpression = new CallExpression(
                            Optional.empty(),
                            function.getFunctionCall().getDisplayName(),
                            function.getFunctionCall().getFunctionHandle(),
                            function.getFunctionCall().getType(),
                            function.getFunctionCall().getArguments().stream()
                                    .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression))
                                    .collect(toImmutableList()));
                    Optional<VariableReferenceExpression> startValue = function.getFrame().getStartValue()
                            .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression));
                    Optional<VariableReferenceExpression> endValue = function.getFrame().getEndValue()
                            .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression));
                    Optional<VariableReferenceExpression> sortKeyCoercedForFrameStartComparison = function.getFrame().getSortKeyCoercedForFrameStartComparison()
                            .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression));
                    Optional<VariableReferenceExpression> sortKeyCoercedForFrameEndComparison = function.getFrame().getSortKeyCoercedForFrameEndComparison()
                            .map(expression -> inlineAndCanonicalize(context.getExpressions(), expression));
                    WindowNode.Frame frame = new WindowNode.Frame(
                            function.getFrame().getType(),
                            function.getFrame().getStartType(),
                            startValue,
                            sortKeyCoercedForFrameStartComparison,
                            function.getFrame().getEndType(),
                            endValue,
                            sortKeyCoercedForFrameEndComparison,
                            startValue.map(ignored -> ""),
                            endValue.map(ignored -> ""));
                    WindowNode.Function newFunction = new WindowNode.Function(
                            callExpression,
                            frame,
                            function.isIgnoreNulls());
                    return Maps.immutableEntry(entry.getKey(), newFunction);
                })
                .sorted(comparing(entry -> writeValueAsString(entry.getValue())))
                .map(entry -> {
                    VariableReferenceExpression variable = rename(entry.getKey(), entry.getValue().getFunctionCall().getDisplayName(), context);
                    return Maps.immutableEntry(variable, entry.getValue());
                })
                .collect(toImmutableMap(Entry::getKey, Entry::getValue));

        PlanNode canonicalPlan = new WindowNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                specification,
                windowFunctions,
                Optional.empty(),
                prePartitionedInputs,
                node.getPreSortedOrderPrefix());
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));

        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitValues(ValuesNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        List<List<RowExpression>> rows = node.getRows().stream()
                .map(row -> row.stream().map(expression -> inlineAndCanonicalize(context.getExpressions(), expression)).collect(toImmutableList()))
                .collect(toImmutableList());

        List<VariableReferenceExpression> outputVariables = node.getOutputVariables().stream()
                .map(variable -> rename(variable, "", context))
                .collect(toImmutableList());

        PlanNode canonicalPlan = new ValuesNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                outputVariables,
                rows,
                Optional.empty());
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitMarkDistinct(MarkDistinctNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        List<VariableReferenceExpression> distinctVariables = node.getDistinctVariables().stream()
                .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable))
                .sorted(comparing(this::writeValueAsString))
                .collect(toImmutableList());

        VariableReferenceExpression markerVariable = rename(node.getMarkerVariable(), "is_distinct", context);
        PlanNode canonicalPlan = new MarkDistinctNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                markerVariable,
                distinctVariables,
                Optional.empty());
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitAssignUniqueId(AssignUniqueId node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        VariableReferenceExpression idVariable = rename(node.getIdVariable(), "unique", context);

        PlanNode canonicalPlan = new AssignUniqueId(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                idVariable);
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitEnforceSingleRow(EnforceSingleRowNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        PlanNode canonicalPlan = new EnforceSingleRowNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get());
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitRowNumber(RowNumberNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        List<VariableReferenceExpression> partitionBy = node.getPartitionBy().stream()
                .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable))
                .sorted(comparing(this::writeValueAsString))
                .collect(toImmutableList());

        VariableReferenceExpression rowNumberVariable = rename(node.getRowNumberVariable(), "row_number", context);
        PlanNode canonicalPlan = new RowNumberNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                partitionBy,
                rowNumberVariable,
                node.getMaxRowCountPerPartition(),
                node.isPartial(),
                Optional.empty());
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitTopNRowNumber(TopNRowNumberNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        List<VariableReferenceExpression> partitionBy = node.getPartitionBy().stream()
                .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable))
                .sorted(comparing(this::writeValueAsString))
                .collect(toImmutableList());

        VariableReferenceExpression rowNumberVariable = rename(node.getRowNumberVariable(), "row_number", context);
        PlanNode canonicalPlan = new TopNRowNumberNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                new DataOrganizationSpecification(
                        partitionBy,
                        node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))),
                rowNumberVariable,
                node.getMaxRowCountPerPartition(),
                node.isPartial(),
                Optional.empty());
        context.addLimitingNodePlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitDistinctLimit(DistinctLimitNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        List<VariableReferenceExpression> distinctVariables = node.getDistinctVariables().stream()
                .map(variable -> inlineAndCanonicalize(context.getExpressions(), variable))
                .sorted(comparing(this::writeValueAsString))
                .collect(toImmutableList());

        PlanNode canonicalPlan = new DistinctLimitNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                node.getLimit(),
                node.isPartial(),
                distinctVariables,
                Optional.empty(),
                0);
        context.addLimitingNodePlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitSort(SortNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        PlanNode canonicalPlan = new SortNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                getCanonicalOrderingScheme(node.getOrderingScheme(), context.getExpressions()),
                node.isPartial(),
                node.getPartitionBy());
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitOutput(OutputNode node, Context context)
    {
        if (strategy == DEFAULT) {
            return Optional.empty();
        }

        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        List<RowExpressionReference> rowExpressionReferences = node.getOutputVariables().stream()
                .map(variable -> new RowExpressionReference(inlineAndCanonicalize(context.getExpressions(), variable, strategy == IGNORE_SAFE_CONSTANTS), variable))
                .sorted(comparing(rowExpressionReference -> writeValueAsString(rowExpressionReference.getRowExpression())))
                .collect(toImmutableList());

        ImmutableMap.Builder<VariableReferenceExpression, RowExpression> assignments = ImmutableMap.builder();
        for (RowExpressionReference rowExpressionReference : rowExpressionReferences) {
            VariableReferenceExpression reference = variableAllocator.newVariable(rowExpressionReference.getRowExpression());
            context.mapExpression(rowExpressionReference.getVariableReferenceExpression(), reference);
            assignments.put(reference, rowExpressionReference.getRowExpression());
        }
        // Rewrite OutputNode as ProjectNode
        PlanNode canonicalPlan = new ProjectNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                new Assignments(assignments.build()),
                ProjectNode.Locality.LOCAL);
        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));

        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitAggregation(AggregationNode node, Context context)
    {
        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        // Steps to get canonical aggregations:
        //   1. Transform aggregation into canonical form
        //   2. Sort based on canonical aggregation expression
        //   3. Get new variable reference for aggregation expression
        //   4. Record mapping from original variable reference to the new one
        List<AggregationReference> aggregationReferences = node.getAggregations().entrySet().stream()
                .map(entry -> new AggregationReference(getCanonicalAggregation(entry.getValue(), context.getExpressions()), entry.getKey()))
                .sorted(comparing(aggregationReference -> writeValueAsString(aggregationReference.getAggregation().getCall())))
                .collect(toImmutableList());
        ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
        for (AggregationReference aggregationReference : aggregationReferences) {
            VariableReferenceExpression reference = variableAllocator.newVariable(aggregationReference.getAggregation().getCall());
            context.mapExpression(aggregationReference.getVariableReferenceExpression(), reference);
            aggregations.put(reference, aggregationReference.getAggregation());
        }

        PlanNode canonicalPlan = new AggregationNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                aggregations.build(),
                getCanonicalGroupingSetDescriptor(node.getGroupingSets(), context.getExpressions()),
                node.getPreGroupedVariables().stream()
                        .map(variable -> context.getExpressions().get(variable))
                        .collect(toImmutableList()),
                node.getStep(),
                node.getHashVariable().map(ignored -> variableAllocator.newHashVariable()),
                node.getGroupIdVariable().map(variable -> context.getExpressions().get(variable)),
                // ignore aggregationId when creating the canonical plan
                Optional.empty());

        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitSequence(SequenceNode node, Context context)
    {
        node.getCteProducers().forEach(x -> x.accept(this, context));
        return node.getPrimarySource().accept(this, context);
    }

    @Override
    public Optional<PlanNode> visitCteProducer(CteProducerNode node, Context context)
    {
        return node.getSource().accept(this, context);
    }

    @Override
    public Optional<PlanNode> visitCteConsumer(CteConsumerNode node, Context context)
    {
        return node.getOriginalSource().accept(this, context);
    }

    @Override
    public Optional<PlanNode> visitCteReference(CteReferenceNode node, Context context)
    {
        return node.getSource().accept(this, context);
    }

    private Aggregation getCanonicalAggregation(Aggregation aggregation, Map<VariableReferenceExpression, VariableReferenceExpression> context)
    {
        return new Aggregation(
                inlineAndCanonicalize(context, aggregation.getCall()),
                aggregation.getFilter().map(filter -> inlineAndCanonicalize(context, filter)),
                aggregation.getOrderBy().map(orderBy -> getCanonicalOrderingScheme(orderBy, context)),
                aggregation.isDistinct(),
                aggregation.getMask().map(mask -> inlineAndCanonicalize(context, mask)));
    }

    private static OrderingScheme getCanonicalOrderingScheme(OrderingScheme orderingScheme, Map<VariableReferenceExpression, VariableReferenceExpression> context)
    {
        return new OrderingScheme(
                orderingScheme.getOrderBy().stream()
                        .map(orderBy -> new Ordering(inlineAndCanonicalize(context, orderBy.getVariable()), orderBy.getSortOrder()))
                        .collect(toImmutableList()));
    }

    private static GroupingSetDescriptor getCanonicalGroupingSetDescriptor(GroupingSetDescriptor groupingSetDescriptor, Map<VariableReferenceExpression, VariableReferenceExpression> context)
    {
        return new GroupingSetDescriptor(
                groupingSetDescriptor.getGroupingKeys().stream()
                        .map(key -> inlineAndCanonicalize(context, key))
                        .collect(toImmutableList()),
                groupingSetDescriptor.getGroupingSetCount(),
                groupingSetDescriptor.getGlobalGroupingSets());
    }

    private static class AggregationReference
    {
        private final Aggregation aggregation;
        private final VariableReferenceExpression variableReferenceExpression;

        public AggregationReference(Aggregation aggregation, VariableReferenceExpression variableReferenceExpression)
        {
            this.aggregation = requireNonNull(aggregation, "aggregation is null");
            this.variableReferenceExpression = requireNonNull(variableReferenceExpression, "variableReferenceExpression is null");
        }

        public Aggregation getAggregation()
        {
            return aggregation;
        }

        public VariableReferenceExpression getVariableReferenceExpression()
        {
            return variableReferenceExpression;
        }
    }

    @Override
    public Optional<PlanNode> visitGroupId(GroupIdNode node, Context context)
    {
        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> groupingColumns = ImmutableMap.builder();
        for (Entry<VariableReferenceExpression, VariableReferenceExpression> entry : node.getGroupingColumns().entrySet()) {
            VariableReferenceExpression column = context.getExpressions().get(entry.getValue());
            VariableReferenceExpression reference = variableAllocator.newVariable(column, "gid");
            context.mapExpression(entry.getKey(), reference);
            groupingColumns.put(reference, column);
        }

        ImmutableList.Builder<List<VariableReferenceExpression>> groupingSets = ImmutableList.builder();
        for (List<VariableReferenceExpression> groupingSet : node.getGroupingSets()) {
            groupingSets.add(groupingSet.stream()
                    .map(variable -> context.getExpressions().get(variable))
                    .collect(toImmutableList()));
        }

        VariableReferenceExpression groupId = variableAllocator.newVariable("groupid", INTEGER);
        context.mapExpression(node.getGroupIdVariable(), groupId);

        PlanNode canonicalPlan = new GroupIdNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                groupingSets.build(),
                groupingColumns.build(),
                node.getAggregationArguments().stream()
                        .map(variable -> context.getExpressions().get(variable))
                        .collect(toImmutableList()),
                groupId);

        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitUnnest(UnnestNode node, Context context)
    {
        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        // Generate canonical unnestVariables.
        ImmutableMap.Builder<VariableReferenceExpression, List<VariableReferenceExpression>> newUnnestVariables = ImmutableMap.builder();
        for (Map.Entry<VariableReferenceExpression, List<VariableReferenceExpression>> unnestVariable : node.getUnnestVariables().entrySet()) {
            VariableReferenceExpression input = (VariableReferenceExpression) inlineAndCanonicalize(context.getExpressions(), unnestVariable.getKey());
            ImmutableList.Builder<VariableReferenceExpression> newVariables = ImmutableList.builder();
            for (VariableReferenceExpression variable : unnestVariable.getValue()) {
                VariableReferenceExpression newVariable = variableAllocator.newVariable(Optional.empty(), "unnest_field", variable.getType());
                context.mapExpression(variable, newVariable);
                newVariables.add(newVariable);
            }
            newUnnestVariables.put(input, newVariables.build());
        }

        // Generate canonical ordinality variable
        Optional<VariableReferenceExpression> ordinalityVariable = node.getOrdinalityVariable()
                .map(variable -> {
                    VariableReferenceExpression newVariable = variableAllocator.newVariable(Optional.empty(), "unnest_ordinality", variable.getType());
                    context.mapExpression(variable, newVariable);
                    return newVariable;
                });

        PlanNode canonicalPlan = new UnnestNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                node.getReplicateVariables().stream()
                        .map(variable -> (VariableReferenceExpression) inlineAndCanonicalize(context.getExpressions(), variable))
                        .collect(toImmutableList()),
                newUnnestVariables.build(),
                ordinalityVariable);

        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitProject(ProjectNode node, Context context)
    {
        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        List<RowExpressionReference> rowExpressionReferences = node.getAssignments().entrySet().stream()
                .map(entry -> new RowExpressionReference(inlineAndCanonicalize(context.getExpressions(), entry.getValue(), strategy == IGNORE_SAFE_CONSTANTS || strategy == IGNORE_SCAN_CONSTANTS), entry.getKey()))
                .sorted(comparing(rowExpressionReference -> writeValueAsString(rowExpressionReference.getRowExpression())))
                .collect(toImmutableList());
        ImmutableMap.Builder<VariableReferenceExpression, RowExpression> assignments = ImmutableMap.builder();
        for (RowExpressionReference rowExpressionReference : rowExpressionReferences) {
            VariableReferenceExpression reference = variableAllocator.newVariable(rowExpressionReference.getRowExpression());
            context.mapExpression(rowExpressionReference.getVariableReferenceExpression(), reference);
            assignments.put(reference, rowExpressionReference.getRowExpression());
        }

        PlanNode canonicalPlan = new ProjectNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                new Assignments(assignments.build()),
                node.getLocality());

        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    // Variable names and plan node ids can change with what order we process nodes because of our
    // stateful canonicalization using `variableAllocator` and `planNodeIdAllocator`.
    // We want to order sources in a consistent manner, because the order matters when hashing plan.
    // Returns a list of indices in input sources array, with a canonical order.
    private Optional<List<Integer>> orderSources(List<PlanNode> sources)
    {
        // Try heuristic where we sort sources by the tables they scan.
        Optional<List<Integer>> sourcesByTables = orderSourcesByTables(sources);
        if (sourcesByTables.isPresent()) {
            return sourcesByTables;
        }

        if (!usePerfectlyConsistentHistories(session)) {
            return Optional.of(IntStream.range(0, sources.size()).boxed().collect(toImmutableList()));
        }

        // We canonicalize each source independently, and use its representation to order sources.
        Multimap<String, Integer> sourceToPosition = TreeMultimap.create();
        for (int i = 0; i < sources.size(); ++i) {
            Optional<CanonicalPlan> canonicalSource = generateCanonicalPlan(sources.get(i), strategy, objectMapper, session);
            if (!canonicalSource.isPresent()) {
                return Optional.empty();
            }
            sourceToPosition.put(canonicalSource.get().toString(objectMapper), i);
        }
        return Optional.of(sourceToPosition.values().stream().collect(toImmutableList()));
    }

    // Order sources by list of tables they use. If any 2 sources are using the same set of tables, we give up
    // and return Optional.empty().
    // Returns a list of indices in input sources array, with a canonical order
    private Optional<List<Integer>> orderSourcesByTables(List<PlanNode> sources)
    {
        Multimap<String, Integer> sourceToPosition = TreeMultimap.create();
        for (int i = 0; i < sources.size(); ++i) {
            List<String> tables = new ArrayList<>();

            PlanNodeSearcher.searchFrom(sources.get(i))
                    .where(node -> node instanceof TableScanNode)
                    .findAll()
                    .forEach(node -> tables.add(((TableScanNode) node).getTable().getConnectorHandle().toString()));
            sourceToPosition.put(tables.stream().sorted().collect(Collectors.joining(",")), i);
        }
        String lastIdentifier = ",";
        for (Map.Entry<String, Integer> entry : sourceToPosition.entries()) {
            String identifier = entry.getKey();
            if (lastIdentifier.equals(identifier)) {
                return Optional.empty();
            }
            lastIdentifier = identifier;
        }
        return Optional.of(sourceToPosition.values().stream().collect(toImmutableList()));
    }

    private static class CanonicalWriterTarget
            extends TableWriterNode.WriterTarget
    {
        private final ConnectorId connectorId;
        // Include classname of WriterTarget, as it signifies type of table operation.
        private final String writerTargetType;

        @JsonCreator
        public CanonicalWriterTarget(
                @JsonProperty("connectorId") ConnectorId connectorId,
                @JsonProperty("writerTargetType") String writerTargetType)
        {
            this.connectorId = connectorId;
            this.writerTargetType = writerTargetType;
        }

        @JsonProperty
        public ConnectorId getConnectorId()
        {
            return connectorId;
        }

        @JsonProperty
        public String getWriterTargetType()
        {
            return writerTargetType;
        }

        @Override
        public SchemaTableName getSchemaTableName()
        {
            // Just return a sample table name, which is always same
            return new SchemaTableName("schema", "table");
        }

        @Override
        public String toString()
        {
            return format("WriterTarget{connectorId: %s, type: %s}", connectorId, writerTargetType);
        }

        private static CanonicalWriterTarget from(TableWriterNode.WriterTarget target)
        {
            return new CanonicalWriterTarget(target.getConnectorId(), target.getClass().getSimpleName());
        }
    }

    private static class RowExpressionReference
    {
        private final RowExpression rowExpression;
        private final VariableReferenceExpression variableReferenceExpression;

        public RowExpressionReference(RowExpression rowExpression, VariableReferenceExpression variableReferenceExpression)
        {
            this.rowExpression = requireNonNull(rowExpression, "rowExpression is null");
            this.variableReferenceExpression = requireNonNull(variableReferenceExpression, "variableReferenceExpression is null");
        }

        public RowExpression getRowExpression()
        {
            return rowExpression;
        }

        public VariableReferenceExpression getVariableReferenceExpression()
        {
            return variableReferenceExpression;
        }
    }

    @Override
    public Optional<PlanNode> visitFilter(FilterNode node, Context context)
    {
        Optional<PlanNode> source = node.getSource().accept(this, context);
        if (!source.isPresent()) {
            return Optional.empty();
        }

        PlanNode canonicalPlan = new FilterNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                source.get(),
                inlineAndCanonicalize(context.getExpressions(), node.getPredicate()));

        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    @Override
    public Optional<PlanNode> visitTableScan(TableScanNode node, Context context)
    {
        List<ColumnReference> columnReferences = node.getAssignments().entrySet().stream()
                .map(entry -> new ColumnReference(entry.getValue(), entry.getKey()))
                .sorted(comparing(columnReference -> columnReference.getColumnHandle().toString()))
                .collect(toImmutableList());
        ImmutableList.Builder<VariableReferenceExpression> outputVariables = ImmutableList.builder();
        ImmutableMap.Builder<VariableReferenceExpression, ColumnHandle> assignments = ImmutableMap.builder();
        for (ColumnReference columnReference : columnReferences) {
            VariableReferenceExpression reference = variableAllocator.newVariable(Optional.empty(), columnReference.getColumnHandle().toString(), columnReference.getVariableReferenceExpression().getType());
            context.mapExpression(columnReference.getVariableReferenceExpression(), reference);
            outputVariables.add(reference);
            assignments.put(reference, columnReference.getColumnHandle());
        }

        PlanNode canonicalPlan = new CanonicalTableScanNode(
                Optional.empty(),
                planNodeidAllocator.getNextId(),
                getCanonicalTableHandle(node.getTable(), strategy),
                outputVariables.build(),
                assignments.build());

        context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy));
        return Optional.of(canonicalPlan);
    }

    private boolean shouldMergeJoinNodes(JoinType type)
    {
        return type.equals(JoinType.INNER);
    }

    private VariableReferenceExpression rename(VariableReferenceExpression variable, String nameHint, Context context)
    {
        VariableReferenceExpression newVariable = variableAllocator.newVariable(Optional.empty(), nameHint, variable.getType());
        context.mapExpression(variable, newVariable);
        return newVariable;
    }

    private String writeValueAsString(Object object)
    {
        try {
            return objectMapper.writeValueAsString(object);
        }
        catch (JsonProcessingException e) {
            throw new PrestoException(PLAN_SERIALIZATION_ERROR, "Cannot serialize plan to JSON", e);
        }
    }

    private static EquiJoinClause canonicalize(EquiJoinClause criteria, Context context)
    {
        VariableReferenceExpression left = inlineAndCanonicalize(context.getExpressions(), criteria.getLeft());
        VariableReferenceExpression right = inlineAndCanonicalize(context.getExpressions(), criteria.getRight());
        return left.compareTo(right) > 0 ? new EquiJoinClause(left, right) : new EquiJoinClause(right, left);
    }

    private static Optional<EquiJoinClause> toEquiJoinClause(RowExpression expression)
    {
        if (!(expression instanceof CallExpression)) {
            return Optional.empty();
        }
        CallExpression callExpression = (CallExpression) expression;
        boolean isValid = callExpression.getDisplayName().equals(EQUAL.getFunctionName().getObjectName())
                && callExpression.getArguments().size() == 2
                && callExpression.getArguments().get(0) instanceof VariableReferenceExpression
                && callExpression.getArguments().get(1) instanceof VariableReferenceExpression;

        if (!isValid) {
            return Optional.empty();
        }

        return Optional.of(new EquiJoinClause(
                (VariableReferenceExpression) callExpression.getArguments().get(0),
                (VariableReferenceExpression) callExpression.getArguments().get(1)));
    }

    private static <T extends RowExpression> T inlineAndCanonicalize(
            Map<VariableReferenceExpression, VariableReferenceExpression> context,
            T expression)
    {
        return inlineAndCanonicalize(context, expression, false);
    }

    private static <T extends RowExpression> T inlineAndCanonicalize(
            Map<VariableReferenceExpression, VariableReferenceExpression> context,
            T expression,
            boolean removeConstants)
    {
        return (T) canonicalizeRowExpression(inlineVariables(variable -> context.getOrDefault(variable, variable), expression), removeConstants);
    }

    private static class ColumnReference
    {
        private final ColumnHandle columnHandle;
        private final VariableReferenceExpression variableReferenceExpression;

        public ColumnReference(ColumnHandle columnHandle, VariableReferenceExpression variableReferenceExpression)
        {
            this.columnHandle = requireNonNull(columnHandle, "columnHandle is null");
            this.variableReferenceExpression = requireNonNull(variableReferenceExpression, "variableReferenceExpression is null");
        }

        public ColumnHandle getColumnHandle()
        {
            return columnHandle;
        }

        public VariableReferenceExpression getVariableReferenceExpression()
        {
            return variableReferenceExpression;
        }
    }

    public static class Context
    {
        private final Map<VariableReferenceExpression, VariableReferenceExpression> expressions = new HashMap<>();
        private final Map<PlanNode, CanonicalPlan> canonicalPlans = new IdentityHashMap<>();
        private final Map<PlanNode, PlanNode> canonicalPlanToPlan = new IdentityHashMap<>();
        private final Map<PlanNode, List<TableScanNode>> inputTables = new IdentityHashMap<>();

        public Map<VariableReferenceExpression, VariableReferenceExpression> getExpressions()
        {
            return expressions;
        }

        public Map<PlanNode, CanonicalPlan> getCanonicalPlans()
        {
            return canonicalPlans;
        }

        public Map<PlanNode, List<TableScanNode>> getInputTables()
        {
            return inputTables;
        }

        public void mapExpression(VariableReferenceExpression from, VariableReferenceExpression to)
        {
            expressions.put(from, to);
        }

        private void addLimitingNodePlan(PlanNode limit, CanonicalPlan canonicalPlan)
        {
            if (!limit.getStatsEquivalentPlanNode().isPresent()) {
                addPlanInternal(limit, canonicalPlan);
                return;
            }
            // When limits are involved, we can only know canonicalized plans after topmost limit has been canonicalized.
            // Once we are at topmost limit, we cache canonicalized plans for all sub-plans.
            PlanNode statsEquivalentPlanNode = limit.getStatsEquivalentPlanNode().get();
            StatsEquivalentPlanNodeWithLimit statsEquivalentPlanNodeWithLimit = (StatsEquivalentPlanNodeWithLimit) statsEquivalentPlanNode;
            if (childrenCount(statsEquivalentPlanNodeWithLimit.getLimit()) != childrenCount(statsEquivalentPlanNodeWithLimit.getPlan())) {
                addPlanInternal(limit, canonicalPlan);
                return;
            }
            forTree(PlanNode::getSources)
                    .depthFirstPreOrder(limit)
                    .forEach(child -> {
                        CanonicalPlan childCanonicalPlan = child == limit ? canonicalPlan : canonicalPlans.get(child);
                        if (childCanonicalPlan == null || !child.getStatsEquivalentPlanNode().isPresent()) {
                            return;
                        }
                        // Only save canonicalized plans for stats equivalent plan nodes.
                        canonicalPlans.remove(child);
                        inputTables.remove(child);
                        addPlanInternal(
                                child.getStatsEquivalentPlanNode().get(),
                                new CanonicalPlan(
                                        new StatsEquivalentPlanNodeWithLimit(childCanonicalPlan.getPlan().getId(), childCanonicalPlan.getPlan(), canonicalPlan.getPlan()),
                                        canonicalPlan.getStrategy()));
                    });
        }

        private void addPlan(PlanNode plan, CanonicalPlan canonicalPlan)
        {
            if (!plan.getStatsEquivalentPlanNode().isPresent()) {
                addPlanInternal(plan, canonicalPlan);
                return;
            }
            PlanNode statsEquivalentPlanNode = plan.getStatsEquivalentPlanNode().get();
            if (childrenCount(plan) == childrenCount(statsEquivalentPlanNode)) {
                addPlanInternal(statsEquivalentPlanNode, canonicalPlan);
            }
            else {
                addPlanInternal(plan, canonicalPlan);
            }
        }

        private int childrenCount(PlanNode root)
        {
            return Iterables.size(forTree(PlanNode::getSources).depthFirstPreOrder(root));
        }

        private void addPlanInternal(PlanNode plan, CanonicalPlan canonicalPlan)
        {
            ImmutableList.Builder<TableScanNode> inputs = ImmutableList.builder();
            canonicalPlans.put(plan, canonicalPlan);
            canonicalPlanToPlan.put(canonicalPlan.getPlan(), plan);
            for (PlanNode node : forTree(PlanNode::getSources).depthFirstPreOrder(canonicalPlan.getPlan())) {
                if (node instanceof CanonicalTableScanNode) {
                    if (canonicalPlanToPlan.containsKey(node)) {
                        inputs.add((TableScanNode) canonicalPlanToPlan.get(node));
                    }
                }
            }
            inputTables.put(plan, inputs.build());
        }
    }
}