OptimizerAssert.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.assertions;
import com.facebook.presto.Session;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.metadata.InMemoryNodeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.nodeManager.PluginNodeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.sql.Optimizer;
import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.RuleStatsRecorder;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.IterativeOptimizer;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
import com.facebook.presto.sql.planner.iterative.rule.SimplifyRowExpressions;
import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin;
import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert.TestingStatsCalculator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PruneUnreferencedOutputs;
import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan;
import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlanDoesNotMatch;
import static com.facebook.presto.transaction.TransactionBuilder.transaction;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.fail;
public class OptimizerAssert
{
private final Metadata metadata;
private final TestingStatsCalculator statsCalculator;
private final PlanOptimizer optimizer;
private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
private final TransactionManager transactionManager;
private final AccessControl accessControl;
private final LocalQueryRunner queryRunner;
private Session session;
private TypeProvider types;
private PlanNode plan;
public OptimizerAssert(Metadata metadata, LocalQueryRunner queryRunner, StatsCalculator statsCalculator, Session session, PlanOptimizer optimizer, TransactionManager transactionManager, AccessControl accessControl)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.statsCalculator = new TestingStatsCalculator(requireNonNull(statsCalculator, "statsCalculator is null"));
this.session = requireNonNull(session, "session is null");
this.optimizer = requireNonNull(optimizer, "optimizer is null");
this.transactionManager = requireNonNull(transactionManager, "transactionManager is null");
this.accessControl = requireNonNull(accessControl, "access control is null");
this.queryRunner = requireNonNull(queryRunner, "queryRunner is null");
}
public OptimizerAssert setSystemProperty(String key, String value)
{
return withSession(Session.builder(session)
.setSystemProperty(key, value)
.build());
}
public OptimizerAssert withSession(Session session)
{
this.session = session;
return this;
}
public OptimizerAssert on(Function<PlanBuilder, PlanNode> planProvider)
{
checkState(plan == null, "plan has already been set");
PlanBuilder builder = new PlanBuilder(session, idAllocator, metadata);
plan = planProvider.apply(builder);
types = builder.getTypes();
return this;
}
public OptimizerAssert on(String sql)
{
checkState(plan == null, "plan has already been set");
//get an initial plan and apply a minimal set of optimizers in preparation for applying the specific rules to be tested
Plan result = queryRunner.inTransaction(session -> queryRunner.createPlan(session, sql, getMinimalOptimizers(), Optimizer.PlanStage.OPTIMIZED, WarningCollector.NOOP));
plan = result.getRoot();
types = result.getTypes();
return this;
}
public void matches(PlanMatchPattern pattern)
{
inTransaction(session -> {
assertPlan(session, metadata, statsCalculator, applyRules(), pattern);
return null;
});
}
public void doesNotMatch(PlanMatchPattern pattern)
{
inTransaction(session -> {
assertPlanDoesNotMatch(session, metadata, statsCalculator, applyRules(), pattern);
return null;
});
}
public void validates(Consumer<Plan> planValidator)
{
planValidator.accept(applyRules());
}
private Plan applyRules()
{
PlanNode actual = optimizer.optimize(plan, session, types, new VariableAllocator(), idAllocator, WarningCollector.NOOP).getPlanNode();
if (!ImmutableSet.copyOf(plan.getOutputVariables()).equals(ImmutableSet.copyOf(actual.getOutputVariables()))) {
fail(String.format(
"%s: output schema of transformed and original plans are not equivalent\n" +
"\texpected: %s\n" +
"\tactual: %s",
optimizer.getClass().getName(),
plan.getOutputVariables(),
actual.getOutputVariables()));
}
return new Plan(actual, types, StatsAndCosts.empty());
}
private List<PlanOptimizer> getMinimalOptimizers()
{
ImmutableSet.Builder<Rule<?>> rulesBuilder = ImmutableSet.builder();
rulesBuilder.add(new TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin());
rulesBuilder.add(new TransformUncorrelatedInPredicateSubqueryToSemiJoin());
rulesBuilder.add(new RemoveRedundantIdentityProjections());
ImmutableSet<Rule<?>> rules = rulesBuilder.build();
return ImmutableList.of(
new UnaliasSymbolReferences(queryRunner.getMetadata().getFunctionAndTypeManager()),
new PruneUnreferencedOutputs(),
new IterativeOptimizer(
queryRunner.getMetadata(),
new RuleStatsRecorder(),
queryRunner.getStatsCalculator(),
queryRunner.getCostCalculator(),
rules),
new IterativeOptimizer(
queryRunner.getMetadata(),
new RuleStatsRecorder(),
queryRunner.getStatsCalculator(),
queryRunner.getCostCalculator(),
new SimplifyRowExpressions(
metadata,
new ExpressionOptimizerManager(
new PluginNodeManager(new InMemoryNodeManager()),
queryRunner.getFunctionAndTypeManager())).rules()));
}
private <T> void inTransaction(Function<Session, T> transactionSessionConsumer)
{
transaction(transactionManager, accessControl)
.singleStatement()
.execute(session, session -> {
// metadata.getCatalogHandle() registers the catalog for the transaction
session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog));
return transactionSessionConsumer.apply(session);
});
}
}