PushDownDereferences.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.Rule.Context;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import static com.facebook.presto.SystemSessionProperties.isPushdownDereferenceEnabled;
import static com.facebook.presto.expressions.RowExpressionTreeRewriter.rewriteWith;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE;
import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressionsNonRecursive;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.planner.plan.Patterns.unnest;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;
/**
* Push down dereferences as follows:
* <p>
* Extract dereferences from PlanNode which has expressions
* and push them down to a new ProjectNode right below the PlanNode.
* After this step, All dereferences will be in ProjectNode.
* <p>
* Pushdown dereferences in ProjectNode down through other types of PlanNode,
* e.g, Filter, Join etc.
*/
public class PushDownDereferences
{
private final Metadata metadata;
public PushDownDereferences(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}
public Set<Rule<?>> rules()
{
return ImmutableSet.of(
new ExtractFromFilter(),
new ExtractFromJoin(),
new PushDownDereferenceThrough<>(AssignUniqueId.class),
new PushDownDereferenceThrough<>(WindowNode.class),
new PushDownDereferenceThrough<>(TopNNode.class),
new PushDownDereferenceThrough<>(RowNumberNode.class),
new PushDownDereferenceThrough<>(TopNRowNumberNode.class),
new PushDownDereferenceThrough<>(SortNode.class),
new PushDownDereferenceThrough<>(FilterNode.class),
new PushDownDereferenceThrough<>(LimitNode.class),
new PushDownDereferenceThroughProject(),
new PushDownDereferenceThroughUnnest(),
new PushDownDereferenceThroughSemiJoin(),
new PushDownDereferenceThroughJoin());
}
/**
* Extract dereferences and push them down to new ProjectNode below
* Transforms:
* <pre>
* TargetNode(expression(a.x))
* </pre>
* to:
* <pre>
* ProjectNode(original symbols)
* TargetNode(expression(symbol))
* Project(symbol := a.x)
* </pre>
*/
abstract class ExtractProjectDereferences<N extends PlanNode>
implements Rule<N>
{
private final Class<N> planNodeClass;
ExtractProjectDereferences(Class<N> planNodeClass)
{
this.planNodeClass = planNodeClass;
}
@Override
public boolean isEnabled(Session session)
{
return isPushdownDereferenceEnabled(session);
}
@Override
public Pattern<N> getPattern()
{
return Pattern.typeOf(planNodeClass);
}
@Override
public Result apply(N node, Captures captures, Context context)
{
Map<SpecialFormExpression, VariableReferenceExpression> expressions =
getDereferenceSymbolMap(extractExpressionsNonRecursive(node), context, metadata);
if (expressions.isEmpty()) {
return Result.empty();
}
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), rewrite(context, node, HashBiMap.create(expressions)), identityAssignments(node.getOutputVariables())));
}
protected abstract N rewrite(Context context, N node, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions);
}
class ExtractFromFilter
extends ExtractProjectDereferences<FilterNode>
{
ExtractFromFilter()
{
super(FilterNode.class);
}
@Override
protected FilterNode rewrite(Context context, FilterNode node, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions)
{
PlanNode source = node.getSource();
Map<VariableReferenceExpression, RowExpression> dereferencesMap = expressions.inverse().entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
Assignments assignments = Assignments.builder()
.putAll(identityAssignments(source.getOutputVariables()))
.putAll(dereferencesMap)
.build();
ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), source, assignments);
return new FilterNode(
projectNode.getSourceLocation(),
context.getIdAllocator().getNextId(),
projectNode,
replaceDereferences(node.getPredicate(), expressions));
}
}
class ExtractFromJoin
extends ExtractProjectDereferences<JoinNode>
{
ExtractFromJoin()
{
super(JoinNode.class);
}
@Override
protected JoinNode rewrite(Context context, JoinNode joinNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions)
{
Assignments.Builder leftSideDereferences = Assignments.builder();
Assignments.Builder rightSideDereferences = Assignments.builder();
for (Map.Entry<VariableReferenceExpression, SpecialFormExpression> entry : expressions.inverse().entrySet()) {
VariableReferenceExpression baseVariable = getBase(entry.getValue());
if (joinNode.getLeft().getOutputVariables().contains(baseVariable)) {
leftSideDereferences.put(entry.getKey(), entry.getValue());
}
else {
rightSideDereferences.put(entry.getKey(), entry.getValue());
}
}
PlanNode leftNode = createProject(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator());
PlanNode rightNode = createProject(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator());
return new JoinNode(
joinNode.getSourceLocation(),
context.getIdAllocator().getNextId(),
joinNode.getType(),
leftNode,
rightNode,
joinNode.getCriteria(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(leftNode.getOutputVariables())
.addAll(rightNode.getOutputVariables())
.build(),
joinNode.getFilter().map(expression -> replaceDereferences(expression, expressions)),
joinNode.getLeftHashVariable(),
joinNode.getRightHashVariable(),
joinNode.getDistributionType(),
joinNode.getDynamicFilters());
}
}
/**
* Push down dereferences from ProjectNode to child nodes if possible
*/
private abstract class PushdownDereferencesInProject<N extends PlanNode>
implements Rule<ProjectNode>
{
private final Capture<N> targetCapture = newCapture();
private final Pattern<N> targetPattern;
protected PushdownDereferencesInProject(Pattern<N> targetPattern)
{
this.targetPattern = requireNonNull(targetPattern, "targetPattern is null");
}
@Override
public boolean isEnabled(Session session)
{
return isPushdownDereferenceEnabled(session);
}
@Override
public Pattern<ProjectNode> getPattern()
{
return project().with(source().matching(targetPattern.capturedAs(targetCapture)));
}
@Override
public Result apply(ProjectNode node, Captures captures, Context context)
{
N child = captures.get(targetCapture);
Map<SpecialFormExpression, VariableReferenceExpression> allDereferencesInProject = getDereferenceSymbolMap(node.getAssignments().getExpressions(), context, metadata);
Set<VariableReferenceExpression> childSourceVariables = child.getSources().stream()
.map(PlanNode::getOutputVariables).flatMap(Collection::stream)
.collect(toImmutableSet());
Map<SpecialFormExpression, VariableReferenceExpression> pushdownDereferences = allDereferencesInProject.entrySet().stream()
.filter(entry -> childSourceVariables.contains(getBase(entry.getKey())))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
if (pushdownDereferences.isEmpty()) {
return Result.empty();
}
Result result = pushDownDereferences(context, child, HashBiMap.create(pushdownDereferences));
if (result.isEmpty()) {
return Result.empty();
}
Assignments.Builder builder = Assignments.builder();
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getAssignments().entrySet()) {
builder.put(entry.getKey(), replaceDereferences(entry.getValue(), pushdownDereferences));
}
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), result.getTransformedPlan().get(), builder.build()));
}
protected abstract Result pushDownDereferences(Context context, N targetNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions);
}
/**
* Transforms:
* <pre>
* Project(a_x := a.x)
* TargetNode(a)
* </pre>
* to:
* <pre>
* Project(a_x := symbol)
* TargetNode(symbol)
* Project(symbol := a.x)
* </pre>
*/
public class PushDownDereferenceThrough<N extends PlanNode>
extends PushdownDereferencesInProject<N>
{
public PushDownDereferenceThrough(Class<N> planNodeClass)
{
super(Pattern.typeOf(planNodeClass));
}
@Override
protected Result pushDownDereferences(Context context, N targetNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions)
{
PlanNode source = getOnlyElement(targetNode.getSources());
Map<VariableReferenceExpression, RowExpression> dereferencesMap = expressions.inverse().entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
ProjectNode projectNode = new ProjectNode(
context.getIdAllocator().getNextId(),
source,
Assignments.builder()
.putAll(identityAssignments(source.getOutputVariables()))
.putAll(dereferencesMap)
.build());
return Result.ofPlanNode(targetNode.replaceChildren(ImmutableList.of(projectNode)));
}
}
/**
* Transforms:
* <pre>
* Project(a_x := a.msg.x)
* Join(a_y = b_y) => [a]
* Project(a_y := a.msg.y)
* Source(a)
* Project(b_y := b.msg.y)
* Source(b)
* </pre>
* to:
* <pre>
* Project(a_x := symbol)
* Join(a_y = b_y) => [symbol]
* Project(symbol := a.msg.x, a_y := a.msg.y)
* Source(a)
* Project(b_y := b.msg.y)
* Source(b)
* </pre>
*/
public class PushDownDereferenceThroughJoin
extends PushdownDereferencesInProject<JoinNode>
{
PushDownDereferenceThroughJoin()
{
super(join());
}
@Override
protected Result pushDownDereferences(Context context, JoinNode joinNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions)
{
Assignments.Builder leftSideDereferences = Assignments.builder();
Assignments.Builder rightSideDereferences = Assignments.builder();
for (Map.Entry<VariableReferenceExpression, SpecialFormExpression> entry : expressions.inverse().entrySet()) {
VariableReferenceExpression baseVariable = getBase(entry.getValue());
if (joinNode.getLeft().getOutputVariables().contains(baseVariable)) {
leftSideDereferences.put(entry.getKey(), entry.getValue());
}
else {
rightSideDereferences.put(entry.getKey(), entry.getValue());
}
}
PlanNode leftNode = createProject(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator());
PlanNode rightNode = createProject(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator());
return Result.ofPlanNode(new JoinNode(
joinNode.getSourceLocation(),
context.getIdAllocator().getNextId(),
joinNode.getType(),
leftNode,
rightNode,
joinNode.getCriteria(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(leftNode.getOutputVariables())
.addAll(rightNode.getOutputVariables())
.build(),
joinNode.getFilter().map(expression -> replaceDereferences(expression, expressions)),
joinNode.getLeftHashVariable(),
joinNode.getRightHashVariable(),
joinNode.getDistributionType(),
joinNode.getDynamicFilters()));
}
}
public class PushDownDereferenceThroughSemiJoin
extends PushdownDereferencesInProject<SemiJoinNode>
{
PushDownDereferenceThroughSemiJoin()
{
super(semiJoin());
}
@Override
protected Result pushDownDereferences(Context context, SemiJoinNode semiJoinNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions)
{
Assignments.Builder filteringSourceDereferences = Assignments.builder();
Assignments.Builder sourceDereferences = Assignments.builder();
for (Map.Entry<VariableReferenceExpression, SpecialFormExpression> entry : expressions.inverse().entrySet()) {
VariableReferenceExpression baseVariable = getBase(entry.getValue());
if (semiJoinNode.getFilteringSource().getOutputVariables().contains(baseVariable)) {
filteringSourceDereferences.put(entry.getKey(), entry.getValue());
}
else {
sourceDereferences.put(entry.getKey(), entry.getValue());
}
}
PlanNode filteringSource = createProject(semiJoinNode.getFilteringSource(), filteringSourceDereferences.build(), context.getIdAllocator());
PlanNode source = createProject(semiJoinNode.getSource(), sourceDereferences.build(), context.getIdAllocator());
return Result.ofPlanNode(semiJoinNode.replaceChildren(ImmutableList.of(source, filteringSource)));
}
}
public class PushDownDereferenceThroughProject
extends PushdownDereferencesInProject<ProjectNode>
{
PushDownDereferenceThroughProject()
{
super(project());
}
@Override
protected Result pushDownDereferences(Context context, ProjectNode projectNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions)
{
Map<VariableReferenceExpression, RowExpression> dereferencesMap = expressions.inverse().entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
return Result.ofPlanNode(
new ProjectNode(context.getIdAllocator().getNextId(),
projectNode.getSource(),
Assignments.builder()
.putAll(projectNode.getAssignments())
.putAll(dereferencesMap)
.build()));
}
}
public class PushDownDereferenceThroughUnnest
extends PushdownDereferencesInProject<UnnestNode>
{
PushDownDereferenceThroughUnnest()
{
super(unnest());
}
@Override
protected Result pushDownDereferences(Context context, UnnestNode unnestNode, BiMap<SpecialFormExpression, VariableReferenceExpression> expressions)
{
Map<VariableReferenceExpression, RowExpression> dereferencesMap = expressions.inverse().entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
Assignments assignments = Assignments.builder()
.putAll(identityAssignments(unnestNode.getSource().getOutputVariables()))
.putAll(dereferencesMap)
.build();
ProjectNode source = new ProjectNode(context.getIdAllocator().getNextId(), unnestNode.getSource(), assignments);
return Result.ofPlanNode(
new UnnestNode(
unnestNode.getSourceLocation(),
context.getIdAllocator().getNextId(),
source,
ImmutableList.<VariableReferenceExpression>builder()
.addAll(unnestNode.getReplicateVariables())
.addAll(expressions.values())
.build(),
unnestNode.getUnnestVariables(),
unnestNode.getOrdinalityVariable()));
}
}
private RowExpression replaceDereferences(RowExpression rowExpression, Map<SpecialFormExpression, VariableReferenceExpression> dereferences)
{
return rewriteWith(new DereferenceReplacer(dereferences), rowExpression);
}
private static PlanNode createProject(PlanNode planNode, Assignments dereferences, PlanNodeIdAllocator idAllocator)
{
if (dereferences.isEmpty()) {
return planNode;
}
Assignments assignments = Assignments.builder()
.putAll(identityAssignments(planNode.getOutputVariables()))
.putAll(dereferences)
.build();
return new ProjectNode(idAllocator.getNextId(), planNode, assignments);
}
private static class DereferenceReplacer
extends RowExpressionRewriter<Void>
{
private final Map<SpecialFormExpression, VariableReferenceExpression> expressions;
DereferenceReplacer(Map<SpecialFormExpression, VariableReferenceExpression> expressions)
{
this.expressions = requireNonNull(expressions, "expressions is null");
}
@Override
public RowExpression rewriteSpecialForm(SpecialFormExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
if (expressions.containsKey(node)) {
return new VariableReferenceExpression(node.getSourceLocation(), expressions.get(node).getName(), node.getType());
}
return treeRewriter.defaultRewrite(node, context);
}
}
private static List<SpecialFormExpression> extractDereference(RowExpression expression)
{
ImmutableList.Builder<SpecialFormExpression> builder = ImmutableList.builder();
expression.accept(new DefaultRowExpressionTraversalVisitor<ImmutableList.Builder<SpecialFormExpression>>()
{
@Override
public Void visitSpecialForm(SpecialFormExpression node, ImmutableList.Builder<SpecialFormExpression> context)
{
if (isValidDereference(node)) {
context.add(node);
}
else {
node.getArguments().forEach(argument -> argument.accept(this, context));
}
return null;
}
}, builder);
return builder.build();
}
private static Map<SpecialFormExpression, VariableReferenceExpression> getDereferenceSymbolMap(Collection<RowExpression> expressions, Context context, Metadata metadata)
{
Set<SpecialFormExpression> dereferences = expressions.stream()
.flatMap(expression -> extractDereference(expression).stream())
.collect(toImmutableSet());
// DereferenceExpression with the same base will cause unnecessary rewritten
if (dereferences.stream().anyMatch(expression -> baseExists(expression, dereferences))) {
return ImmutableMap.of();
}
return dereferences.stream()
.collect(toImmutableMap(Function.identity(), expression -> createVariable(expression, context)));
}
private static VariableReferenceExpression createVariable(SpecialFormExpression expression, Context context)
{
return context.getVariableAllocator().newVariable(expression);
}
private static boolean baseExists(SpecialFormExpression expression, Set<SpecialFormExpression> dereferences)
{
RowExpression base = expression.getArguments().get(0);
while (base instanceof SpecialFormExpression) {
if (dereferences.contains(base)) {
return true;
}
base = ((SpecialFormExpression) base).getArguments().get(0);
}
return false;
}
private static boolean isValidDereference(SpecialFormExpression dereference)
{
RowExpression expression = dereference;
while (true) {
if (expression instanceof VariableReferenceExpression) {
return true;
}
if (!(expression instanceof SpecialFormExpression) || ((SpecialFormExpression) expression).getForm() != DEREFERENCE) {
return false;
}
expression = ((SpecialFormExpression) expression).getArguments().get(0);
}
}
private static VariableReferenceExpression getBase(RowExpression expression)
{
return getOnlyElement(extractAll(expression));
}
}