CteProjectionAndPredicatePushDown.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.SimplePlanVisitor;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.rule.SimplifyRowExpressions;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.getCteFilterAndProjectionPushdownEnabled;
import static com.facebook.presto.SystemSessionProperties.isCteMaterializationApplicable;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.planner.PlannerUtils.isConstant;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
/*
* CteProjectionAndPredicatePushDown Transformation:
* This optimizer collects predicates and projections on top of CTE consumers and pushes them into the CTE producer.
*
* Example:
* Before Transformation:
* CTEProducer(cteX)
* |-- SomeOp
* `--Filter (Pred1)
* -- Projection (C1,C2)
* -- CTEConsumer(cteX)
* |-- ...
* `--Filter (Pred2)
* -- Projection (C3,C4)
* -- CTEConsumer(cteX)
*
* After Transformation:
* CTEProducer(cteX)
* |-- Filter (Pred1 or Pred2)
* -- Projection (C1,C2,C3,C4)
* -- SomeOp
* `--Filter (Pred1)
* -- Projection (C1,C2)
* -- CTEConsumer(cteX)
* |-- ...
* `--Filter (Pred2)
* -- Projection (C3,C4)
* -- CTEConsumer(cteX)*/
public class CteProjectionAndPredicatePushDown
implements PlanOptimizer
{
private final Metadata metadata;
private final ExpressionOptimizerManager expressionOptimizerManager;
public CteProjectionAndPredicatePushDown(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null");
}
@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
requireNonNull(plan, "plan is null");
requireNonNull(session, "session is null");
requireNonNull(variableAllocator, "variableAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");
requireNonNull(warningCollector, "warningCollector is null");
if (!isCteMaterializationApplicable(session)
|| !getCteFilterAndProjectionPushdownEnabled(session)) {
return PlanOptimizerResult.optimizerResult(plan, false);
}
CteContext cteContext = new CteContext();
plan.accept(new CtePredicateAndProjectionExtractor(session, idAllocator, variableAllocator), cteContext);
CteProducerRewriter cteProducerRewriter = new CteProducerRewriter(session, idAllocator, variableAllocator);
PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(cteProducerRewriter, plan, cteContext);
return PlanOptimizerResult.optimizerResult(rewrittenPlan, cteProducerRewriter.isPlanRewritten());
}
public class CtePredicateAndProjectionExtractor
extends SimplePlanVisitor<CteContext>
{
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
private final Session session;
public CtePredicateAndProjectionExtractor(Session session, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator must not be null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null");
this.session = requireNonNull(session, "session must not be null");
}
@Override
public Void visitCteProducer(CteProducerNode node, CteContext context)
{
String cteName = node.getCteId();
List<VariableReferenceExpression> columns = node.getOutputVariables();
context.addCteProducerInfo(cteName, columns);
return super.visitCteProducer(node, context);
}
public Void visitFilter(FilterNode node, CteContext context)
{
PlanNode childNode = node.getSource();
if (!(childNode instanceof CteConsumerNode)) {
return super.visitFilter(node, context);
}
String cteName = ((CteConsumerNode) childNode).getCteId();
List<VariableReferenceExpression> producerColumns = context.getCteProducerColumns(cteName);
RowExpression predicate = node.getPredicate();
Map<VariableReferenceExpression, VariableReferenceExpression> varMap = constructConsumerToProducerVarMap((CteConsumerNode) childNode, context);
RowExpression newPredicate = remapExpression(predicate, varMap);
context.addCteConsumerInfo(cteName, producerColumns, ImmutableList.of(newPredicate));
return null;
}
public Void visitProject(ProjectNode node, CteContext context)
{
if (!isCteConsumerFilterRestrict(node)) {
return super.visitProject(node, context);
}
// get predicate and used columns
CteConsumerNode cteConsumerNode = extractCteConsumer(node);
Map<VariableReferenceExpression, VariableReferenceExpression> varMap = constructConsumerToProducerVarMap(cteConsumerNode, context);
List<VariableReferenceExpression> usedColumns =
node.getAssignments().getExpressions().stream().map(expression -> (VariableReferenceExpression) remapExpression(expression, varMap)).collect(Collectors.toList());
FilterNode filterNode = extractFilterNode(node);
RowExpression predicate;
if (filterNode != null) {
predicate = remapExpression(filterNode.getPredicate(), varMap);
// extract predicate columns and add to used columns
usedColumns.addAll(VariablesExtractor.extractAll(predicate));
}
else {
predicate = constant(true, BOOLEAN);
}
context.addCteConsumerInfo(cteConsumerNode.getCteId(), usedColumns, ImmutableList.of(predicate));
return null;
}
@Override
public Void visitCteConsumer(CteConsumerNode node, CteContext context)
{
// if we reach this point, it means that the CTE consumer had no filter or projection on top of it and we must take all columns and rows
String cteName = node.getCteId();
// TODO: support pushing of projections in the cte consumer (similar to table scan projection push down)
// for now, take the original columns of the CTE producer
List<VariableReferenceExpression> producerColumns = context.getCteProducerColumns(cteName);
checkState(producerColumns != null, "No producer with name " + cteName + " found");
// no filter encountered on top of this consumer: all rows must be read
RowExpression predicate = constant(true, BOOLEAN);
context.addCteConsumerInfo(cteName, producerColumns, ImmutableList.of(predicate));
return null;
}
@Override
public Void visitSequence(SequenceNode node, CteContext context)
{
List<PlanNode> cteProducers = node.getCteProducers();
for (int i = cteProducers.size() - 1; i >= 0; i--) {
PlanNode cteProducer = cteProducers.get(i);
cteProducer.accept(this, context);
}
PlanNode primarySource = node.getPrimarySource();
primarySource.accept(this, context);
return null;
}
private boolean isCteConsumerFilterRestrict(PlanNode node)
{
if (!(node instanceof ProjectNode)) {
return false;
}
ProjectNode projectNode = (ProjectNode) node;
PlanNode childNode = projectNode.getSource();
if (isCteConsumerFilter(childNode) || childNode instanceof CteConsumerNode) {
// check if all project elements are restrict-only
return projectNode.getAssignments().getExpressions().stream().allMatch(expression -> expression instanceof VariableReferenceExpression);
}
return false;
}
private boolean isCteConsumerFilter(PlanNode node)
{
return node instanceof FilterNode && ((FilterNode) node).getSource() instanceof CteConsumerNode;
}
private CteConsumerNode extractCteConsumer(PlanNode node)
{
checkState(isCteConsumerFilterRestrict(node));
PlanNode childNode = ((ProjectNode) node).getSource();
if (childNode instanceof CteConsumerNode) {
return (CteConsumerNode) childNode;
}
return (CteConsumerNode) ((FilterNode) childNode).getSource();
}
private FilterNode extractFilterNode(PlanNode node)
{
checkState(isCteConsumerFilterRestrict(node));
PlanNode childNode = ((ProjectNode) node).getSource();
if (childNode instanceof FilterNode) {
return (FilterNode) childNode;
}
// no filter node
return null;
}
private Map<VariableReferenceExpression, VariableReferenceExpression> constructConsumerToProducerVarMap(CteConsumerNode cteConsumerNode, CteContext context)
{
List<VariableReferenceExpression> consumerColumns = cteConsumerNode.getOutputVariables();
List<VariableReferenceExpression> producerColumns = context.getCteProducerColumns(cteConsumerNode.getCteId());
Map<VariableReferenceExpression, VariableReferenceExpression> varMap = constructVarMap(consumerColumns, producerColumns);
return varMap;
}
}
private static Map<VariableReferenceExpression, VariableReferenceExpression> constructVarMap(List<VariableReferenceExpression> sourceColumns, List<VariableReferenceExpression> targetColumns)
{
Map<VariableReferenceExpression, VariableReferenceExpression> varMap = new HashMap<>();
Streams.zip(
sourceColumns.stream(),
targetColumns.stream(),
AbstractMap.SimpleImmutableEntry::new).forEach(pair -> varMap.put(pair.getKey(), pair.getValue()));
return varMap;
}
private RowExpression remapExpression(RowExpression expression, Map<VariableReferenceExpression, VariableReferenceExpression> varMap)
{
return RowExpressionVariableInliner.inlineVariables(varMap, expression);
}
public class CteProducerRewriter
extends SimplePlanRewriter<CteContext>
{
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
private final Session session;
private boolean isPlanRewritten;
public CteProducerRewriter(Session session, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator must not be null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null");
this.session = requireNonNull(session, "session must not be null");
}
@Override
public PlanNode visitCteProducer(CteProducerNode node, RewriteContext<CteContext> context)
{
String cteName = node.getCteId();
List<VariableReferenceExpression> usedColumns = context.get().getCteRequiredColumns(cteName);
List<RowExpression> predicates = context.get().getPredicates(cteName);
// recursively process child node
PlanNode newChild = node.getSource().accept(this, context);
if (usedColumns == null || predicates == null) {
PlanNode result = replaceChildren(node, ImmutableList.of(newChild));
isPlanRewritten = isPlanRewritten || !node.equals(result);
return result;
}
Set<VariableReferenceExpression> usedColumnsSet = new HashSet<VariableReferenceExpression>(usedColumns);
PlanNode newChildWithFilterAndProject = addFilter(newChild, predicates);
List<VariableReferenceExpression> producerColumns = node.getOutputVariables();
List<VariableReferenceExpression> newProducerColumns = producerColumns.stream().filter(var -> usedColumnsSet.contains(var)).collect(Collectors.toList());
if (!newProducerColumns.equals(newChildWithFilterAndProject.getOutputVariables())) {
newChildWithFilterAndProject = PlannerUtils.restrictOutput(newChildWithFilterAndProject, idAllocator, newProducerColumns);
}
if (newChildWithFilterAndProject != node.getSource()) {
isPlanRewritten = true;
return new CteProducerNode(
node.getSourceLocation(),
node.getId(),
node.getStatsEquivalentPlanNode(),
newChildWithFilterAndProject,
cteName,
node.getRowCountVariable(),
newProducerColumns);
}
return node;
}
@Override
public PlanNode visitCteConsumer(CteConsumerNode node, RewriteContext<CteContext> context)
{
// project out consumer columns
List<VariableReferenceExpression> allProducerColumns = context.get().getCteProducerColumns(node.getCteId());
List<VariableReferenceExpression> requiredProducerColumns = context.get().getCteRequiredColumns(node.getCteId());
checkState(requiredProducerColumns != null, "Required columns for producer " + node.getCteId() + " not found");
Set<VariableReferenceExpression> requiredProducerColumnsSet = new HashSet<>(requiredProducerColumns);
List<VariableReferenceExpression> newConsumerColumns = new ArrayList<>();
Streams.zip(
allProducerColumns.stream(),
node.getOutputVariables().stream(),
AbstractMap.SimpleImmutableEntry::new).forEach(pair -> {
if (requiredProducerColumnsSet.contains(pair.getKey())) {
newConsumerColumns.add(pair.getValue());
}
});
return new CteConsumerNode(node.getSourceLocation(), node.getId(), node.getStatsEquivalentPlanNode(), newConsumerColumns, node.getCteId(), node.getOriginalSource());
}
public boolean isPlanRewritten()
{
return isPlanRewritten;
}
private PlanNode addFilter(PlanNode node, List<RowExpression> predicates)
{
if (isConstTrue(predicates)) {
return node;
}
RowExpression resultPredicate;
if (predicates.size() == 1) {
resultPredicate = predicates.get(0);
}
else {
resultPredicate = predicates.get(0);
for (int i = 1; i < predicates.size(); i++) {
resultPredicate = new SpecialFormExpression(
SpecialFormExpression.Form.OR,
BOOLEAN,
resultPredicate, predicates.get(i));
}
}
resultPredicate = SimplifyRowExpressions.rewrite(resultPredicate, metadata, session, expressionOptimizerManager);
return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), node, resultPredicate);
}
private boolean isConstTrue(List<RowExpression> predicates)
{
return predicates.size() == 0 || predicates.stream().anyMatch(predicate -> isConstant(predicate, BOOLEAN, true));
}
}
public static class CteContext
{
private Map<String, CteInfo> cteNameToTableInfo;
private Map<String, List<VariableReferenceExpression>> cteProducerOutputColumnsMap;
public CteContext()
{
cteNameToTableInfo = new HashMap<>();
cteProducerOutputColumnsMap = new HashMap<>();
}
public void addCteProducerInfo(String cteName, List<VariableReferenceExpression> outputColumns)
{
requireNonNull(outputColumns, "CTE producer output columns cannot be null");
checkState(!cteProducerOutputColumnsMap.containsKey(cteName), "CTE producer columns already recorded.");
cteProducerOutputColumnsMap.put(cteName, outputColumns);
}
public void addCteConsumerInfo(String cteName, List<VariableReferenceExpression> columns, List<RowExpression> predicates)
{
CteInfo cteInfo = cteNameToTableInfo.getOrDefault(cteName, new CteInfo(new HashSet() {}, new ArrayList<>()));
cteInfo.addColumns(columns);
cteInfo.addPredicates(predicates);
cteNameToTableInfo.put(cteName, cteInfo);
}
public List<VariableReferenceExpression> getCteProducerColumns(String cteName)
{
return cteProducerOutputColumnsMap.getOrDefault(cteName, null);
}
public List<VariableReferenceExpression> getCteRequiredColumns(String cteName)
{
if (cteNameToTableInfo.containsKey(cteName)) {
return new ArrayList<>(cteNameToTableInfo.get(cteName).getColumns());
}
// if no CTE is found, this means no CTE consumers found during exploration: let the caller handle this case
return null;
}
public List<RowExpression> getPredicates(String cteName)
{
if (cteNameToTableInfo.containsKey(cteName)) {
return cteNameToTableInfo.get(cteName).getPredicates();
}
// if no CTE is found, this means no CTE consumers found during exploration: let the caller handle this case
return null;
}
public static class CteInfo
{
private Set<VariableReferenceExpression> columns;
private List<RowExpression> predicates;
public CteInfo(Set<VariableReferenceExpression> columns, List<RowExpression> predicates)
{
this.columns = requireNonNull(columns, "columns must not be null");
this.predicates = requireNonNull(predicates, "predicates must not be null");
}
public Set<VariableReferenceExpression> getColumns()
{
return columns;
}
public List<RowExpression> getPredicates()
{
return predicates;
}
public void addColumns(List<VariableReferenceExpression> columns)
{
this.columns.addAll(columns);
}
public void addPredicates(List<RowExpression> predicates)
{
this.predicates.addAll(predicates);
}
}
}
}