RowExpressionRewriteRuleSet.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SpatialJoinNode;
import com.facebook.presto.spi.plan.StatisticAggregations;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.planner.plan.Patterns.applyNode;
import static com.facebook.presto.sql.planner.plan.Patterns.filter;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.spatialJoin;
import static com.facebook.presto.sql.planner.plan.Patterns.tableFinish;
import static com.facebook.presto.sql.planner.plan.Patterns.tableWriterNode;
import static com.facebook.presto.sql.planner.plan.Patterns.values;
import static com.facebook.presto.sql.planner.plan.Patterns.window;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.builder;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public class RowExpressionRewriteRuleSet
{
public interface PlanRowExpressionRewriter
{
RowExpression rewrite(RowExpression expression, Rule.Context context);
}
protected final PlanRowExpressionRewriter rewriter;
public RowExpressionRewriteRuleSet(PlanRowExpressionRewriter rewriter)
{
this.rewriter = requireNonNull(rewriter, "rewriter is null");
}
public boolean isRewriterEnabled(Session session)
{
return true;
}
public Set<Rule<?>> rules()
{
return ImmutableSet.of(
valueRowExpressionRewriteRule(),
filterRowExpressionRewriteRule(),
projectRowExpressionRewriteRule(),
applyNodeRowExpressionRewriteRule(),
windowRowExpressionRewriteRule(),
joinRowExpressionRewriteRule(),
spatialJoinRowExpressionRewriteRule(),
aggregationRowExpressionRewriteRule(),
tableFinishRowExpressionRewriteRule(),
tableWriterRowExpressionRewriteRule());
}
public Rule<ValuesNode> valueRowExpressionRewriteRule()
{
return new ValuesRowExpressionRewrite();
}
public Rule<FilterNode> filterRowExpressionRewriteRule()
{
return new FilterRowExpressionRewrite();
}
public Rule<ProjectNode> projectRowExpressionRewriteRule()
{
return new ProjectRowExpressionRewrite();
}
public Rule<ApplyNode> applyNodeRowExpressionRewriteRule()
{
return new ApplyRowExpressionRewrite();
}
public Rule<WindowNode> windowRowExpressionRewriteRule()
{
return new WindowRowExpressionRewrite();
}
public Rule<JoinNode> joinRowExpressionRewriteRule()
{
return new JoinRowExpressionRewrite();
}
public Rule<SpatialJoinNode> spatialJoinRowExpressionRewriteRule()
{
return new SpatialJoinRowExpressionRewrite();
}
public Rule<TableFinishNode> tableFinishRowExpressionRewriteRule()
{
return new TableFinishRowExpressionRewrite();
}
public Rule<TableWriterNode> tableWriterRowExpressionRewriteRule()
{
return new TableWriterRowExpressionRewrite();
}
public Rule<AggregationNode> aggregationRowExpressionRewriteRule()
{
return new AggregationRowExpressionRewrite();
}
public abstract class RowExpressionRewriteRule<T>
implements Rule<T>
{
public String getOptimizerNameForLog()
{
String rewriterName = rewriter.getClass().getName();
return format("%s:%s", rewriterName.substring(rewriterName.lastIndexOf('.') + 1), this.getClass().getSimpleName());
}
}
private final class ProjectRowExpressionRewrite
extends RowExpressionRewriteRule<ProjectNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<ProjectNode> getPattern()
{
return project();
}
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context)
{
Assignments.Builder builder = Assignments.builder();
boolean anyRewritten = false;
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : projectNode.getAssignments().getMap().entrySet()) {
RowExpression rewritten = rewriter.rewrite(entry.getValue(), context);
if (!rewritten.equals(entry.getValue())) {
anyRewritten = true;
}
builder.put(entry.getKey(), rewritten);
}
Assignments assignments = builder.build();
if (anyRewritten) {
return Result.ofPlanNode(new ProjectNode(projectNode.getSourceLocation(), projectNode.getId(), projectNode.getSource(), assignments, projectNode.getLocality()));
}
return Result.empty();
}
}
private final class SpatialJoinRowExpressionRewrite
extends RowExpressionRewriteRule<SpatialJoinNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<SpatialJoinNode> getPattern()
{
return spatialJoin();
}
@Override
public Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Context context)
{
RowExpression filter = spatialJoinNode.getFilter();
RowExpression rewritten = rewriter.rewrite(filter, context);
if (filter.equals(rewritten)) {
return Result.empty();
}
return Result.ofPlanNode(new SpatialJoinNode(
spatialJoinNode.getSourceLocation(),
spatialJoinNode.getId(),
spatialJoinNode.getType(),
spatialJoinNode.getLeft(),
spatialJoinNode.getRight(),
spatialJoinNode.getOutputVariables(),
rewritten,
spatialJoinNode.getLeftPartitionVariable(),
spatialJoinNode.getRightPartitionVariable(),
spatialJoinNode.getKdbTree()));
}
}
private final class JoinRowExpressionRewrite
extends RowExpressionRewriteRule<JoinNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<JoinNode> getPattern()
{
return join();
}
@Override
public Result apply(JoinNode joinNode, Captures captures, Context context)
{
if (!joinNode.getFilter().isPresent()) {
return Result.empty();
}
RowExpression filter = joinNode.getFilter().get();
RowExpression rewritten = rewriter.rewrite(filter, context);
if (filter.equals(rewritten)) {
return Result.empty();
}
return Result.ofPlanNode(new JoinNode(
joinNode.getSourceLocation(),
joinNode.getId(),
joinNode.getType(),
joinNode.getLeft(),
joinNode.getRight(),
joinNode.getCriteria(),
joinNode.getOutputVariables(),
Optional.of(rewritten),
joinNode.getLeftHashVariable(),
joinNode.getRightHashVariable(),
joinNode.getDistributionType(),
joinNode.getDynamicFilters()));
}
}
private final class WindowRowExpressionRewrite
extends RowExpressionRewriteRule<WindowNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<WindowNode> getPattern()
{
return window();
}
@Override
public Result apply(WindowNode windowNode, Captures captures, Context context)
{
checkState(windowNode.getSource() != null);
boolean anyRewritten = false;
ImmutableMap.Builder<VariableReferenceExpression, WindowNode.Function> functions = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, WindowNode.Function> entry : windowNode.getWindowFunctions().entrySet()) {
ImmutableList.Builder<RowExpression> newArguments = ImmutableList.builder();
CallExpression callExpression = entry.getValue().getFunctionCall();
for (RowExpression argument : callExpression.getArguments()) {
RowExpression rewritten = rewriter.rewrite(argument, context);
if (rewritten != argument) {
anyRewritten = true;
}
newArguments.add(rewritten);
}
functions.put(
entry.getKey(),
new WindowNode.Function(
call(
callExpression.getDisplayName(),
callExpression.getFunctionHandle(),
callExpression.getType(),
newArguments.build()),
entry.getValue().getFrame(),
entry.getValue().isIgnoreNulls()));
}
if (anyRewritten) {
return Result.ofPlanNode(new WindowNode(
windowNode.getSourceLocation(),
windowNode.getId(),
windowNode.getSource(),
windowNode.getSpecification(),
functions.build(),
windowNode.getHashVariable(),
windowNode.getPrePartitionedInputs(),
windowNode.getPreSortedOrderPrefix()));
}
return Result.empty();
}
}
private final class ApplyRowExpressionRewrite
extends RowExpressionRewriteRule<ApplyNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<ApplyNode> getPattern()
{
return applyNode();
}
@Override
public Result apply(ApplyNode applyNode, Captures captures, Context context)
{
Assignments assignments = applyNode.getSubqueryAssignments();
Optional<Assignments> rewrittenAssignments = translateAssignments(assignments, context);
if (!rewrittenAssignments.isPresent()) {
return Result.empty();
}
return Result.ofPlanNode(new ApplyNode(
applyNode.getSourceLocation(),
applyNode.getId(),
applyNode.getInput(),
applyNode.getSubquery(),
rewrittenAssignments.get(),
applyNode.getCorrelation(),
applyNode.getOriginSubqueryError(),
applyNode.getMayParticipateInAntiJoin()));
}
}
private Optional<Assignments> translateAssignments(Assignments assignments, Rule.Context context)
{
Assignments.Builder builder = Assignments.builder();
assignments.getMap()
.entrySet()
.stream()
.forEach(entry -> builder.put(entry.getKey(), rewriter.rewrite(entry.getValue(), context)));
Assignments rewritten = builder.build();
if (rewritten.equals(assignments)) {
return Optional.empty();
}
return Optional.of(rewritten);
}
private final class FilterRowExpressionRewrite
extends RowExpressionRewriteRule<FilterNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<FilterNode> getPattern()
{
return filter();
}
@Override
public Result apply(FilterNode filterNode, Captures captures, Context context)
{
checkState(filterNode.getSource() != null);
RowExpression rewritten = rewriter.rewrite(filterNode.getPredicate(), context);
if (filterNode.getPredicate().equals(rewritten)) {
return Result.empty();
}
return Result.ofPlanNode(new FilterNode(filterNode.getSourceLocation(), filterNode.getId(), filterNode.getSource(), rewritten));
}
}
private final class ValuesRowExpressionRewrite
extends RowExpressionRewriteRule<ValuesNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<ValuesNode> getPattern()
{
return values();
}
@Override
public Result apply(ValuesNode valuesNode, Captures captures, Context context)
{
boolean anyRewritten = false;
ImmutableList.Builder<List<RowExpression>> rows = ImmutableList.builder();
for (List<RowExpression> row : valuesNode.getRows()) {
ImmutableList.Builder<RowExpression> newRow = ImmutableList.builder();
for (RowExpression rowExpression : row) {
RowExpression rewritten = rewriter.rewrite(rowExpression, context);
if (!rewritten.equals(rowExpression)) {
anyRewritten = true;
}
newRow.add(rewritten);
}
rows.add(newRow.build());
}
if (anyRewritten) {
return Result.ofPlanNode(new ValuesNode(valuesNode.getSourceLocation(), valuesNode.getId(), valuesNode.getOutputVariables(), rows.build(), valuesNode.getValuesNodeLabel()));
}
return Result.empty();
}
}
private final class AggregationRowExpressionRewrite
extends RowExpressionRewriteRule<AggregationNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<AggregationNode> getPattern()
{
return aggregation();
}
@Override
public Result apply(AggregationNode node, Captures captures, Context context)
{
checkState(node.getSource() != null);
boolean changed = false;
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> rewrittenAggregation = builder();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation rewritten = rewriteAggregation(entry.getValue(), context);
rewrittenAggregation.put(entry.getKey(), rewritten);
if (!rewritten.equals(entry.getValue())) {
changed = true;
}
}
if (changed) {
AggregationNode aggregationNode = new AggregationNode(
node.getSourceLocation(),
node.getId(),
node.getSource(),
rewrittenAggregation.build(),
node.getGroupingSets(),
node.getPreGroupedVariables(),
node.getStep(),
node.getHashVariable(),
node.getGroupIdVariable(),
node.getAggregationId());
return Result.ofPlanNode(aggregationNode);
}
return Result.empty();
}
}
private final class TableFinishRowExpressionRewrite
extends RowExpressionRewriteRule<TableFinishNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<TableFinishNode> getPattern()
{
return tableFinish();
}
@Override
public Result apply(TableFinishNode node, Captures captures, Context context)
{
checkState(node.getSource() != null);
if (!node.getStatisticsAggregation().isPresent()) {
return Result.empty();
}
Optional<StatisticAggregations> rewrittenStatisticsAggregation = translateStatisticAggregation(node.getStatisticsAggregation().get(), context);
if (rewrittenStatisticsAggregation.isPresent()) {
return Result.ofPlanNode(new TableFinishNode(
node.getSourceLocation(),
node.getId(),
node.getSource(),
node.getTarget(),
node.getRowCountVariable(),
rewrittenStatisticsAggregation,
node.getStatisticsAggregationDescriptor(),
node.getCteMaterializationInfo()));
}
return Result.empty();
}
}
private Optional<StatisticAggregations> translateStatisticAggregation(StatisticAggregations statisticAggregations, Rule.Context context)
{
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> rewrittenAggregation = builder();
boolean changed = false;
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : statisticAggregations.getAggregations().entrySet()) {
AggregationNode.Aggregation rewritten = rewriteAggregation(entry.getValue(), context);
rewrittenAggregation.put(entry.getKey(), rewritten);
if (!rewritten.equals(entry.getValue())) {
changed = true;
}
}
if (changed) {
return Optional.of(new StatisticAggregations(rewrittenAggregation.build(), statisticAggregations.getGroupingVariables()));
}
return Optional.empty();
}
private final class TableWriterRowExpressionRewrite
extends RowExpressionRewriteRule<TableWriterNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRewriterEnabled(session);
}
@Override
public Pattern<TableWriterNode> getPattern()
{
return tableWriterNode();
}
@Override
public Result apply(TableWriterNode node, Captures captures, Context context)
{
checkState(node.getSource() != null);
if (!node.getStatisticsAggregation().isPresent()) {
return Result.empty();
}
Optional<StatisticAggregations> rewrittenStatisticsAggregation = translateStatisticAggregation(node.getStatisticsAggregation().get(), context);
if (rewrittenStatisticsAggregation.isPresent()) {
return Result.ofPlanNode(new TableWriterNode(
node.getSourceLocation(),
node.getId(),
node.getStatsEquivalentPlanNode(),
node.getSource(),
node.getTarget(),
node.getRowCountVariable(),
node.getFragmentVariable(),
node.getTableCommitContextVariable(),
node.getColumns(),
node.getColumnNames(),
node.getNotNullColumnVariables(),
node.getTablePartitioningScheme(),
rewrittenStatisticsAggregation,
node.getTaskCountIfScaledWriter(),
node.getIsTemporaryTableWriter()));
}
return Result.empty();
}
}
private AggregationNode.Aggregation rewriteAggregation(AggregationNode.Aggregation aggregation, Rule.Context context)
{
RowExpression rewrittenCall = rewriter.rewrite(aggregation.getCall(), context);
checkArgument(rewrittenCall instanceof CallExpression, "Aggregation CallExpression must be rewritten to CallExpression");
return new AggregationNode.Aggregation(
(CallExpression) rewrittenCall,
aggregation.getFilter().map(filter -> rewriter.rewrite(filter, context)),
aggregation.getOrderBy(),
aggregation.isDistinct(),
aggregation.getMask());
}
}