StatisticsAggregationPlanner.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.StatisticAggregations;
import com.facebook.presto.spi.plan.StatisticAggregationsDescriptor;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.statistics.ColumnStatisticMetadata;
import com.facebook.presto.spi.statistics.ColumnStatisticType;
import com.facebook.presto.spi.statistics.TableStatisticType;
import com.facebook.presto.spi.statistics.TableStatisticsMetadata;
import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled;
import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE;
import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT;
import static com.facebook.presto.sql.relational.SqlFunctionUtils.sqlFunctionToRowExpression;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;
public class StatisticsAggregationPlanner
{
private final VariableAllocator variableAllocator;
private final FunctionAndTypeResolver functionAndTypeResolver;
private final boolean useHistograms;
private final Session session;
private final FunctionAndTypeManager functionAndTypeManager;
public StatisticsAggregationPlanner(VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager, Session session)
{
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.session = requireNonNull(session, "session is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.functionAndTypeResolver = functionAndTypeManager.getFunctionAndTypeResolver();
this.useHistograms = shouldOptimizerUseHistograms(session);
}
public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map<String, VariableReferenceExpression> columnToVariableMap)
{
StatisticAggregationsDescriptor.Builder<VariableReferenceExpression> descriptor = StatisticAggregationsDescriptor.builder();
List<String> groupingColumns = statisticsMetadata.getGroupingColumns();
List<VariableReferenceExpression> groupingVariables = groupingColumns.stream()
.map(columnToVariableMap::get)
.collect(toImmutableList());
for (int i = 0; i < groupingVariables.size(); i++) {
descriptor.addGrouping(groupingColumns.get(i), groupingVariables.get(i));
}
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> additionalVariables = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
StandardFunctionResolution functionResolution = new FunctionResolution(functionAndTypeResolver);
for (TableStatisticType type : statisticsMetadata.getTableStatistics()) {
if (type != ROW_COUNT) {
throw new PrestoException(NOT_SUPPORTED, "Table-wide statistic type not supported: " + type);
}
AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(
new CallExpression(
"count",
functionResolution.countFunction(),
BIGINT,
ImmutableList.of()),
Optional.empty(),
Optional.empty(),
false,
Optional.empty());
VariableReferenceExpression variable = variableAllocator.newVariable("rowCount", BIGINT);
aggregations.put(variable, aggregation);
descriptor.addTableStatistic(ROW_COUNT, variable);
}
for (ColumnStatisticMetadata columnStatisticMetadata : statisticsMetadata.getColumnStatistics()) {
if (!useHistograms && columnStatisticMetadata.getStatisticType() == ColumnStatisticType.HISTOGRAM) {
continue;
}
String columnName = columnStatisticMetadata.getColumnName();
ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType();
VariableReferenceExpression inputVariable = columnToVariableMap.get(columnName);
verify(inputVariable != null, "inputVariable is null");
ColumnStatisticsAggregation aggregation = createColumnAggregation(columnStatisticMetadata, inputVariable,
ImmutableMap.of(columnName, inputVariable.getName()));
additionalVariables.putAll(aggregation.getInputProjections());
VariableReferenceExpression variable = variableAllocator.newVariable(statisticType + ":" + columnName, aggregation.getOutputType());
aggregations.put(variable, aggregation.getAggregation());
descriptor.addColumnStatistic(columnStatisticMetadata, variable);
}
StatisticAggregations aggregation = new StatisticAggregations(aggregations.build(), groupingVariables);
return new TableStatisticAggregation(aggregation, descriptor.build(), additionalVariables.build());
}
private ColumnStatisticsAggregation createColumnAggregationFromSqlFunction(
String sqlFunction,
VariableReferenceExpression input,
Map<String, String> columnNameToInputVariableNameMap)
{
RowExpression expression = sqlFunctionToRowExpression(
sqlFunction,
ImmutableSet.of(input),
functionAndTypeManager,
session,
columnNameToInputVariableNameMap);
verify(expression instanceof CallExpression, "column statistic SQL expressions must represent a function call");
CallExpression call = (CallExpression) expression;
FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(call.getFunctionHandle());
verify(functionMeta.getFunctionKind().equals(AGGREGATE), "column statistic function must be aggregates");
// Aggregations input arguments are required to be variable reference expressions.
// For each one that isn't, allocate a new variable to reference
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> inputProjections = ImmutableMap.builder();
List<RowExpression> callVariableArguments = call.getArguments()
.stream()
.map(argument -> {
if (argument instanceof VariableReferenceExpression) {
return argument;
}
VariableReferenceExpression newArgument = variableAllocator.newVariable(argument);
inputProjections.put(newArgument, argument);
return newArgument;
})
.collect(Collectors.toList());
CallExpression callWithVariables = new CallExpression(
call.getSourceLocation(),
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
callVariableArguments);
return new ColumnStatisticsAggregation(
new AggregationNode.Aggregation(callWithVariables,
Optional.empty(),
Optional.empty(),
false,
Optional.empty()),
functionAndTypeResolver.getType(functionMeta.getReturnType()),
inputProjections.build());
}
private ColumnStatisticsAggregation createColumnAggregationFromFunctionName(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
{
FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunction(), TypeSignatureProvider.fromTypes(ImmutableList.<Type>builder()
.add(input.getType())
.build()));
FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(functionHandle);
Type inputType = functionAndTypeResolver.getType(getOnlyElement(functionMeta.getArgumentTypes()));
Type outputType = functionAndTypeResolver.getType(functionMeta.getReturnType());
// todo: fix this hack
// In native clusters, we do not support parameterized varchar,
// hence if its a varchar type, we just compare the type signature base.
boolean isVarcharType = input.getType().getTypeSignature().getBase().equals(VARCHAR);
boolean isTypeSignatureBaseMatching = inputType.getTypeSignature().getBase().equals(input.getType().getTypeSignature().getBase());
verify(
inputType.equals(input.getType()) ||
input.getType().equals(UNKNOWN) ||
isNativeExecutionEnabled(session) && isVarcharType && isTypeSignatureBaseMatching,
"resolved function input type does not match the input type: %s != %s", inputType, input.getType());
return new ColumnStatisticsAggregation(
new AggregationNode.Aggregation(
new CallExpression(
input.getSourceLocation(),
columnStatisticMetadata.getFunction(),
functionHandle,
outputType,
ImmutableList.of(input)),
Optional.empty(),
Optional.empty(),
false,
Optional.empty()),
outputType,
ImmutableMap.of());
}
private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input,
Map<String, String> columnNameToInputVariableNameMap)
{
if (columnStatisticMetadata.isSqlExpression()) {
return createColumnAggregationFromSqlFunction(columnStatisticMetadata.getFunction(), input, columnNameToInputVariableNameMap);
}
return createColumnAggregationFromFunctionName(columnStatisticMetadata, input);
}
public static class TableStatisticAggregation
{
private final StatisticAggregations aggregations;
private final StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor;
private final Map<VariableReferenceExpression, RowExpression> additionalVariables;
private TableStatisticAggregation(
StatisticAggregations aggregations,
StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor,
Map<VariableReferenceExpression, RowExpression> additionalVariables)
{
this.aggregations = requireNonNull(aggregations, "statisticAggregations is null");
this.descriptor = requireNonNull(descriptor, "descriptor is null");
this.additionalVariables = requireNonNull(additionalVariables, "additionalVariables is null");
}
public StatisticAggregations getAggregations()
{
return aggregations;
}
public StatisticAggregationsDescriptor<VariableReferenceExpression> getDescriptor()
{
return descriptor;
}
public Map<VariableReferenceExpression, RowExpression> getAdditionalVariables()
{
return additionalVariables;
}
}
public static class ColumnStatisticsAggregation
{
private final AggregationNode.Aggregation aggregation;
private final Type outputType;
private final Map<VariableReferenceExpression, RowExpression> inputProjections;
private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType, Map<VariableReferenceExpression, RowExpression> inputProjections)
{
this.aggregation = requireNonNull(aggregation, "aggregation is null");
this.outputType = requireNonNull(outputType, "outputType is null");
this.inputProjections = requireNonNull(inputProjections, "additionalVariable is null");
}
public AggregationNode.Aggregation getAggregation()
{
return aggregation;
}
public Type getOutputType()
{
return outputType;
}
public Map<VariableReferenceExpression, RowExpression> getInputProjections()
{
return inputProjections;
}
}
}