TestCostCalculator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;
import com.facebook.presto.Session;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.dispatcher.NoOpQueryManager;
import com.facebook.presto.execution.NodeTaskMap;
import com.facebook.presto.execution.QueryManagerConfig;
import com.facebook.presto.execution.scheduler.LegacyNetworkTopology;
import com.facebook.presto.execution.scheduler.NodeScheduler;
import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
import com.facebook.presto.execution.scheduler.nodeSelection.NodeSelectionStats;
import com.facebook.presto.execution.scheduler.nodeSelection.SimpleTtlNodeSelectorConfig;
import com.facebook.presto.metadata.CatalogManager;
import com.facebook.presto.metadata.InMemoryNodeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.CteReferenceNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SimplePlanFragment;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.security.AllowAllAccessControl;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.NodePartitioningManager;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.PlanFragmenter;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JsonCodecSimplePlanFragmentSerde;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager;
import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManagerConfig;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.tpch.TpchColumnHandle;
import com.facebook.presto.tpch.TpchTableHandle;
import com.facebook.presto.tpch.TpchTableLayoutHandle;
import com.facebook.presto.tpch.TpchTransactionHandle;
import com.facebook.presto.transaction.TransactionManager;
import com.facebook.presto.ttl.nodettlfetchermanagers.ThrowingNodeTtlFetcherManager;
import com.facebook.presto.util.FinalizerService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.graph.GraphBuilder;
import com.google.common.graph.MutableGraph;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.count;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.replicatedExchange;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange;
import static com.facebook.presto.sql.relational.Expressions.variable;
import static com.facebook.presto.testing.TestingSession.createBogusTestingCatalog;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static com.facebook.presto.transaction.TransactionBuilder.transaction;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
public class TestCostCalculator
{
private static final int NUMBER_OF_NODES = 10;
private static final double AVERAGE_ROW_SIZE = 8.;
private static final double IS_NULL_OVERHEAD = 9. / AVERAGE_ROW_SIZE;
private static final double OFFSET_AND_IS_NULL_OVERHEAD = 13. / AVERAGE_ROW_SIZE;
private CostCalculator costCalculatorUsingExchanges;
private CostCalculator costCalculatorWithEstimatedExchanges;
private PlanFragmenter planFragmenter;
private Session session;
private MetadataManager metadata;
private TransactionManager transactionManager;
private FinalizerService finalizerService;
private NodeScheduler nodeScheduler;
private NodePartitioningManager nodePartitioningManager;
private TestingRowExpressionTranslator translator;
private PlanCheckerProviderManager planCheckerProviderManager;
@BeforeClass
public void setUp()
{
TaskCountEstimator taskCountEstimator = new TaskCountEstimator(() -> NUMBER_OF_NODES);
costCalculatorUsingExchanges = new CostCalculatorUsingExchanges(taskCountEstimator);
costCalculatorWithEstimatedExchanges = new CostCalculatorWithEstimatedExchanges(costCalculatorUsingExchanges, taskCountEstimator);
session = testSessionBuilder().setCatalog("tpch").build();
CatalogManager catalogManager = new CatalogManager();
catalogManager.registerCatalog(createBogusTestingCatalog("tpch"));
transactionManager = createTestTransactionManager(catalogManager);
metadata = createTestMetadataManager(transactionManager);
finalizerService = new FinalizerService();
finalizerService.start();
nodeScheduler = new NodeScheduler(
new LegacyNetworkTopology(),
new InMemoryNodeManager(),
new NodeSelectionStats(),
new NodeSchedulerConfig().setIncludeCoordinator(true),
new NodeTaskMap(finalizerService),
new ThrowingNodeTtlFetcherManager(),
new NoOpQueryManager(),
new SimpleTtlNodeSelectorConfig());
PartitioningProviderManager partitioningProviderManager = new PartitioningProviderManager();
nodePartitioningManager = new NodePartitioningManager(nodeScheduler, partitioningProviderManager, new NodeSelectionStats());
planCheckerProviderManager = new PlanCheckerProviderManager(new JsonCodecSimplePlanFragmentSerde(jsonCodec(SimplePlanFragment.class)), new PlanCheckerProviderManagerConfig());
planFragmenter = new PlanFragmenter(metadata, nodePartitioningManager, new QueryManagerConfig(), new FeaturesConfig(), planCheckerProviderManager);
translator = new TestingRowExpressionTranslator();
}
@AfterClass(alwaysRun = true)
public void tearDown()
{
costCalculatorUsingExchanges = null;
costCalculatorWithEstimatedExchanges = null;
planFragmenter = null;
session = null;
transactionManager = null;
metadata = null;
finalizerService.destroy();
finalizerService = null;
nodeScheduler.stop();
nodeScheduler = null;
nodePartitioningManager = null;
}
@Test
public void testTableScan()
{
TableScanNode tableScan = tableScan("ts", "orderkey");
Map<String, Type> types = ImmutableMap.of("orderkey", BIGINT);
assertCost(tableScan, ImmutableMap.of(), ImmutableMap.of("ts", statsEstimate(tableScan, 1000)))
.cpu(1000 * IS_NULL_OVERHEAD)
.memory(0)
.network(0);
assertCostEstimatedExchanges(tableScan, ImmutableMap.of(), ImmutableMap.of("ts", statsEstimate(tableScan, 1000)))
.cpu(1000 * IS_NULL_OVERHEAD)
.memory(0)
.network(0);
assertCostSingleStageFragmentedPlan(tableScan, ImmutableMap.of(), ImmutableMap.of("ts", statsEstimate(tableScan, 1000)), types)
.cpu(1000 * IS_NULL_OVERHEAD)
.memory(0)
.network(0);
assertCostHasUnknownComponentsForUnknownStats(tableScan);
}
@Test
public void testProject()
{
TableScanNode tableScan = tableScan("ts", "orderkey");
RowExpression cast = translator.translate(new Cast(new SymbolReference("orderkey"), "VARCHAR"), TypeProvider.viewOf(ImmutableMap.of("orderkey", BIGINT)));
PlanNode project = project("project", tableScan, new VariableReferenceExpression(Optional.empty(), "string", VARCHAR), cast);
Map<String, PlanCostEstimate> costs = ImmutableMap.of("ts", cpuCost(1000));
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"project", statsEstimate(project, 4000),
"ts", statsEstimate(tableScan, 1000));
Map<String, Type> types = ImmutableMap.of(
"orderkey", BIGINT,
"string", VARCHAR);
assertCost(project, costs, stats)
.cpu(1000 + 4000 * OFFSET_AND_IS_NULL_OVERHEAD)
.memory(0)
.network(0);
assertCostEstimatedExchanges(project, costs, stats)
.cpu(1000 + 4000 * OFFSET_AND_IS_NULL_OVERHEAD)
.memory(0)
.network(0);
assertCostSingleStageFragmentedPlan(project, costs, stats, types)
.cpu(1000 + 4000 * OFFSET_AND_IS_NULL_OVERHEAD)
.memory(0)
.network(0);
assertCostHasUnknownComponentsForUnknownStats(project);
}
@Test
public void testRepartitionedJoin()
{
TableScanNode ts1 = tableScan("ts1", "orderkey");
TableScanNode ts2 = tableScan("ts2", "orderkey_0");
JoinNode join = join("join",
ts1,
ts2,
JoinDistributionType.PARTITIONED,
"orderkey",
"orderkey_0");
Map<String, PlanCostEstimate> costs = ImmutableMap.of(
"ts1", cpuCost(6000),
"ts2", cpuCost(1000));
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"join", statsEstimate(join, 12000),
"ts1", statsEstimate(ts1, 6000),
"ts2", statsEstimate(ts2, 1000));
Map<String, Type> types = ImmutableMap.of(
"orderkey", BIGINT,
"orderkey_0", BIGINT);
assertCost(join, costs, stats)
.cpu(6000 + 1000 + (12000 + 6000 + 1000) * IS_NULL_OVERHEAD)
.memory(1000 * IS_NULL_OVERHEAD)
.network(0);
assertCostEstimatedExchanges(join, costs, stats)
.cpu(6000 + 1000 + (12000 + 6000 + 1000 + 6000 + 1000 + 1000) * IS_NULL_OVERHEAD)
.memory(1000 * IS_NULL_OVERHEAD)
.network((6000 + 1000) * IS_NULL_OVERHEAD);
assertCostSingleStageFragmentedPlan(join, costs, stats, types)
.cpu(6000 + 1000 + (12000 + 6000 + 1000) * IS_NULL_OVERHEAD)
.memory(1000 * IS_NULL_OVERHEAD)
.network(0);
assertCostHasUnknownComponentsForUnknownStats(join);
}
@Test
public void testReplicatedJoin()
{
TableScanNode ts1 = tableScan("ts1", "orderkey");
TableScanNode ts2 = tableScan("ts2", "orderkey_0");
JoinNode join = join("join",
ts1,
ts2,
JoinDistributionType.REPLICATED,
"orderkey",
"orderkey_0");
Map<String, PlanCostEstimate> costs = ImmutableMap.of(
"ts1", cpuCost(6000),
"ts2", cpuCost(1000));
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"join", statsEstimate(join, 12000),
"ts1", statsEstimate(ts1, 6000),
"ts2", statsEstimate(ts2, 1000));
Map<String, Type> types = ImmutableMap.of(
"orderkey", BIGINT,
"orderkey_0", BIGINT);
assertCost(join, costs, stats)
.cpu(1000 + 6000 + (12000 + 6000 + 10000 + 1000 * (NUMBER_OF_NODES - 1)) * IS_NULL_OVERHEAD)
.memory(1000 * NUMBER_OF_NODES * IS_NULL_OVERHEAD)
.network(0);
assertCostEstimatedExchanges(join, costs, stats)
.cpu(1000 + 6000 + (12000 + 6000 + 10000 + 1000 * NUMBER_OF_NODES) * IS_NULL_OVERHEAD)
.memory(1000 * NUMBER_OF_NODES * IS_NULL_OVERHEAD)
.network(1000 * NUMBER_OF_NODES * IS_NULL_OVERHEAD);
assertCostSingleStageFragmentedPlan(join, costs, stats, types)
.cpu(1000 + 6000 + (12000 + 6000 + 10000 + 1000 * (NUMBER_OF_NODES - 1)) * IS_NULL_OVERHEAD)
.memory(1000 * NUMBER_OF_NODES * IS_NULL_OVERHEAD)
.network(0);
assertCostHasUnknownComponentsForUnknownStats(join);
}
@Test
public void testMemoryCostJoinAboveJoin()
{
// join
// / \
// ts1 join23
// / \
// ts2 ts3
TableScanNode ts1 = tableScan("ts1", "key1");
TableScanNode ts2 = tableScan("ts2", "key2");
TableScanNode ts3 = tableScan("ts3", "key3");
JoinNode join23 = join(
"join23",
ts2,
ts3,
JoinDistributionType.PARTITIONED,
"key2",
"key3");
JoinNode join = join(
"join",
ts1,
join23,
JoinDistributionType.PARTITIONED,
"key1",
"key2");
Map<String, PlanCostEstimate> costs = ImmutableMap.of(
"ts1", new PlanCostEstimate(0, 128, 128, 0),
"ts2", new PlanCostEstimate(0, 64, 64, 0),
"ts3", new PlanCostEstimate(0, 32, 32, 0));
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"join", statsEstimate(join, 10_000),
"join23", statsEstimate(join23, 2_000),
"ts1", statsEstimate(ts1, 10_000),
"ts2", statsEstimate(ts2, 1_000),
"ts3", statsEstimate(ts3, 100));
Map<String, Type> types = ImmutableMap.of("key1", BIGINT, "key2", BIGINT, "key3", BIGINT);
assertCost(join23, costs, stats)
.memory(
100 * IS_NULL_OVERHEAD // join23 memory footprint
+ 64 + 32) // ts2, ts3 memory footprint
.memoryWhenOutputting(
100 * IS_NULL_OVERHEAD // join23 memory footprint
+ 64); // ts2 memory footprint
assertCost(join, costs, stats)
.memory(
2000 * IS_NULL_OVERHEAD // join memory footprint
+ 100 * IS_NULL_OVERHEAD + 64 // join23 total memory when outputting
+ 128) // ts1 memory footprint
.memoryWhenOutputting(
2000 * IS_NULL_OVERHEAD // join memory footprint
+ 128); // ts1 memory footprint
assertCostEstimatedExchanges(join23, costs, stats)
.memory(
100 * IS_NULL_OVERHEAD // join23 memory footprint
+ 64 + 32) // ts2, ts3 memory footprint
.memoryWhenOutputting(
100 * IS_NULL_OVERHEAD // join23 memory footprint
+ 64); // ts2 memory footprint
assertCostEstimatedExchanges(join, costs, stats)
.memory(
2000 * IS_NULL_OVERHEAD // join memory footprint
+ 100 * IS_NULL_OVERHEAD + 64 // join23 total memory when outputting
+ 128) // ts1 memory footprint
.memoryWhenOutputting(
2000 * IS_NULL_OVERHEAD // join memory footprint
+ 128); // ts1 memory footprint
assertCostSingleStageFragmentedPlan(join23, costs, stats, types)
.memory(
100 * IS_NULL_OVERHEAD // join23 memory footprint
+ 64 + 32) // ts2, ts3 memory footprint
.memoryWhenOutputting(
100 * IS_NULL_OVERHEAD // join23 memory footprint
+ 64); // ts2 memory footprint
assertCostSingleStageFragmentedPlan(join, costs, stats, types)
.memory(
2000 * IS_NULL_OVERHEAD // join memory footprint
+ 100 * IS_NULL_OVERHEAD + 64 // join23 total memory when outputting
+ 128) // ts1 memory footprint
.memoryWhenOutputting(
2000 * IS_NULL_OVERHEAD // join memory footprint
+ 128); // ts1 memory footprint
}
@Test
public void testAggregation()
{
TableScanNode tableScan = tableScan("ts", "orderkey");
AggregationNode aggregation = aggregation("agg", tableScan);
Map<String, PlanCostEstimate> costs = ImmutableMap.of("ts", cpuCost(6000));
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"ts", statsEstimate(tableScan, 6000),
"agg", statsEstimate(aggregation, 13));
Map<String, Type> types = ImmutableMap.of(
"orderkey", BIGINT,
"count", BIGINT);
assertCost(aggregation, costs, stats)
.cpu(6000 * IS_NULL_OVERHEAD + 6000)
.memory(13 * IS_NULL_OVERHEAD)
.network(0);
assertCostEstimatedExchanges(aggregation, costs, stats)
.cpu((6000 + 6000 + 6000) * IS_NULL_OVERHEAD + 6000)
.memory(13 * IS_NULL_OVERHEAD)
.network(6000 * IS_NULL_OVERHEAD);
assertCostSingleStageFragmentedPlan(aggregation, costs, stats, types)
.cpu(6000 + 6000 * IS_NULL_OVERHEAD)
.memory(13 * IS_NULL_OVERHEAD)
.network(0 * IS_NULL_OVERHEAD);
assertCostHasUnknownComponentsForUnknownStats(aggregation);
}
@Test
public void testCteProducer()
{
TableScanNode ts1 = tableScan("ts1", "orderkey");
CteProducerNode cteProducerNode = new CteProducerNode(
Optional.empty(),
new PlanNodeId("cteProducer"),
ts1,
"test_cte",
new VariableReferenceExpression(Optional.empty(), "rows", BIGINT),
ts1.getOutputVariables());
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"ts1", statsEstimate(ts1, 4000));
Map<String, PlanCostEstimate> costs = ImmutableMap.of(
"ts1", new PlanCostEstimate(1000, 10, 10, 1000));
assertCost(cteProducerNode, costs, stats)
.cpu(14500)
.memory(10)
.network(14500);
assertCostEstimatedExchanges(cteProducerNode, costs, stats)
.cpu(14500)
.memory(10)
.network(14500);
}
@Test
public void testCteConsumer()
{
TableScanNode ts1 = tableScan("ts1", "orderkey");
CteConsumerNode cteConsumerNode = new CteConsumerNode(
Optional.empty(),
new PlanNodeId("cteConsumer"),
ts1.getOutputVariables(),
"test_cte", ts1);
// This just symbolizes that the tablescan(original planNode) was more expensive but we used the stats from the stats store
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"cteConsumer", statsEstimate(ts1, 4000),
"ts1", statsEstimate(ts1, 10000000));
assertCost(cteConsumerNode, ImmutableMap.of(), stats)
.cpu(4500)
.memory(0)
.network(0);
}
@Test
public void testCteReferenceCost()
{
TableScanNode ts1 = tableScan("ts1", "orderkey");
CteReferenceNode cteReferenceNode = new CteReferenceNode(
Optional.empty(),
new PlanNodeId("cteReference"),
ts1,
"test");
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"ts1", statsEstimate(ts1, 4000));
assertCost(cteReferenceNode, ImmutableMap.of(), stats);
}
@Test
public void testSequence()
{
// Create PlanNodes
TableScanNode ts1 = tableScan("ts1", "orderkey");
TableScanNode ts2 = tableScan("ts2", "custkey");
CteProducerNode cteProducerNode1 = new CteProducerNode(
Optional.empty(),
new PlanNodeId("cteProducer1"),
ts1,
"cte1",
new VariableReferenceExpression(Optional.empty(), "rows", BIGINT),
ts1.getOutputVariables());
CteProducerNode cteProducerNode2 = new CteProducerNode(
Optional.empty(),
new PlanNodeId("cteProducer2"),
ts2,
"cte2",
new VariableReferenceExpression(Optional.empty(), "rows", BIGINT),
ts2.getOutputVariables());
// Define the CTE consumer nodes that would be used in the join
CteConsumerNode cteConsumerNode1 = new CteConsumerNode(
Optional.empty(),
new PlanNodeId("cteConsumer1"),
ts1.getOutputVariables(),
"cte1",
ts1);
CteConsumerNode cteConsumerNode2 = new CteConsumerNode(
Optional.empty(),
new PlanNodeId("cteConsumer2"),
ts2.getOutputVariables(),
"cte2",
ts1);
JoinNode joinNode = join("join",
cteConsumerNode1,
cteConsumerNode2,
JoinDistributionType.PARTITIONED,
"orderkey", "custkey");
MutableGraph<Integer> sequenceGraph = GraphBuilder.directed().build();
// Add indexes to the graph
sequenceGraph.addNode(0);
sequenceGraph.addNode(1);
SequenceNode sequenceNode = new SequenceNode(
Optional.empty(),
new PlanNodeId("sequence"),
ImmutableList.of(cteProducerNode1, cteProducerNode2),
joinNode,
sequenceGraph);
// Define cost of sequence children
Map<String, PlanCostEstimate> costs = ImmutableMap.of(
"join", new PlanCostEstimate(5000, 5000, 5000, 5000),
"cteProducer1", new PlanCostEstimate(4000, 4000, 4000, 4000),
"cteProducer2", new PlanCostEstimate(3000, 3000, 3000, 3000));
// Assert costs for the sequence node
assertCost(sequenceNode, costs, new HashMap<>())
.cpu(12000)
.memory(12000)
.network(12000);
}
@Test
public void testRepartitionedJoinWithExchange()
{
TableScanNode ts1 = tableScan("ts1", "orderkey");
TableScanNode ts2 = tableScan("ts2", "orderkey_0");
PlanNode p1 = project("p1", ts1, variable("orderkey_1", BIGINT), variable("orderkey", BIGINT));
ExchangeNode remoteExchange1 = systemPartitionedExchange(
new PlanNodeId("re1"),
REMOTE_STREAMING,
p1,
ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_1", BIGINT)),
Optional.empty());
ExchangeNode remoteExchange2 = systemPartitionedExchange(
new PlanNodeId("re2"),
REMOTE_STREAMING,
ts2,
ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_0", BIGINT)),
Optional.empty());
ExchangeNode localExchange = systemPartitionedExchange(
new PlanNodeId("le"),
LOCAL,
remoteExchange2,
ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_0", BIGINT)),
Optional.empty());
JoinNode join = join("join",
remoteExchange1,
localExchange,
JoinDistributionType.PARTITIONED,
"orderkey_1",
"orderkey_0");
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.<String, PlanNodeStatsEstimate>builder()
.put("join", statsEstimate(join, 12000))
.put("re1", statsEstimate(remoteExchange1, 10000))
.put("re2", statsEstimate(remoteExchange2, 10000))
.put("le", statsEstimate(localExchange, 6000))
.put("p1", statsEstimate(p1, 6000))
.put("ts1", statsEstimate(ts1, 6000))
.put("ts2", statsEstimate(ts2, 1000))
.build();
Map<String, Type> types = ImmutableMap.of(
"orderkey", BIGINT,
"orderkey_1", BIGINT,
"orderkey_0", BIGINT);
assertFragmentedEqualsUnfragmented(join, stats, types);
}
@Test
public void testReplicatedJoinWithExchange()
{
TableScanNode ts1 = tableScan("ts1", ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey", BIGINT)));
TableScanNode ts2 = tableScan("ts2", ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_0", BIGINT)));
ExchangeNode remoteExchange2 = replicatedExchange(new PlanNodeId("re2"), REMOTE_STREAMING, ts2);
ExchangeNode localExchange = systemPartitionedExchange(
new PlanNodeId("le"),
LOCAL,
remoteExchange2,
ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_0", BIGINT)),
Optional.empty());
JoinNode join = join("join",
ts1,
localExchange,
JoinDistributionType.REPLICATED,
"orderkey",
"orderkey_0");
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.<String, PlanNodeStatsEstimate>builder()
.put("join", statsEstimate(join, 12000))
.put("re2", statsEstimate(remoteExchange2, 10000))
.put("le", statsEstimate(localExchange, 6000))
.put("ts1", statsEstimate(ts1, 6000))
.put("ts2", statsEstimate(ts2, 1000))
.build();
Map<String, Type> types = ImmutableMap.of(
"orderkey", BIGINT,
"orderkey_0", BIGINT);
assertFragmentedEqualsUnfragmented(join, stats, types);
}
@Test
public void testUnion()
{
TableScanNode ts1 = tableScan("ts1", "orderkey");
TableScanNode ts2 = tableScan("ts2", "orderkey_0");
UnionNode union = new UnionNode(
Optional.empty(),
new PlanNodeId("union"),
ImmutableList.of(ts1, ts2),
ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_1", BIGINT)),
ImmutableMap.of(
new VariableReferenceExpression(Optional.empty(), "orderkey_1", BIGINT),
ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey", BIGINT), new VariableReferenceExpression(Optional.empty(), "orderkey_0", BIGINT))));
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of(
"ts1", statsEstimate(ts1, 4000),
"ts2", statsEstimate(ts2, 1000),
"union", statsEstimate(ts1, 5000));
Map<String, PlanCostEstimate> costs = ImmutableMap.of(
"ts1", cpuCost(1000),
"ts2", cpuCost(1000));
assertCost(union, costs, stats)
.cpu(2000)
.memory(0)
.network(0);
assertCostEstimatedExchanges(union, costs, stats)
.cpu(2000)
.memory(0)
.network(5000 * IS_NULL_OVERHEAD);
}
private CostAssertionBuilder assertCost(
PlanNode node,
Map<String, PlanCostEstimate> costs,
Map<String, PlanNodeStatsEstimate> stats)
{
return assertCost(costCalculatorUsingExchanges, node, costs, stats);
}
private CostAssertionBuilder assertCostEstimatedExchanges(
PlanNode node,
Map<String, PlanCostEstimate> costs,
Map<String, PlanNodeStatsEstimate> stats)
{
return assertCost(costCalculatorWithEstimatedExchanges, node, costs, stats);
}
private CostAssertionBuilder assertCostSingleStageFragmentedPlan(
PlanNode node,
Map<String, PlanCostEstimate> costs,
Map<String, PlanNodeStatsEstimate> stats,
Map<String, Type> types)
{
TypeProvider typeProvider = TypeProvider.copyOf(types);
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator(stats), session, typeProvider);
CostProvider costProvider = new TestingCostProvider(costs, costCalculatorUsingExchanges, statsProvider, session);
// Explicitly generate the statsAndCosts, bypass fragment generation and sanity checks for mock plans.
StatsAndCosts statsAndCosts = StatsAndCosts.create(node, statsProvider, costProvider, session).getForSubplan(node);
return new CostAssertionBuilder(statsAndCosts.getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown()));
}
private static class TestingCostProvider
implements CostProvider
{
private final Map<String, PlanCostEstimate> costs;
private final CostCalculator costCalculator;
private final StatsProvider statsProvider;
private final Session session;
private TestingCostProvider(Map<String, PlanCostEstimate> costs, CostCalculator costCalculator, StatsProvider statsProvider, Session session)
{
this.costs = ImmutableMap.copyOf(requireNonNull(costs, "costs is null"));
this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
this.statsProvider = requireNonNull(statsProvider, "statsProvider is null");
this.session = requireNonNull(session, "session is null");
}
@Override
public PlanCostEstimate getCost(PlanNode node)
{
if (costs.containsKey(node.getId().toString())) {
return costs.get(node.getId().toString());
}
return costCalculator.calculateCost(node, statsProvider, this, session);
}
}
private CostAssertionBuilder assertCost(
CostCalculator costCalculator,
PlanNode node,
Map<String, PlanCostEstimate> costs,
Map<String, PlanNodeStatsEstimate> stats)
{
Function<PlanNode, PlanNodeStatsEstimate> statsProvider = planNode -> stats.get(planNode.getId().toString());
PlanCostEstimate cost = calculateCost(
costCalculator,
node,
sourceCostProvider(costCalculator, costs, statsProvider),
statsProvider);
return new CostAssertionBuilder(cost);
}
private Function<PlanNode, PlanCostEstimate> sourceCostProvider(
CostCalculator costCalculator,
Map<String, PlanCostEstimate> costs,
Function<PlanNode, PlanNodeStatsEstimate> statsProvider)
{
return node -> {
PlanCostEstimate providedCost = costs.get(node.getId().toString());
if (providedCost != null) {
return providedCost;
}
return calculateCost(
costCalculator,
node,
sourceCostProvider(costCalculator, costs, statsProvider),
statsProvider);
};
}
private void assertCostHasUnknownComponentsForUnknownStats(PlanNode node)
{
new CostAssertionBuilder(calculateCost(
costCalculatorUsingExchanges,
node,
planNode -> PlanCostEstimate.unknown(),
planNode -> PlanNodeStatsEstimate.unknown()))
.hasUnknownComponents();
new CostAssertionBuilder(calculateCost(
costCalculatorWithEstimatedExchanges,
node,
planNode -> PlanCostEstimate.unknown(),
planNode -> PlanNodeStatsEstimate.unknown()))
.hasUnknownComponents();
}
private void assertFragmentedEqualsUnfragmented(PlanNode node, Map<String, PlanNodeStatsEstimate> stats, Map<String, Type> types)
{
StatsCalculator statsCalculator = statsCalculator(stats);
PlanCostEstimate costWithExchanges = calculateCost(node, costCalculatorUsingExchanges, statsCalculator, types);
PlanCostEstimate costWithFragments = calculateCostFragmentedPlan(node, statsCalculator, types);
assertEquals(costWithExchanges, costWithFragments);
}
private StatsCalculator statsCalculator(Map<String, PlanNodeStatsEstimate> stats)
{
return (node, sourceStats, lookup, session, types) ->
requireNonNull(stats.get(node.getId().toString()), "no stats for node");
}
private PlanCostEstimate calculateCost(
CostCalculator costCalculator,
PlanNode node,
Function<PlanNode, PlanCostEstimate> costs,
Function<PlanNode, PlanNodeStatsEstimate> stats)
{
return costCalculator.calculateCost(
node,
planNode -> requireNonNull(stats.apply(planNode), "no stats for node"),
source -> requireNonNull(costs.apply(source), format("no cost for source: %s", source.getId())),
session);
}
private PlanCostEstimate calculateCost(PlanNode node, CostCalculator costCalculator, StatsCalculator statsCalculator, Map<String, Type> types)
{
TypeProvider typeProvider = TypeProvider.copyOf(types);
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider);
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session);
return costProvider.getCost(node);
}
private PlanCostEstimate calculateCostFragmentedPlan(PlanNode node, StatsCalculator statsCalculator, Map<String, Type> types)
{
TypeProvider typeProvider = TypeProvider.copyOf(types);
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider);
CostProvider costProvider = new CachingCostProvider(costCalculatorUsingExchanges, statsProvider, Optional.empty(), session);
SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider, session)));
return subPlan.getFragment().getStatsAndCosts().orElse(StatsAndCosts.empty()).getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown());
}
private static class CostAssertionBuilder
{
private final PlanCostEstimate actual;
CostAssertionBuilder(PlanCostEstimate actual)
{
this.actual = requireNonNull(actual, "actual is null");
}
CostAssertionBuilder cpu(double value)
{
assertEquals(actual.getCpuCost(), value, 0.000001);
return this;
}
CostAssertionBuilder memory(double value)
{
assertEquals(actual.getMaxMemory(), value, 0.000001);
return this;
}
CostAssertionBuilder memoryWhenOutputting(double value)
{
assertEquals(actual.getMaxMemoryWhenOutputting(), value, 0.000001);
return this;
}
CostAssertionBuilder network(double value)
{
assertEquals(actual.getNetworkCost(), value, 0.000001);
return this;
}
CostAssertionBuilder hasUnknownComponents()
{
assertTrue(actual.hasUnknownComponents());
return this;
}
}
private static PlanNodeStatsEstimate statsEstimate(PlanNode node, double outputSizeInBytes)
{
return statsEstimate(node.getOutputVariables(), outputSizeInBytes);
}
private static PlanNodeStatsEstimate statsEstimate(Collection<VariableReferenceExpression> variables, double outputSizeInBytes)
{
checkArgument(variables.size() > 0, "No variables");
checkArgument(ImmutableSet.copyOf(variables).size() == variables.size(), "Duplicate variables");
double rowCount = outputSizeInBytes / variables.size() / AVERAGE_ROW_SIZE;
PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder()
.setOutputRowCount(rowCount);
for (VariableReferenceExpression variable : variables) {
builder.addVariableStatistics(
variable,
VariableStatsEstimate.builder()
.setNullsFraction(0)
.setAverageRowSize(AVERAGE_ROW_SIZE)
.build());
}
return builder.build();
}
private TableScanNode tableScan(String id, String... symbols)
{
List<VariableReferenceExpression> variables = Arrays.stream(symbols)
.map(symbol -> new VariableReferenceExpression(Optional.empty(), symbol, BIGINT))
.collect(toImmutableList());
return tableScan(id, variables);
}
private TableScanNode tableScan(String id, List<VariableReferenceExpression> variables)
{
ImmutableMap.Builder<VariableReferenceExpression, ColumnHandle> assignments = ImmutableMap.builder();
for (VariableReferenceExpression variable : variables) {
assignments.put(variable, new TpchColumnHandle("orderkey", BIGINT));
}
TpchTableHandle tableHandle = new TpchTableHandle("orders", 1.0);
return new TableScanNode(
Optional.empty(),
new PlanNodeId(id),
new TableHandle(
new ConnectorId("tpch"),
tableHandle,
TpchTransactionHandle.INSTANCE,
Optional.of(new TpchTableLayoutHandle(tableHandle, TupleDomain.all()))),
variables,
assignments.build(),
TupleDomain.all(),
TupleDomain.all(), Optional.empty());
}
private PlanNode project(String id, PlanNode source, VariableReferenceExpression variable, RowExpression expression)
{
return new ProjectNode(
new PlanNodeId(id),
source,
assignment(variable, expression));
}
private AggregationNode aggregation(String id, PlanNode source)
{
AggregationNode.Aggregation aggregation = count(metadata.getFunctionAndTypeManager());
return new AggregationNode(
Optional.empty(),
new PlanNodeId(id),
source,
ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "count", BIGINT), aggregation),
singleGroupingSet(source.getOutputVariables()),
ImmutableList.of(),
AggregationNode.Step.FINAL,
Optional.empty(),
Optional.empty(),
Optional.empty());
}
/**
* EquiJoinClause is created from symbols in form of:
* symbol[0] = symbol[1] AND symbol[2] = symbol[3] AND ...
*/
private JoinNode join(String planNodeId, PlanNode left, PlanNode right, JoinDistributionType distributionType, String... symbols)
{
checkArgument(symbols.length % 2 == 0);
ImmutableList.Builder<EquiJoinClause> criteria = ImmutableList.builder();
for (int i = 0; i < symbols.length; i += 2) {
criteria.add(new EquiJoinClause(new VariableReferenceExpression(Optional.empty(), symbols[i], BIGINT), new VariableReferenceExpression(Optional.empty(), symbols[i + 1], BIGINT)));
}
return new JoinNode(
Optional.empty(),
new PlanNodeId(planNodeId),
JoinType.INNER,
left,
right,
criteria.build(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(left.getOutputVariables())
.addAll(right.getOutputVariables())
.build(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.of(distributionType),
ImmutableMap.of());
}
private SubPlan fragment(Plan plan)
{
return inTransaction(session -> planFragmenter.createSubPlans(session, plan, false, new PlanNodeIdAllocator(), WarningCollector.NOOP));
}
private <T> T inTransaction(Function<Session, T> transactionSessionConsumer)
{
return transaction(transactionManager, new AllowAllAccessControl())
.singleStatement()
.execute(session, session -> {
// metadata.getCatalogHandle() registers the catalog for the transaction
session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog));
return transactionSessionConsumer.apply(session);
});
}
private static PlanCostEstimate cpuCost(double cpuCost)
{
return new PlanCostEstimate(cpuCost, 0, 0, 0);
}
}