ReplaceConstantVariableReferencesWithConstants.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.OffsetNode;
import com.facebook.presto.sql.planner.plan.SampleNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import static com.facebook.presto.SystemSessionProperties.isRewriteExpressionWithConstantEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.sql.planner.PlannerUtils.addOverrideProjection;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;
/**
* Get constant from filter and project node. Rewrite expressions in parent nodes with the constant.
* For example, for query "select orderkey, orderpriority, avg(totalprice) from orders where orderpriority='3-MEDIUM' group by 1, 2", the query plan changes from
* <pre>
* - OutputNode
* orderkey, orderpriority, avg
* - Aggregate
* avg := avg(totalprice)
* Grouping Keys := [orderkey, orderpriority]
* - Filter
* orderpriority = '3-MEDIUM'
* - TableScan
* orderkey := orderkey
* orderpriority := orderpriority
* totalprice := totalprice
* </pre>
* to
* <pre>
* - OutputNode
* orderkey, expr_12, avg
* - project
* orderkey := orderkey
* expr_12 := '3-MEDIUM'
* avg := avg
* - Aggregate
* avg := avg(totalprice)
* Grouping Keys := [orderkey]
* - Filter
* orderpriority = '3-MEDIUM'
* - TableScan
* orderkey := orderkey
* orderpriority := orderpriority
* totalprice := totalprice
* </pre>
*/
public class ReplaceConstantVariableReferencesWithConstants
implements PlanOptimizer
{
private static final List<Type> SUPPORTED_TYPES = ImmutableList.of(BIGINT, INTEGER, VARCHAR, DATE);
private final FunctionAndTypeManager functionAndTypeManager;
public ReplaceConstantVariableReferencesWithConstants(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
}
private static boolean isSupportedType(RowExpression rowExpression)
{
return SUPPORTED_TYPES.contains(rowExpression.getType()) || rowExpression.getType() instanceof VarcharType;
}
@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
if (isRewriteExpressionWithConstantEnabled(session)) {
Rewriter rewriter = new Rewriter(idAllocator, functionAndTypeManager);
PlanNode rewrittenPlan = plan.accept(rewriter, null).getPlanNode();
return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
}
return PlanOptimizerResult.optimizerResult(plan, false);
}
private static class Rewriter
extends InternalPlanVisitor<PlanNodeWithConstant, Void>
{
private final PlanNodeIdAllocator idAllocator;
private final FunctionResolution functionResolution;
private boolean planChanged;
public Rewriter(PlanNodeIdAllocator idAllocator, FunctionAndTypeManager functionAndTypeManager)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
}
public boolean isPlanChanged()
{
return planChanged;
}
private PlanNodeWithConstant accept(PlanNode node)
{
return node.accept(this, null);
}
private PlanNodeWithConstant planAndReplace(PlanNode node, boolean keepConstantConstraint)
{
List<PlanNodeWithConstant> children = node.getSources().stream().map(this::accept).collect(toImmutableList());
ImmutableList.Builder<PlanNode> newSources = ImmutableList.builder();
for (PlanNodeWithConstant planNodeWithConstant : children) {
PlanNode child = planNodeWithConstant.getPlanNode();
Map<VariableReferenceExpression, ConstantExpression> constantExpressionMap = planNodeWithConstant.getConstantExpressionMap();
if (child.getOutputVariables().stream().noneMatch(x -> constantExpressionMap.containsKey(x))) {
newSources.add(child);
continue;
}
newSources.add(addOverrideProjection(child, idAllocator, constantExpressionMap));
}
PlanNode result = replaceChildren(node, newSources.build());
if (!keepConstantConstraint) {
return new PlanNodeWithConstant(result, ImmutableMap.of());
}
ImmutableMap.Builder<VariableReferenceExpression, ConstantExpression> properties = ImmutableMap.builder();
children.stream().map(PlanNodeWithConstant::getConstantExpressionMap).forEach(properties::putAll);
return new PlanNodeWithConstant(result, properties.build());
}
@Override
public PlanNodeWithConstant visitPlan(PlanNode node, Void context)
{
return planAndReplace(node, false);
}
@Override
public PlanNodeWithConstant visitFilter(FilterNode node, Void context)
{
PlanNodeWithConstant rewrittenChild = accept(node.getSource());
RowExpression predicate = node.getPredicate();
Map<VariableReferenceExpression, ConstantExpression> newConstantMap = new HashMap<>();
newConstantMap.putAll(rewrittenChild.getConstantExpressionMap());
for (RowExpression conjunct : extractConjuncts(predicate)) {
if (conjunct instanceof CallExpression && functionResolution.isEqualsFunction(((CallExpression) conjunct).getFunctionHandle())) {
RowExpression argument0 = ((CallExpression) conjunct).getArguments().get(0);
RowExpression argument1 = ((CallExpression) conjunct).getArguments().get(1);
if (isSupportedType(argument0) && isSupportedType(argument1)) {
if ((argument0 instanceof VariableReferenceExpression && argument1 instanceof ConstantExpression)
|| (argument1 instanceof VariableReferenceExpression && argument0 instanceof ConstantExpression)) {
VariableReferenceExpression variable = (VariableReferenceExpression) (argument0 instanceof VariableReferenceExpression ? argument0 : argument1);
ConstantExpression constant = (ConstantExpression) (argument0 instanceof ConstantExpression ? argument0 : argument1);
// Get conflicting filter expression
if (newConstantMap.containsKey(variable) && !newConstantMap.get(variable).equals(constant)) {
return new PlanNodeWithConstant(replaceChildren(node, ImmutableList.of(rewrittenChild.getPlanNode())), ImmutableMap.of());
}
if (!constant.isNull()) {
planChanged = true;
newConstantMap.put(variable, constant);
}
}
}
}
}
FilterNode newFilterNode = node;
if (!rewrittenChild.getConstantExpressionMap().isEmpty()) {
predicate = predicate.accept(new ExpressionRewriter(rewrittenChild.getConstantExpressionMap()), null);
newFilterNode = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), node.getSource(), predicate);
}
return new PlanNodeWithConstant(replaceChildren(newFilterNode, ImmutableList.of(rewrittenChild.getPlanNode())), newConstantMap);
}
@Override
public PlanNodeWithConstant visitProject(ProjectNode node, Void context)
{
PlanNodeWithConstant rewrittenChild = accept(node.getSource());
ProjectNode newProjectNode = node;
if (!rewrittenChild.getConstantExpressionMap().isEmpty()) {
Map<VariableReferenceExpression, RowExpression> newAssignments = node.getAssignments().getMap().entrySet().stream()
.collect(toImmutableMap(x -> x.getKey(), x -> x.getValue().accept(new ExpressionRewriter(rewrittenChild.getConstantExpressionMap()), null)));
newProjectNode = new ProjectNode(idAllocator.getNextId(), node.getSource(), Assignments.copyOf(newAssignments));
}
ImmutableMap.Builder<VariableReferenceExpression, ConstantExpression> newConstantMap = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : newProjectNode.getAssignments().getMap().entrySet()) {
if (entry.getValue() instanceof ConstantExpression && isSupportedType(entry.getKey()) && isSupportedType(entry.getValue())) {
ConstantExpression constantExpression = (ConstantExpression) entry.getValue();
if (!constantExpression.isNull()) {
planChanged = true;
newConstantMap.put(entry.getKey(), constantExpression);
}
}
}
return new PlanNodeWithConstant(replaceChildren(newProjectNode, ImmutableList.of(rewrittenChild.getPlanNode())), newConstantMap.build());
}
@Override
public PlanNodeWithConstant visitJoin(JoinNode node, Void context)
{
PlanNodeWithConstant rewrittenLeft = accept(node.getLeft());
PlanNodeWithConstant rewrittenRight = accept(node.getRight());
ImmutableMap.Builder<VariableReferenceExpression, ConstantExpression> outputConstantMap = ImmutableMap.builder();
// Output from inner side of outer joins can be NULL when no match, hence will not keep the constant constraint
if (node.getType().equals(JoinType.LEFT) || node.getType().equals(JoinType.INNER)) {
outputConstantMap.putAll(rewrittenLeft.getConstantExpressionMap());
}
if (node.getType().equals(JoinType.RIGHT) || node.getType().equals(JoinType.INNER)) {
outputConstantMap.putAll(rewrittenRight.getConstantExpressionMap());
}
// Add a projection with constant assignment for input source nodes if exist
List<PlanNode> sourceWithConstantProjection = ImmutableList.of(addOverrideProjection(rewrittenLeft.getPlanNode(), idAllocator, rewrittenLeft.getConstantExpressionMap()),
addOverrideProjection(rewrittenRight.getPlanNode(), idAllocator, rewrittenRight.getConstantExpressionMap()));
return new PlanNodeWithConstant(replaceChildren(node, sourceWithConstantProjection), outputConstantMap.build());
}
@Override
public PlanNodeWithConstant visitUnion(UnionNode node, Void context)
{
List<PlanNodeWithConstant> rewrittenSources = node.getSources().stream().map(this::accept).collect(toImmutableList());
ImmutableMap.Builder<VariableReferenceExpression, ConstantExpression> outputConstantMap = ImmutableMap.builder();
if (rewrittenSources.stream().allMatch(x -> !x.getConstantExpressionMap().isEmpty())) {
for (Map.Entry<VariableReferenceExpression, List<VariableReferenceExpression>> entry : node.getVariableMapping().entrySet()) {
VariableReferenceExpression outputVariable = entry.getKey();
List<VariableReferenceExpression> inputList = entry.getValue();
// Output variable is constant only when all corresponding variables in input are the same constant
if (IntStream.range(0, inputList.size()).boxed().allMatch(idx -> rewrittenSources.get(idx).getConstantExpressionMap().containsKey(inputList.get(idx)))
&& IntStream.range(0, inputList.size()).boxed().map(idx -> rewrittenSources.get(idx).getConstantExpressionMap().get(inputList.get(idx))).distinct().count() == 1) {
outputConstantMap.put(outputVariable, rewrittenSources.get(0).getConstantExpressionMap().get(inputList.get(0)));
}
}
}
List<PlanNode> sourceWithConstantProjection = rewrittenSources.stream().map(x -> addOverrideProjection(x.getPlanNode(), idAllocator, x.getConstantExpressionMap())).collect(toImmutableList());
return new PlanNodeWithConstant(replaceChildren(node, sourceWithConstantProjection), outputConstantMap.build());
}
@Override
public PlanNodeWithConstant visitAggregation(AggregationNode node, Void context)
{
PlanNodeWithConstant rewrittenChild = accept(node.getSource());
List<PlanNode> sourceWithConstantProjection = ImmutableList.of(addOverrideProjection(rewrittenChild.getPlanNode(), idAllocator, rewrittenChild.getConstantExpressionMap()));
if (node.getGroupingSetCount() != 1) {
return new PlanNodeWithConstant(replaceChildren(node, sourceWithConstantProjection), ImmutableMap.of());
}
Map<VariableReferenceExpression, ConstantExpression> constantGroupByKeys = rewrittenChild.getConstantExpressionMap().entrySet().stream()
.filter(x -> node.getGroupingKeys().contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
return new PlanNodeWithConstant(replaceChildren(node, sourceWithConstantProjection), constantGroupByKeys);
}
@Override
public PlanNodeWithConstant visitTopN(TopNNode node, Void context)
{
return planAndReplace(node, true);
}
@Override
public PlanNodeWithConstant visitSort(SortNode node, Void context)
{
return planAndReplace(node, true);
}
@Override
public PlanNodeWithConstant visitLimit(LimitNode node, Void context)
{
return planAndReplace(node, true);
}
@Override
public PlanNodeWithConstant visitSample(SampleNode node, Void context)
{
return planAndReplace(node, true);
}
@Override
public PlanNodeWithConstant visitSemiJoin(SemiJoinNode node, Void context)
{
return planAndReplace(node, true);
}
@Override
public PlanNodeWithConstant visitOffset(OffsetNode node, Void context)
{
return planAndReplace(node, true);
}
@Override
public PlanNodeWithConstant visitUnnest(UnnestNode node, Void context)
{
return planAndReplace(node, true);
}
}
private static class PlanNodeWithConstant
{
private final PlanNode planNode;
private final Map<VariableReferenceExpression, ConstantExpression> constantExpressionMap;
public PlanNodeWithConstant(PlanNode planNode, Map<VariableReferenceExpression, ConstantExpression> constantExpressionMap)
{
checkArgument(constantExpressionMap.entrySet().stream().allMatch(entry -> entry.getKey().getType().equals(entry.getValue().getType())),
"key and value in constantExpressionMap not of the same type");
this.planNode = planNode;
this.constantExpressionMap = constantExpressionMap.entrySet().stream().filter(entry -> planNode.getOutputVariables().contains(entry.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
}
public PlanNode getPlanNode()
{
return planNode;
}
public Map<VariableReferenceExpression, ConstantExpression> getConstantExpressionMap()
{
return constantExpressionMap;
}
}
private static class ExpressionRewriter
implements RowExpressionVisitor<RowExpression, Void>
{
private final Map<VariableReferenceExpression, ConstantExpression> expressionMap;
public ExpressionRewriter(Map<VariableReferenceExpression, ConstantExpression> expressionMap)
{
this.expressionMap = ImmutableMap.copyOf(expressionMap);
}
@Override
public RowExpression visitCall(CallExpression call, Void context)
{
return new CallExpression(
call.getSourceLocation(),
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
call.getArguments().stream().map(argument -> argument.accept(this, null)).collect(toImmutableList()));
}
@Override
public RowExpression visitInputReference(InputReferenceExpression reference, Void context)
{
return reference;
}
@Override
public RowExpression visitConstant(ConstantExpression literal, Void context)
{
return literal;
}
@Override
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
{
return lambda;
}
@Override
public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
{
if (expressionMap.containsKey(reference)) {
return expressionMap.get(reference);
}
return reference;
}
@Override
public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context)
{
return new SpecialFormExpression(
specialForm.getForm(),
specialForm.getType(),
specialForm.getArguments().stream().map(argument -> argument.accept(this, null)).collect(toImmutableList()));
}
}
}