EliminateCrossJoins.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.AUTOMATIC;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS;
import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Comparator.comparing;
import static java.util.Objects.requireNonNull;
public class EliminateCrossJoins
implements Rule<JoinNode>
{
private static final Pattern<JoinNode> PATTERN = join();
@Override
public Pattern<JoinNode> getPattern()
{
return PATTERN;
}
@Override
public boolean isEnabled(Session session)
{
// we run this for cost-based reordering also for cases when some of the tables do not have statistics
JoinReorderingStrategy joinReorderingStrategy = getJoinReorderingStrategy(session);
return joinReorderingStrategy == ELIMINATE_CROSS_JOINS || joinReorderingStrategy == AUTOMATIC;
}
@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, context.getLookup());
if (joinGraph.size() < 3) {
return Result.empty();
}
List<Integer> joinOrder = getJoinOrder(joinGraph);
if (isOriginalOrder(joinOrder)) {
return Result.empty();
}
PlanNode replacement = buildJoinTree(node.getOutputVariables(), joinGraph, joinOrder, context.getIdAllocator());
return Result.ofPlanNode(replacement);
}
public static boolean isOriginalOrder(List<Integer> joinOrder)
{
for (int i = 0; i < joinOrder.size(); i++) {
if (joinOrder.get(i) != i) {
return false;
}
}
return true;
}
/**
* Given JoinGraph determine the order of joins between graph nodes
* by traversing JoinGraph. Any graph traversal algorithm could be used
* here (like BFS or DFS), but we use PriorityQueue to preserve
* original JoinOrder as mush as it is possible. PriorityQueue returns
* next nodes to join in order of their occurrence in original Plan.
*/
public static List<Integer> getJoinOrder(JoinGraph graph)
{
ImmutableList.Builder<PlanNode> joinOrder = ImmutableList.builder();
Map<PlanNodeId, Integer> priorities = new HashMap<>();
for (int i = 0; i < graph.size(); i++) {
priorities.put(graph.getNode(i).getId(), i);
}
PriorityQueue<PlanNode> nodesToVisit = new PriorityQueue<>(
graph.size(),
comparing(node -> priorities.get(node.getId())));
Set<PlanNode> visited = new HashSet<>();
nodesToVisit.add(graph.getNode(0));
while (!nodesToVisit.isEmpty()) {
PlanNode node = nodesToVisit.poll();
if (!visited.contains(node)) {
visited.add(node);
joinOrder.add(node);
for (JoinGraph.Edge edge : graph.getEdges(node)) {
nodesToVisit.add(edge.getTargetNode());
}
}
if (nodesToVisit.isEmpty() && visited.size() < graph.size()) {
// disconnected graph, find new starting point
Optional<PlanNode> firstNotVisitedNode = graph.getNodes().stream()
.filter(graphNode -> !visited.contains(graphNode))
.findFirst();
if (firstNotVisitedNode.isPresent()) {
nodesToVisit.add(firstNotVisitedNode.get());
}
}
}
checkState(visited.size() == graph.size());
return joinOrder.build().stream()
.map(node -> priorities.get(node.getId()))
.collect(toImmutableList());
}
public static PlanNode buildJoinTree(List<VariableReferenceExpression> expectedOutputVariables, JoinGraph graph, List<Integer> joinOrder, PlanNodeIdAllocator idAllocator)
{
requireNonNull(expectedOutputVariables, "expectedOutputVariables is null");
requireNonNull(idAllocator, "idAllocator is null");
requireNonNull(graph, "graph is null");
joinOrder = ImmutableList.copyOf(requireNonNull(joinOrder, "joinOrder is null"));
checkArgument(joinOrder.size() >= 2);
PlanNode result = graph.getNode(joinOrder.get(0));
Set<PlanNodeId> alreadyJoinedNodes = new HashSet<>();
alreadyJoinedNodes.add(result.getId());
for (int i = 1; i < joinOrder.size(); i++) {
PlanNode rightNode = graph.getNode(joinOrder.get(i));
alreadyJoinedNodes.add(rightNode.getId());
ImmutableList.Builder<EquiJoinClause> criteria = ImmutableList.builder();
for (JoinGraph.Edge edge : graph.getEdges(rightNode)) {
PlanNode targetNode = edge.getTargetNode();
if (alreadyJoinedNodes.contains(targetNode.getId())) {
criteria.add(new EquiJoinClause(
edge.getTargetVariable(),
edge.getSourceVariable()));
}
}
result = new JoinNode(
result.getSourceLocation(),
idAllocator.getNextId(),
JoinType.INNER,
result,
rightNode,
criteria.build(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(result.getOutputVariables())
.addAll(rightNode.getOutputVariables())
.build(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
ImmutableMap.of());
}
List<RowExpression> filters = graph.getFilters();
for (RowExpression filter : filters) {
result = new FilterNode(
result.getSourceLocation(),
idAllocator.getNextId(),
result,
filter);
}
if (graph.getAssignments().isPresent()) {
result = new ProjectNode(
idAllocator.getNextId(),
result,
Assignments.copyOf(graph.getAssignments().get()));
}
// If needed, introduce a projection to constrain the outputs to what was originally expected
// Some nodes are sensitive to what's produced (e.g., DistinctLimit node)
return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputVariables)).orElse(result);
}
}