JoinGraph.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations.joins;

import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.sql.relational.ProjectNodeUtils.isIdentity;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

/**
 * JoinGraph represents sequence of Joins, where nodes in the graph
 * are PlanNodes that are being joined and edges are all equality join
 * conditions between pair of nodes.
 */
public class JoinGraph
{
    private final Optional<Map<VariableReferenceExpression, RowExpression>> assignments;
    private final List<RowExpression> filters;
    private final List<PlanNode> nodes; // nodes in order of their appearance in tree plan (left, right, parent)
    private final Multimap<PlanNodeId, Edge> edges;
    private final PlanNodeId rootId;

    /**
     * Builds all (distinct) {@link JoinGraph}-es whole plan tree.
     */
    public static List<JoinGraph> buildFrom(PlanNode plan)
    {
        return buildFrom(plan, Lookup.noLookup());
    }

    /**
     * Builds {@link JoinGraph} containing {@code plan} node.
     */
    public static JoinGraph buildShallowFrom(PlanNode plan, Lookup lookup)
    {
        JoinGraph graph = plan.accept(new Builder(true, lookup), new Context());
        return graph;
    }

    private static List<JoinGraph> buildFrom(PlanNode plan, Lookup lookup)
    {
        Context context = new Context();
        JoinGraph graph = plan.accept(new Builder(false, lookup), context);
        if (graph.size() > 1) {
            context.addSubGraph(graph);
        }
        return context.getGraphs();
    }

    public JoinGraph(PlanNode node)
    {
        this(ImmutableList.of(node), ImmutableMultimap.of(), node.getId(), ImmutableList.of(), Optional.empty());
    }

    public JoinGraph(
            List<PlanNode> nodes,
            Multimap<PlanNodeId, Edge> edges,
            PlanNodeId rootId,
            List<RowExpression> filters,
            Optional<Map<VariableReferenceExpression, RowExpression>> assignments)
    {
        this.nodes = nodes;
        this.edges = edges;
        this.rootId = rootId;
        this.filters = filters;
        this.assignments = assignments;
    }

    public JoinGraph withAssignments(Map<VariableReferenceExpression, RowExpression> assignments)
    {
        return new JoinGraph(nodes, edges, rootId, filters, Optional.of(assignments));
    }

    public Optional<Map<VariableReferenceExpression, RowExpression>> getAssignments()
    {
        return assignments;
    }

    public JoinGraph withFilter(RowExpression expression)
    {
        ImmutableList.Builder<RowExpression> filters = ImmutableList.builder();
        filters.addAll(this.filters);
        filters.add(expression);

        return new JoinGraph(nodes, edges, rootId, filters.build(), assignments);
    }

    public List<RowExpression> getFilters()
    {
        return filters;
    }

    public PlanNodeId getRootId()
    {
        return rootId;
    }

    public JoinGraph withRootId(PlanNodeId rootId)
    {
        return new JoinGraph(nodes, edges, rootId, filters, assignments);
    }

    public boolean isEmpty()
    {
        return nodes.isEmpty();
    }

    public int size()
    {
        return nodes.size();
    }

    public PlanNode getNode(int index)
    {
        return nodes.get(index);
    }

    public List<PlanNode> getNodes()
    {
        return nodes;
    }

    public Collection<Edge> getEdges(PlanNode node)
    {
        return ImmutableList.copyOf(edges.get(node.getId()));
    }

    @Override
    public String toString()
    {
        StringBuilder builder = new StringBuilder();

        for (PlanNode nodeFrom : nodes) {
            builder.append(nodeFrom.getId())
                    .append(" = ")
                    .append(nodeFrom.toString())
                    .append("\n");
        }
        for (PlanNode nodeFrom : nodes) {
            builder.append(nodeFrom.getId())
                    .append(":");
            for (Edge nodeTo : edges.get(nodeFrom.getId())) {
                builder.append(" ").append(nodeTo.getTargetNode().getId());
            }
            builder.append("\n");
        }

        return builder.toString();
    }

