TestLocalExecutionPlanner.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.execution.FragmentResultCacheContext;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.DriverFactory;
import com.facebook.presto.operator.Operator;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.OperatorFactory;
import com.facebook.presto.operator.TableScanOperator;
import com.facebook.presto.operator.TaskOutputOperator;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.StageExecutionDescriptor;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.testing.TestingHandle;
import com.facebook.presto.testing.TestingMetadata;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.facebook.presto.tpch.TpchConnectorFactory;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.SystemSessionProperties.ENABLE_INTERMEDIATE_AGGREGATIONS;
import static com.facebook.presto.SystemSessionProperties.FRAGMENT_RESULT_CACHING_ENABLED;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.execution.TaskTestUtils.createTestingPlanner;
import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR;
import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
import static com.facebook.presto.testing.TestingTaskContext.createTaskContext;
import static java.util.Collections.nCopies;
import static java.util.Collections.singletonList;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
public class TestLocalExecutionPlanner
{
private static final ExecutorService EXECUTOR = newCachedThreadPool(daemonThreadsNamed("test-%s"));
private static final ScheduledExecutorService SCHEDULED_EXECUTOR = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s"));
private LocalQueryRunner runner;
@BeforeClass
public void setUp()
{
runner = new LocalQueryRunner(TEST_SESSION);
runner.createCatalog(runner.getDefaultSession().getCatalog().get(),
new TpchConnectorFactory(1),
ImmutableMap.of());
}
@AfterClass(alwaysRun = true)
public void cleanup()
{
closeAllRuntimeException(runner);
runner = null;
}
@Test(enabled = false)
public void testCompilerFailure()
{
// structure the query this way to avoid stack overflow when parsing
String inner = "(" + Joiner.on(" + ").join(nCopies(100, "rand()")) + ")";
String outer = Joiner.on(" + ").join(nCopies(100, inner));
assertFails("SELECT " + outer, COMPILER_ERROR);
}
private void assertFails(@Language("SQL") String sql, ErrorCodeSupplier supplier)
{
try {
runner.execute(sql);
fail("expected exception");
}
catch (PrestoException e) {
assertEquals(e.getErrorCode(), supplier.toErrorCode());
}
}
@Test
public void testCreatingFragmentResultCacheContext()
{
Session session = Session.builder(runner.getDefaultSession())
.setSystemProperty(FRAGMENT_RESULT_CACHING_ENABLED, "true")
.build();
LocalExecutionPlan planWithoutIntermediateAggregation = getLocalExecutionPlan(session);
// Expect one driver factory: partial aggregation
assertEquals(planWithoutIntermediateAggregation.getDriverFactories().size(), 1);
Optional<FragmentResultCacheContext> contextWithoutIntermediateAggregation = planWithoutIntermediateAggregation.getDriverFactories().get(0).getFragmentResultCacheContext();
assertTrue(contextWithoutIntermediateAggregation.isPresent());
session = Session.builder(runner.getDefaultSession())
.setSystemProperty(FRAGMENT_RESULT_CACHING_ENABLED, "true")
.setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true")
.build();
LocalExecutionPlan planWithIntermediateAggregation = getLocalExecutionPlan(session);
// Expect twp driver factories: partial aggregation and intermediate aggregation
assertEquals(planWithIntermediateAggregation.getDriverFactories().size(), 2);
Optional<FragmentResultCacheContext> contextWithIntermediateAggregation = planWithIntermediateAggregation.getDriverFactories().get(0).getFragmentResultCacheContext();
assertTrue(contextWithIntermediateAggregation.isPresent());
assertEquals(contextWithIntermediateAggregation.get().getHashedCanonicalPlanFragment(), contextWithoutIntermediateAggregation.get().getHashedCanonicalPlanFragment());
}
@Test
public void testCustomPlanTranslator()
{
VariableReferenceExpression variable = new VariableReferenceExpression(Optional.empty(), "column", VARCHAR);
PlanNode scan = new TableScanNode(
Optional.empty(),
new PlanNodeId("sourceId"),
new TableHandle(new ConnectorId("test"), new TestingMetadata.TestingTableHandle(), TestingTransactionHandle.create(), Optional.of(TestingHandle.INSTANCE)),
ImmutableList.of(variable),
ImmutableMap.of(variable, new TestingMetadata.TestingColumnHandle("column")),
TupleDomain.all(),
TupleDomain.all(), Optional.empty());
PlanNode node1 = new CustomNodeA(new PlanNodeId("node1"), scan);
PlanNode node2 = new CustomNodeB(new PlanNodeId("node2"), node1);
LocalExecutionPlan plan = getLocalExecutionPlan(
runner.getDefaultSession(),
node2,
ImmutableList.of(new CustomOperatorAFactory.PlanTranslator(), new CustomOperatorBFactory.PlanTranslator()));
List<DriverFactory> driverFactories = plan.getDriverFactories();
assertEquals(driverFactories.size(), 1);
List<OperatorFactory> operatorFactories = driverFactories.get(0).getOperatorFactories();
assertEquals(operatorFactories.size(), 4);
assertTrue(operatorFactories.get(0) instanceof TableScanOperator.TableScanOperatorFactory);
assertTrue(operatorFactories.get(1) instanceof CustomOperatorAFactory);
assertTrue(operatorFactories.get(2) instanceof CustomOperatorBFactory);
assertTrue(operatorFactories.get(3) instanceof TaskOutputOperator.TaskOutputOperatorFactory);
}
private LocalExecutionPlan getLocalExecutionPlan(Session session, PlanNode plan, List<LocalExecutionPlanner.CustomPlanTranslator> customPlanTranslators)
{
PlanFragment testFragment = new PlanFragment(
new PlanFragmentId(0),
plan,
ImmutableSet.of(),
SOURCE_DISTRIBUTION,
ImmutableList.of(new PlanNodeId("sourceId")),
new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()),
StageExecutionDescriptor.ungroupedExecution(),
false,
Optional.of(StatsAndCosts.empty()),
Optional.empty());
return createTestingPlanner().plan(
createTaskContext(EXECUTOR, SCHEDULED_EXECUTOR, session),
testFragment,
new TestingOutputBuffer(),
new TestingRemoteSourceFactory(),
new TableWriteInfo(Optional.empty(), Optional.empty()),
customPlanTranslators);
}
private LocalExecutionPlan getLocalExecutionPlan(Session session)
{
SubPlan subPlan = runner.inTransaction(session, transactionSession -> {
Plan plan = runner.createPlan(transactionSession, "SELECT avg(totalprice) FROM orders", OPTIMIZED_AND_VALIDATED, false, WarningCollector.NOOP);
return runner.createSubPlans(transactionSession, plan, false);
});
// Expect only one child sub plan doing partial aggregation.
assertEquals(subPlan.getChildren().size(), 1);
PlanFragment leafFragment = subPlan.getChildren().get(0).getFragment();
return createTestingPlanner().plan(
createTaskContext(EXECUTOR, SCHEDULED_EXECUTOR, session),
leafFragment,
new TestingOutputBuffer(),
new TestingRemoteSourceFactory(),
new TableWriteInfo(Optional.empty(), Optional.empty()));
}
private static class CustomNodeA
extends CustomNode
{
protected CustomNodeA(PlanNodeId id, PlanNode source)
{
super(Optional.empty(), id, Optional.empty(), source);
}
}
private static class CustomNodeB
extends CustomNode
{
protected CustomNodeB(PlanNodeId id, PlanNode source)
{
super(Optional.empty(), id, Optional.empty(), source);
}
}
private static class CustomNode
extends PlanNode
{
private final PlanNode source;
protected CustomNode(Optional<SourceLocation> sourceLocation, PlanNodeId id, Optional<PlanNode> statsEquivalentPlanNode, PlanNode source)
{
super(sourceLocation, id, statsEquivalentPlanNode);
this.source = source;
}
public PlanNode getSource()
{
return source;
}
@Override
public List<PlanNode> getSources()
{
return singletonList(source);
}
@Override
public List<VariableReferenceExpression> getOutputVariables()
{
return ImmutableList.of();
}
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
throw new UnsupportedOperationException();
}
@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
throw new UnsupportedOperationException();
}
}
public static class CustomOperatorAFactory
extends CustomOperatorFactory
{
public CustomOperatorAFactory(int operatorId, PlanNodeId sourceId)
{
super(operatorId, sourceId);
}
@Override
public Operator createOperator(DriverContext driverContext)
{
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, CustomOperatorA.class.getSimpleName());
return new CustomOperatorA(operatorContext, sourceId);
}
public static class PlanTranslator
extends LocalExecutionPlanner.CustomPlanTranslator
{
@Override
public Optional<LocalExecutionPlanner.PhysicalOperation> translate(
PlanNode node,
LocalExecutionPlanner.LocalExecutionPlanContext context,
InternalPlanVisitor<LocalExecutionPlanner.PhysicalOperation, LocalExecutionPlanner.LocalExecutionPlanContext> visitor)
{
if (node instanceof CustomNodeA) {
OperatorFactory operatorFactory = new CustomOperatorAFactory(
context.getNextOperatorId(),
node.getId());
LocalExecutionPlanner.PhysicalOperation sourceOperator = ((CustomNodeA) node).getSource().accept(visitor, context);
return Optional.of(
new LocalExecutionPlanner.PhysicalOperation(operatorFactory, makeLayout(node), context, sourceOperator));
}
return Optional.empty();
}
}
public static class CustomOperatorA
extends CustomOperator
{
public CustomOperatorA(OperatorContext operatorContext, PlanNodeId planNodeId)
{
super(operatorContext, planNodeId);
}
}
}
public static class CustomOperatorBFactory
extends CustomOperatorFactory
{
public CustomOperatorBFactory(int operatorId, PlanNodeId sourceId)
{
super(operatorId, sourceId);
}
@Override
public Operator createOperator(DriverContext driverContext)
{
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, CustomOperatorB.class.getSimpleName());
return new CustomOperatorB(operatorContext, sourceId);
}
public static class PlanTranslator
extends LocalExecutionPlanner.CustomPlanTranslator
{
@Override
public Optional<LocalExecutionPlanner.PhysicalOperation> translate(
PlanNode node,
LocalExecutionPlanner.LocalExecutionPlanContext context,
InternalPlanVisitor<LocalExecutionPlanner.PhysicalOperation, LocalExecutionPlanner.LocalExecutionPlanContext> visitor)
{
if (node instanceof CustomNodeB) {
OperatorFactory operatorFactory = new CustomOperatorBFactory(
context.getNextOperatorId(),
node.getId());
LocalExecutionPlanner.PhysicalOperation sourceOperator = ((CustomNodeB) node).getSource().accept(visitor, context);
return Optional.of(
new LocalExecutionPlanner.PhysicalOperation(operatorFactory, makeLayout(node), context, sourceOperator));
}
return Optional.empty();
}
}
public class CustomOperatorB
extends CustomOperator
{
public CustomOperatorB(OperatorContext operatorContext, PlanNodeId planNodeId)
{
super(operatorContext, planNodeId);
}
}
}
public abstract static class CustomOperatorFactory
implements OperatorFactory
{
protected final int operatorId;
protected final PlanNodeId sourceId;
public CustomOperatorFactory(
int operatorId,
PlanNodeId sourceId)
{
this.operatorId = operatorId;
this.sourceId = requireNonNull(sourceId, "sourceId is null");
}
public abstract Operator createOperator(DriverContext driverContext);
@Override
public synchronized void noMoreOperators(Lifespan lifespan)
{
}
@Override
public OperatorFactory duplicate()
{
throw new UnsupportedOperationException();
}
@Override
public void noMoreOperators()
{
}
public static class CustomOperator
implements Operator
{
private final OperatorContext operatorContext;
private final PlanNodeId planNodeId;
private boolean finished;
public CustomOperator(
OperatorContext operatorContext,
PlanNodeId planNodeId)
{
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
}
@Override
public OperatorContext getOperatorContext()
{
return operatorContext;
}
@Override
public void close()
{
finish();
}
@Override
public void finish()
{
finished = true;
}
@Override
public boolean isFinished()
{
return finished;
}
@Override
public boolean needsInput()
{
return false;
}
@Override
public void addInput(Page page)
{
throw new UnsupportedOperationException(getClass().getName() + " can not take input");
}
@Override
public Page getOutput()
{
return null;
}
}
}
}