IterativeOptimizer.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.CachingCostProvider;
import com.facebook.presto.cost.CachingStatsProvider;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Match;
import com.facebook.presto.matching.Matcher;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.eventlistener.PlanOptimizerInformation;
import com.facebook.presto.spi.plan.LogicalPropertiesProvider;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.RuleStatsRecorder;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.rule.RowExpressionRewriteRuleSet;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizerResult;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import io.airlift.units.Duration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import static com.facebook.presto.SystemSessionProperties.getOptimizersToEnableVerboseRuntimeStats;
import static com.facebook.presto.SystemSessionProperties.isVerboseOptimizerInfoEnabled;
import static com.facebook.presto.SystemSessionProperties.isVerboseRuntimeStatsEnabled;
import static com.facebook.presto.common.RuntimeUnit.NANO;
import static com.facebook.presto.spi.StandardErrorCode.OPTIMIZER_TIMEOUT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
public class IterativeOptimizer
implements PlanOptimizer
{
private final Metadata metadata;
private final RuleStatsRecorder stats;
private final StatsCalculator statsCalculator;
private final CostCalculator costCalculator;
private final List<PlanOptimizer> legacyRules;
private final RuleIndex ruleIndex;
private final Optional<LogicalPropertiesProvider> logicalPropertiesProvider;
public IterativeOptimizer(Metadata metadata, RuleStatsRecorder stats, StatsCalculator statsCalculator, CostCalculator costCalculator, Set<Rule<?>> rules)
{
this(metadata, stats, statsCalculator, costCalculator, ImmutableList.of(), Optional.empty(), rules);
}
public IterativeOptimizer(Metadata metadata, RuleStatsRecorder stats, StatsCalculator statsCalculator, CostCalculator costCalculator, Optional<LogicalPropertiesProvider> logicalPropertiesProvider, Set<Rule<?>> rules)
{
this(metadata, stats, statsCalculator, costCalculator, ImmutableList.of(), logicalPropertiesProvider, rules);
}
public IterativeOptimizer(Metadata metadata, RuleStatsRecorder stats, StatsCalculator statsCalculator, CostCalculator costCalculator, List<PlanOptimizer> legacyRules, Set<Rule<?>> newRules)
{
this(metadata, stats, statsCalculator, costCalculator, legacyRules, Optional.empty(), newRules);
}
public IterativeOptimizer(Metadata metadata, RuleStatsRecorder stats, StatsCalculator statsCalculator, CostCalculator costCalculator, List<PlanOptimizer> legacyRules, Optional<LogicalPropertiesProvider> logicalPropertiesProvider, Set<Rule<?>> newRules)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.stats = requireNonNull(stats, "stats is null");
this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
this.legacyRules = ImmutableList.copyOf(legacyRules);
this.ruleIndex = RuleIndex.builder()
.register(newRules)
.build();
this.logicalPropertiesProvider = requireNonNull(logicalPropertiesProvider, "logicalPropertiesProvider is null");
stats.registerAll(newRules);
}
@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
// only disable new rules if we have legacy rules to fall back to
if (!SystemSessionProperties.isNewOptimizerEnabled(session) && !legacyRules.isEmpty()) {
boolean planChanged = false;
for (PlanOptimizer optimizer : legacyRules) {
PlanOptimizerResult planOptimizerResult = optimizer.optimize(plan, session, TypeProvider.viewOf(variableAllocator.getVariables()), variableAllocator, idAllocator, warningCollector);
plan = planOptimizerResult.getPlanNode();
planChanged = planChanged || planOptimizerResult.isOptimizerTriggered();
}
return PlanOptimizerResult.optimizerResult(plan, planChanged);
}
Memo memo;
if (SystemSessionProperties.isExploitConstraints(session)) {
memo = new Memo(idAllocator, plan, logicalPropertiesProvider);
}
else {
memo = new Memo(idAllocator, plan, Optional.empty());
}
Lookup lookup = Lookup.from(planNode -> Stream.of(memo.resolve(planNode)));
Matcher matcher = new PlanNodeMatcher(lookup);
Duration timeout = SystemSessionProperties.getOptimizerTimeout(session);
StatsProvider statsProvider = new CachingStatsProvider(
statsCalculator,
Optional.of(memo),
lookup,
session,
TypeProvider.viewOf(variableAllocator.getVariables()));
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.of(memo), session);
Context context = new Context(memo, lookup, idAllocator, variableAllocator, System.nanoTime(), timeout.toMillis(), session, warningCollector, costProvider, statsProvider, metadata, types);
boolean planChanged = exploreGroup(memo.getRootGroup(), context, matcher);
context.collectOptimizerInformation();
if (!planChanged) {
return PlanOptimizerResult.optimizerResult(plan, false);
}
return PlanOptimizerResult.optimizerResult(memo.extract(), true);
}
private boolean exploreGroup(int group, Context context, Matcher matcher)
{
// tracks whether this group or any children groups change as
// this method executes
boolean progress = exploreNode(group, context, matcher);
while (exploreChildren(group, context, matcher)) {
progress = true;
// if children changed, try current group again
// in case we can match additional rules
if (!exploreNode(group, context, matcher)) {
// no additional matches, so bail out
break;
}
}
return progress;
}
private boolean exploreNode(int group, Context context, Matcher matcher)
{
PlanNode node = context.memo.getNode(group);
boolean done = false;
boolean progress = false;
while (!done) {
context.checkTimeoutNotExhausted();
done = true;
Iterator<Rule<?>> possiblyMatchingRules = ruleIndex.getCandidates(node).iterator();
while (possiblyMatchingRules.hasNext()) {
Rule<?> rule = possiblyMatchingRules.next();
if (!rule.isEnabled(context.session)) {
if (isVerboseOptimizerInfoEnabled(context.session) && isApplicable(node, rule, matcher, context)) {
context.addRulesApplicable(getNameOfOptimizerRule(rule));
}
continue;
}
Rule.Result result = transform(node, rule, matcher, context);
if (result.getTransformedPlan().isPresent()) {
// If we rewrite a plan node, topmost node should remain statistically equivalent.
PlanNode transformedNode = result.getTransformedPlan().get();
PlanNode resolvedtransformedNode = context.lookup.resolve(result.getTransformedPlan().get());
if (node.getStatsEquivalentPlanNode().isPresent() && !resolvedtransformedNode.getStatsEquivalentPlanNode().isPresent()) {
if (transformedNode instanceof GroupReference) {
context.memo.assignStatsEquivalentPlanNode((GroupReference) transformedNode, node.getStatsEquivalentPlanNode());
}
else {
transformedNode = transformedNode.assignStatsEquivalentPlanNode(node.getStatsEquivalentPlanNode());
}
}
context.addRulesTriggered(getNameOfOptimizerRule(rule), node, transformedNode, rule.isCostBased(context.session), rule.getStatsSource());
node = context.memo.replace(group, transformedNode, rule.getClass().getName());
done = false;
progress = true;
}
}
}
return progress;
}
private String getNameOfOptimizerRule(Rule<?> rule)
{
String ruleName = rule.getClass().getSimpleName();
if (rule instanceof RowExpressionRewriteRuleSet.RowExpressionRewriteRule) {
ruleName = ((RowExpressionRewriteRuleSet.RowExpressionRewriteRule) rule).getOptimizerNameForLog();
}
return ruleName;
}
private <T> Rule.Result transform(PlanNode node, Rule<T> rule, Matcher matcher, Context context)
{
Rule.Result result;
Match<T> match = matcher.match(rule.getPattern(), node);
if (match.isEmpty()) {
return Rule.Result.empty();
}
long duration;
try {
long start = System.nanoTime();
result = rule.apply(match.value(), match.captures(), ruleContext(context));
duration = System.nanoTime() - start;
}
catch (RuntimeException e) {
stats.recordFailure(rule);
throw e;
}
stats.record(rule, duration, !result.isEmpty());
if (isVerboseRuntimeStatsEnabled(context.session) || trackOptimizerRuntime(context.session, rule)) {
context.session.getRuntimeStats().addMetricValue(String.format("rule%sTimeNanos", getNameOfOptimizerRule(rule)), NANO, duration);
}
return result;
}
private boolean trackOptimizerRuntime(Session session, Rule rule)
{
String optimizerString = getOptimizersToEnableVerboseRuntimeStats(session);
if (optimizerString.isEmpty()) {
return false;
}
List<String> optimizers = Splitter.on(",").trimResults().splitToList(optimizerString);
return optimizers.contains(getNameOfOptimizerRule(rule));
}
private <T> boolean isApplicable(PlanNode node, Rule<T> rule, Matcher matcher, Context context)
{
Match<T> match = matcher.match(rule.getPattern(), node);
if (match.isEmpty()) {
return false;
}
Rule.Result result = rule.apply(match.value(), match.captures(), ruleContext(context));
return !result.isEmpty();
}
private boolean exploreChildren(int group, Context context, Matcher matcher)
{
boolean progress = false;
PlanNode expression = context.memo.getNode(group);
for (PlanNode child : expression.getSources()) {
checkState(child instanceof GroupReference, "Expected child to be a group reference. Found: " + child.getClass().getName());
if (exploreGroup(((GroupReference) child).getGroupId(), context, matcher)) {
progress = true;
}
}
return progress;
}
private Rule.Context ruleContext(Context context)
{
return new Rule.Context()
{
@Override
public Lookup getLookup()
{
return context.lookup;
}
@Override
public PlanNodeIdAllocator getIdAllocator()
{
return context.idAllocator;
}
@Override
public VariableAllocator getVariableAllocator()
{
return context.variableAllocator;
}
@Override
public Session getSession()
{
return context.session;
}
@Override
public StatsProvider getStatsProvider()
{
return context.statsProvider;
}
@Override
public CostProvider getCostProvider()
{
return context.costProvider;
}
@Override
public void checkTimeoutNotExhausted()
{
context.checkTimeoutNotExhausted();
}
@Override
public WarningCollector getWarningCollector()
{
return context.warningCollector;
}
@Override
public Optional<LogicalPropertiesProvider> getLogicalPropertiesProvider()
{
return logicalPropertiesProvider;
}
};
}
private static class RuleTriggered
{
private final String rule;
private final Optional<String> oldNode;
private final Optional<String> newNode;
private boolean isCostBased;
private final Optional<String> statsSource;
public RuleTriggered(String rule, Optional<String> oldNode, Optional<String> newNode, boolean isCostBased, String statsSource)
{
this.rule = requireNonNull(rule, "rule is null");
this.oldNode = requireNonNull(oldNode, "oldNode is null");
this.newNode = requireNonNull(newNode, "newNode is null");
this.isCostBased = isCostBased;
this.statsSource = statsSource == null ? Optional.empty() : Optional.of(statsSource);
}
public String getRule()
{
return rule;
}
public Optional<String> getOldNode()
{
return oldNode;
}
public Optional<String> getNewNode()
{
return newNode;
}
public boolean isCostBased()
{
return isCostBased;
}
public Optional<String> getStatsSource()
{
return statsSource;
}
}
private static class Context
{
private final Memo memo;
private final Lookup lookup;
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
private final long startTimeInNanos;
private final long timeoutInMilliseconds;
private final Session session;
private final WarningCollector warningCollector;
private final CostProvider costProvider;
private final StatsProvider statsProvider;
private final Set<RuleTriggered> rulesTriggered;
private final Set<String> rulesApplicable;
private final Metadata metadata;
private final TypeProvider types;
public Context(
Memo memo,
Lookup lookup,
PlanNodeIdAllocator idAllocator,
VariableAllocator variableAllocator,
long startTimeInNanos,
long timeoutInMilliseconds,
Session session,
WarningCollector warningCollector,
CostProvider costProvider,
StatsProvider statsProvider,
Metadata metadata,
TypeProvider types)
{
checkArgument(timeoutInMilliseconds >= 0, "Timeout has to be a non-negative number [milliseconds]");
this.memo = memo;
this.lookup = lookup;
this.idAllocator = idAllocator;
this.variableAllocator = variableAllocator;
this.startTimeInNanos = startTimeInNanos;
this.timeoutInMilliseconds = timeoutInMilliseconds;
this.session = session;
this.warningCollector = warningCollector;
this.costProvider = costProvider;
this.statsProvider = statsProvider;
this.metadata = metadata;
this.types = types;
this.rulesTriggered = new HashSet<>();
this.rulesApplicable = new HashSet<>();
}
public void checkTimeoutNotExhausted()
{
if ((NANOSECONDS.toMillis(System.nanoTime() - startTimeInNanos)) >= timeoutInMilliseconds) {
throw new PrestoException(OPTIMIZER_TIMEOUT, format("The optimizer exhausted the time limit of %d ms", timeoutInMilliseconds));
}
}
public void addRulesTriggered(String rule, PlanNode oldNode, PlanNode newNode, boolean isCostBased, String statsSource)
{
Optional<String> before = Optional.empty();
Optional<String> after = Optional.empty();
if (SystemSessionProperties.isVerboseOptimizerResults(session, rule)) {
before = Optional.of(PlannerUtils.getPlanString(oldNode, session, types, metadata, false));
after = Optional.of(PlannerUtils.getPlanString(newNode, session, types, metadata, false));
}
rulesTriggered.add(new RuleTriggered(rule, before, after, isCostBased, statsSource));
}
public void addRulesApplicable(String rule)
{
rulesApplicable.add(rule);
}
public void collectOptimizerInformation()
{
rulesTriggered.stream().map(
x -> new PlanOptimizerInformation(x.getRule(), true, Optional.empty(), Optional.empty(), Optional.of(x.isCostBased()), x.getStatsSource()))
.distinct().forEach(rule -> session.getOptimizerInformationCollector().addInformation(rule));
if (SystemSessionProperties.isVerboseOptimizerResults(session)) {
rulesTriggered.stream().filter(x -> x.getNewNode().isPresent()).forEach(x -> session.getOptimizerResultCollector().addOptimizerResult(x.getRule(), x.getOldNode().get(), x.getNewNode().get()));
}
rulesApplicable.forEach(x -> session.getOptimizerInformationCollector().addInformation(
new PlanOptimizerInformation(x, false, Optional.of(true), Optional.empty(), Optional.empty(), Optional.empty())));
}
}
}