    private JoinGraph joinWith(JoinGraph other, List<EquiJoinClause> joinClauses, Context context, PlanNodeId newRoot)
    {
        for (PlanNode node : other.nodes) {
            checkState(!edges.containsKey(node.getId()), format("Node [%s] appeared in two JoinGraphs", node));
        }

        List<PlanNode> nodes = ImmutableList.<PlanNode>builder()
                .addAll(this.nodes)
                .addAll(other.nodes)
                .build();

        ImmutableMultimap.Builder<PlanNodeId, Edge> edges = ImmutableMultimap.<PlanNodeId, Edge>builder()
                .putAll(this.edges)
                .putAll(other.edges);

        List<RowExpression> joinedFilters = ImmutableList.<RowExpression>builder()
                .addAll(this.filters)
                .addAll(other.filters)
                .build();

        for (EquiJoinClause edge : joinClauses) {
            VariableReferenceExpression leftVariable = edge.getLeft();
            VariableReferenceExpression rightVariable = edge.getRight();
            checkState(context.containsVariable(leftVariable));
            checkState(context.containsVariable(rightVariable));

            PlanNode left = context.getVariableSource(leftVariable);
            PlanNode right = context.getVariableSource(rightVariable);
            edges.put(left.getId(), new Edge(right, leftVariable, rightVariable));
            edges.put(right.getId(), new Edge(left, rightVariable, leftVariable));
        }

        return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, Optional.empty());
    }

    private static class Builder
            extends InternalPlanVisitor<JoinGraph, Context>
    {
        // TODO When com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'shallow' flag
        private final boolean shallow;
        private final Lookup lookup;

        private Builder(boolean shallow, Lookup lookup)
        {
            this.shallow = shallow;
            this.lookup = requireNonNull(lookup, "lookup cannot be null");
        }

        @Override
        public JoinGraph visitPlan(PlanNode node, Context context)
        {
            if (!shallow) {
                for (PlanNode child : node.getSources()) {
                    JoinGraph graph = child.accept(this, context);
                    if (graph.size() < 2) {
                        continue;
                    }
                    context.addSubGraph(graph.withRootId(child.getId()));
                }
            }

            for (VariableReferenceExpression variable : node.getOutputVariables()) {
                context.setVariableSource(variable, node);
            }
            return new JoinGraph(node);
        }

        @Override
        public JoinGraph visitFilter(FilterNode node, Context context)
        {
            JoinGraph graph = node.getSource().accept(this, context);
            return graph.withFilter(node.getPredicate());
        }

        @Override
        public JoinGraph visitJoin(JoinNode node, Context context)
        {
            //TODO: add support for non inner joins
            if (node.getType() != INNER) {
                return visitPlan(node, context);
            }

            JoinGraph left = node.getLeft().accept(this, context);
            JoinGraph right = node.getRight().accept(this, context);

            JoinGraph graph = left.joinWith(right, node.getCriteria(), context, node.getId());

            if (node.getFilter().isPresent()) {
                return graph.withFilter(node.getFilter().get());
            }
            return graph;
        }

        @Override
        public JoinGraph visitProject(ProjectNode node, Context context)
        {
            if (isIdentity(node)) {
                JoinGraph graph = node.getSource().accept(this, context);
                return graph.withAssignments(node.getAssignments().getMap());
            }
            return visitPlan(node, context);
        }

        @Override
        public JoinGraph visitGroupReference(GroupReference node, Context context)
        {
            PlanNode dereferenced = lookup.resolve(node);
            JoinGraph graph = dereferenced.accept(this, context);
            if (isTrivialGraph(graph)) {
                return replacementGraph(dereferenced, node, context);
            }
            return graph;
        }

        private boolean isTrivialGraph(JoinGraph graph)
        {
            return graph.nodes.size() < 2 && graph.edges.isEmpty() && graph.filters.isEmpty() && !graph.assignments.isPresent();
        }

        private JoinGraph replacementGraph(PlanNode oldNode, PlanNode newNode, Context context)
        {
            // TODO optimize when idea is generally approved
            List<VariableReferenceExpression> variables = context.variableSources.entrySet().stream()
                    .filter(entry -> entry.getValue() == oldNode)
                    .map(Map.Entry::getKey)
                    .collect(toImmutableList());
            variables.forEach(variable -> context.variableSources.put(variable, newNode));

            return new JoinGraph(newNode);
        }
    }

    public static class Edge
    {
        private final PlanNode targetNode;
        private final VariableReferenceExpression sourceVariable;
        private final VariableReferenceExpression targetVariable;

        public Edge(PlanNode targetNode, VariableReferenceExpression sourceVariable, VariableReferenceExpression targetVariable)
        {
            this.targetNode = requireNonNull(targetNode, "targetNode is null");
            this.sourceVariable = requireNonNull(sourceVariable, "sourceVariable is null");
            this.targetVariable = requireNonNull(targetVariable, "targetVariable is null");
        }

        public PlanNode getTargetNode()
        {
            return targetNode;
        }

        public VariableReferenceExpression getSourceVariable()
        {
            return sourceVariable;
        }

        public VariableReferenceExpression getTargetVariable()
        {
            return targetVariable;
        }
    }

    private static class Context
    {
        private final Map<VariableReferenceExpression, PlanNode> variableSources = new HashMap<>();

        // TODO When com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'joinGraphs'
        private final List<JoinGraph> joinGraphs = new ArrayList<>();

        public void setVariableSource(VariableReferenceExpression variable, PlanNode node)
        {
            variableSources.put(variable, node);
        }

        public void addSubGraph(JoinGraph graph)
        {
            joinGraphs.add(graph);
        }

        public boolean containsVariable(VariableReferenceExpression variable)
        {
            return variableSources.containsKey(variable);
        }

        public PlanNode getVariableSource(VariableReferenceExpression variable)
        {
            checkState(containsVariable(variable));
            return variableSources.get(variable);
        }

        public List<JoinGraph> getGraphs()
        {
            return joinGraphs;
        }
    }
}