StatsNormalizer.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DateType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.SmallintType;
import com.facebook.presto.common.type.TinyintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
import static java.lang.Math.floor;
import static java.lang.Math.pow;
/**
* Makes stats consistent
*/
public class StatsNormalizer
{
public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats)
{
return normalize(stats, Optional.empty());
}
public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Collection<VariableReferenceExpression> outputVariables)
{
return normalize(stats, Optional.of(outputVariables));
}
private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Optional<Collection<VariableReferenceExpression>> outputVariables)
{
if (stats.isOutputRowCountUnknown() && stats.isTotalSizeUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
PlanNodeStatsEstimate.Builder normalized = PlanNodeStatsEstimate.buildFrom(stats)
.setTotalSize(stats.getOutputSizeInBytes());
Predicate<VariableReferenceExpression> variableFilter = outputVariables
.map(ImmutableSet::copyOf)
.map(set -> (Predicate<VariableReferenceExpression>) set::contains)
.orElse(variable -> true);
for (VariableReferenceExpression variable : stats.getVariablesWithKnownStatistics()) {
if (!variableFilter.test(variable)) {
normalized.removeVariableStatistics(variable);
continue;
}
VariableStatsEstimate variableStats = stats.getVariableStatistics(variable);
VariableStatsEstimate normalizedSymbolStats = stats.getOutputRowCount() == 0 ? VariableStatsEstimate.zero() : normalizeVariableStats(variable, variableStats, stats);
if (normalizedSymbolStats.isUnknown()) {
normalized.removeVariableStatistics(variable);
continue;
}
if (!Objects.equals(normalizedSymbolStats, variableStats)) {
normalized.addVariableStatistics(variable, normalizedSymbolStats);
}
}
return normalized.build();
}
/**
* Calculates consistent stats for a symbol.
*/
private VariableStatsEstimate normalizeVariableStats(VariableReferenceExpression variable, VariableStatsEstimate variableStats, PlanNodeStatsEstimate stats)
{
if (variableStats.isUnknown()) {
return VariableStatsEstimate.unknown();
}
double outputRowCount = stats.getOutputRowCount();
checkArgument(outputRowCount > 0, "outputRowCount must be greater than zero: %s", outputRowCount);
double distinctValuesCount = variableStats.getDistinctValuesCount();
double nullsFraction = variableStats.getNullsFraction();
if (!isNaN(distinctValuesCount)) {
Type type = variable.getType();
double maxDistinctValuesByLowHigh = maxDistinctValuesByLowHigh(variableStats, type);
if (distinctValuesCount > maxDistinctValuesByLowHigh) {
distinctValuesCount = maxDistinctValuesByLowHigh;
}
if (distinctValuesCount > outputRowCount) {
distinctValuesCount = outputRowCount;
}
double nonNullValues = outputRowCount * (1 - nullsFraction);
if (distinctValuesCount > nonNullValues) {
double difference = distinctValuesCount - nonNullValues;
distinctValuesCount -= difference / 2;
nonNullValues += difference / 2;
nullsFraction = 1 - nonNullValues / outputRowCount;
}
}
if (distinctValuesCount == 0.0) {
return VariableStatsEstimate.zero();
}
return VariableStatsEstimate.buildFrom(variableStats)
.setDistinctValuesCount(distinctValuesCount)
.setNullsFraction(nullsFraction)
.build();
}
private double maxDistinctValuesByLowHigh(VariableStatsEstimate variableStats, Type type)
{
if (variableStats.statisticRange().length() == 0.0) {
return 1;
}
if (!isDiscrete(type)) {
return NaN;
}
double length = variableStats.getHighValue() - variableStats.getLowValue();
if (isNaN(length)) {
return NaN;
}
if (type instanceof DecimalType) {
length *= pow(10, ((DecimalType) type).getScale());
}
return floor(length + 1);
}
private static boolean isDiscrete(Type type)
{
return type.equals(IntegerType.INTEGER) ||
type.equals(BigintType.BIGINT) ||
type.equals(SmallintType.SMALLINT) ||
type.equals(TinyintType.TINYINT) ||
type.equals(BooleanType.BOOLEAN) ||
type.equals(DateType.DATE) ||
type instanceof DecimalType;
}
}