KeyBasedSampler.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.Varchars;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.PrestoWarning;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.type.TypeUtils;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import static com.facebook.presto.SystemSessionProperties.getKeyBasedSamplingFunction;
import static com.facebook.presto.SystemSessionProperties.getKeyBasedSamplingPercentage;
import static com.facebook.presto.SystemSessionProperties.isKeyBasedSamplingEnabled;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.metadata.CastType.CAST;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND;
import static com.facebook.presto.spi.StandardWarningCode.SAMPLED_FIELDS;
import static com.facebook.presto.spi.StandardWarningCode.SEMANTIC_WARNING;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static java.util.Objects.requireNonNull;
public class KeyBasedSampler
implements PlanOptimizer
{
private final Metadata metadata;
private boolean isEnabledForTesting;
public KeyBasedSampler(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}
@Override
public void setEnabledForTesting(boolean isSet)
{
isEnabledForTesting = isSet;
}
@Override
public boolean isEnabled(Session session)
{
return isEnabledForTesting || isKeyBasedSamplingEnabled(session);
}
@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
if (isEnabled(session)) {
List<String> sampledFields = new ArrayList<>(2);
PlanNode rewritten = SimplePlanRewriter.rewriteWith(new Rewriter(session, metadata.getFunctionAndTypeManager(), idAllocator, sampledFields), plan, null);
if (!isEnabledForTesting) {
if (!sampledFields.isEmpty()) {
warningCollector.add(new PrestoWarning(SAMPLED_FIELDS, String.format("Sampled the following columns/derived columns at %s percent:%n\t%s", getKeyBasedSamplingPercentage(session) * 100., String.join("\n\t", sampledFields))));
}
else {
warningCollector.add(new PrestoWarning(SEMANTIC_WARNING, "Sampling could not be performed due to the query structure"));
}
}
return PlanOptimizerResult.optimizerResult(rewritten, true);
}
return PlanOptimizerResult.optimizerResult(plan, false);
}
private static class Rewriter
extends SimplePlanRewriter<Void>
{
private final Session session;
private final FunctionAndTypeManager functionAndTypeManager;
private final PlanNodeIdAllocator idAllocator;
private final List<String> sampledFields;
private Rewriter(Session session, FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator idAllocator, List<String> sampledFields)
{
this.session = requireNonNull(session, "session is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.sampledFields = requireNonNull(sampledFields, "sampledFields is null");
}
private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional<VariableReferenceExpression> rowExpressionOptional, FunctionAndTypeManager functionAndTypeManager)
{
if (!rowExpressionOptional.isPresent()) {
return tableScanNode;
}
RowExpression rowExpression = rowExpressionOptional.get();
RowExpression arg;
Type type = rowExpression.getType();
if (!Varchars.isVarcharType(type)) {
arg = call(
"CAST",
functionAndTypeManager.lookupCast(CAST, rowExpression.getType(), VARCHAR),
VARCHAR,
rowExpression);
}
else {
arg = rowExpression;
}
RowExpression sampledArg;
try {
sampledArg = call(
functionAndTypeManager,
QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, getKeyBasedSamplingFunction(session)),
DOUBLE,
ImmutableList.of(arg));
}
catch (PrestoException prestoException) {
throw new PrestoException(FUNCTION_NOT_FOUND, String.format("Sampling function: %s not cannot be resolved", getKeyBasedSamplingFunction(session)), prestoException);
}
RowExpression predicate = call(
LESS_THAN_OR_EQUAL.name(),
functionAndTypeManager.resolveOperator(OperatorType.LESS_THAN_OR_EQUAL, fromTypes(DOUBLE, DOUBLE)),
BOOLEAN,
sampledArg,
new ConstantExpression(arg.getSourceLocation(), getKeyBasedSamplingPercentage(session), DOUBLE));
FilterNode filterNode = new FilterNode(
tableScanNode.getSourceLocation(),
idAllocator.getNextId(),
tableScanNode,
predicate);
String tableName;
while (tableScanNode instanceof FilterNode || tableScanNode instanceof ProjectNode) {
tableScanNode = tableScanNode.getSources().get(0);
}
if (tableScanNode instanceof TableScanNode) {
tableName = ((TableScanNode) tableScanNode).getTable().getConnectorHandle().toString();
}
else {
tableName = "plan node: " + tableScanNode.getId();
}
sampledFields.add(String.format("%s from %s", rowExpression, tableName));
return filterNode;
}
private Optional<VariableReferenceExpression> findSuitableKey(List<VariableReferenceExpression> keys)
{
Optional<VariableReferenceExpression> variableReferenceExpression = keys.stream()
.filter(x -> TypeUtils.isIntegralType(x.getType().getTypeSignature(), functionAndTypeManager))
.findFirst();
if (!variableReferenceExpression.isPresent()) {
variableReferenceExpression = keys.stream()
.filter(x -> Varchars.isVarcharType(x.getType()))
.findFirst();
}
return variableReferenceExpression;
}
private PlanNode sampleSourceNodeWithKey(PlanNode planNode, PlanNode source, List<VariableReferenceExpression> keys)
{
PlanNode rewrittenSource = rewriteWith(this, source);
if (rewrittenSource == source) {
// Source not rewritten so we sample here.
rewrittenSource = addSamplingFilter(source, findSuitableKey(keys), functionAndTypeManager);
}
// Always return new
return replaceChildren(planNode, ImmutableList.of(rewrittenSource));
}
@Override
public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
{
PlanNode left = node.getLeft();
PlanNode right = node.getRight();
PlanNode rewrittenLeft = rewriteWith(this, left);
PlanNode rewrittenRight = rewriteWith(this, right);
// If at least one of them is unchanged means it had no join. So one side has a table scan.
// So we apply filter on both sides.
if (left == rewrittenLeft || right == rewrittenRight) {
// Sample both sides if at least one side is not already sampled
// Find the best equijoin clause so we sample both sides the same way optimally
// First see if there is a int/bigint key
Optional<EquiJoinClause> equiJoinClause = node.getCriteria().stream()
.filter(x -> TypeUtils.isIntegralType(x.getLeft().getType().getTypeSignature(), functionAndTypeManager))
.findFirst();
if (!equiJoinClause.isPresent()) {
// See if there is a varchar key
equiJoinClause = node.getCriteria().stream()
.filter(x -> Varchars.isVarcharType(x.getLeft().getType()))
.findFirst();
}
if (equiJoinClause.isPresent()) {
rewrittenLeft = addSamplingFilter(rewrittenLeft, Optional.of(equiJoinClause.get().getLeft()), functionAndTypeManager);
rewrittenRight = addSamplingFilter(rewrittenRight, Optional.of(equiJoinClause.get().getRight()), functionAndTypeManager);
}
}
// We always return new from join so others won't be applied if not needed.
return replaceChildren(node, ImmutableList.of(rewrittenLeft, rewrittenRight));
}
@Override
public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Void> context)
{
PlanNode source = node.getSource();
PlanNode filteringSource = node.getFilteringSource();
PlanNode rewrittenSource = rewriteWith(this, source);
PlanNode rewrittenFilteringSource = rewriteWith(this, filteringSource);
if (rewrittenSource == source || rewrittenFilteringSource == filteringSource) {
rewrittenSource = addSamplingFilter(rewrittenSource, findSuitableKey(ImmutableList.of(node.getSourceJoinVariable())), functionAndTypeManager);
rewrittenFilteringSource = addSamplingFilter(rewrittenFilteringSource, findSuitableKey(ImmutableList.of(node.getFilteringSourceJoinVariable())), functionAndTypeManager);
}
return replaceChildren(node, ImmutableList.of(rewrittenSource, rewrittenFilteringSource));
}
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context)
{
return sampleSourceNodeWithKey(node, node.getSource(), node.getGroupingKeys());
}
@Override
public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
{
return sampleSourceNodeWithKey(node, node.getSource(), node.getPartitionBy());
}
@Override
public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Void> context)
{
return sampleSourceNodeWithKey(node, node.getSource(), node.getPartitionBy());
}
@Override
public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Void> context)
{
return sampleSourceNodeWithKey(node, node.getSource(), node.getPartitionBy());
}
@Override
public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext<Void> context)
{
return sampleSourceNodeWithKey(node, node.getSource(), node.getDistinctVariables());
}
}
}