PlanNodeStatsEstimateMath.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;
import com.facebook.presto.spi.statistics.ConnectorHistogram;
import com.facebook.presto.spi.statistics.DisjointRangeDomainHistogram;
import java.util.Optional;
import static com.facebook.presto.spi.statistics.DisjointRangeDomainHistogram.addConjunction;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
import static java.lang.Double.max;
import static java.lang.Double.min;
import static java.util.stream.Stream.concat;
public class PlanNodeStatsEstimateMath
{
private final boolean shouldUseHistograms;
public PlanNodeStatsEstimateMath(boolean shouldUseHistograms)
{
this.shouldUseHistograms = shouldUseHistograms;
}
/**
* Subtracts subset stats from supersets stats.
* It is assumed that each NDV from subset has a matching NDV in superset.
*/
public PlanNodeStatsEstimate subtractSubsetStats(PlanNodeStatsEstimate superset, PlanNodeStatsEstimate subset)
{
if (superset.isOutputRowCountUnknown() || subset.isOutputRowCountUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
double supersetRowCount = superset.getOutputRowCount();
double subsetRowCount = subset.getOutputRowCount();
double outputRowCount = max(supersetRowCount - subsetRowCount, 0);
// everything will be filtered out after applying negation
if (outputRowCount == 0) {
return createZeroStats(superset);
}
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
result.setOutputRowCount(outputRowCount);
superset.getVariablesWithKnownStatistics().forEach(symbol -> {
VariableStatsEstimate supersetSymbolStats = superset.getVariableStatistics(symbol);
VariableStatsEstimate subsetSymbolStats = subset.getVariableStatistics(symbol);
VariableStatsEstimate.Builder newSymbolStats = VariableStatsEstimate.builder();
// for simplicity keep the average row size the same as in the input
// in most cases the average row size doesn't change after applying filters
newSymbolStats.setAverageRowSize(supersetSymbolStats.getAverageRowSize());
// nullsCount
double supersetNullsCount = supersetSymbolStats.getNullsFraction() * supersetRowCount;
double subsetNullsCount = subsetSymbolStats.getNullsFraction() * subsetRowCount;
double newNullsCount = max(supersetNullsCount - subsetNullsCount, 0);
newSymbolStats.setNullsFraction(min(newNullsCount, outputRowCount) / outputRowCount);
// distinctValuesCount
double supersetDistinctValues = supersetSymbolStats.getDistinctValuesCount();
double subsetDistinctValues = subsetSymbolStats.getDistinctValuesCount();
double newDistinctValuesCount;
if (isNaN(supersetDistinctValues) || isNaN(subsetDistinctValues)) {
newDistinctValuesCount = NaN;
}
else if (supersetDistinctValues == 0) {
newDistinctValuesCount = 0;
}
else if (subsetDistinctValues == 0) {
newDistinctValuesCount = supersetDistinctValues;
}
else {
double supersetNonNullsCount = supersetRowCount - supersetNullsCount;
double subsetNonNullsCount = subsetRowCount - subsetNullsCount;
double supersetValuesPerDistinctValue = supersetNonNullsCount / supersetDistinctValues;
double subsetValuesPerDistinctValue = subsetNonNullsCount / subsetDistinctValues;
if (supersetValuesPerDistinctValue <= subsetValuesPerDistinctValue) {
newDistinctValuesCount = max(supersetDistinctValues - subsetDistinctValues, 0);
}
else {
newDistinctValuesCount = supersetDistinctValues;
}
}
newSymbolStats.setDistinctValuesCount(newDistinctValuesCount);
// range
newSymbolStats.setLowValue(supersetSymbolStats.getLowValue());
newSymbolStats.setHighValue(supersetSymbolStats.getHighValue());
result.addVariableStatistics(symbol, newSymbolStats.build());
});
return result.build();
}
public PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap)
{
if (stats.isOutputRowCountUnknown() || cap.isOutputRowCountUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
double cappedRowCount = min(stats.getOutputRowCount(), cap.getOutputRowCount());
result.setOutputRowCount(cappedRowCount);
stats.getVariablesWithKnownStatistics().forEach(symbol -> {
VariableStatsEstimate symbolStats = stats.getVariableStatistics(symbol);
VariableStatsEstimate capSymbolStats = cap.getVariableStatistics(symbol);
VariableStatsEstimate.Builder newSymbolStats = VariableStatsEstimate.builder();
// for simplicity keep the average row size the same as in the input
// in most cases the average row size doesn't change after applying filters
newSymbolStats.setAverageRowSize(symbolStats.getAverageRowSize());
newSymbolStats.setDistinctValuesCount(min(symbolStats.getDistinctValuesCount(), capSymbolStats.getDistinctValuesCount()));
double newLow = max(symbolStats.getLowValue(), capSymbolStats.getLowValue());
double newHigh = min(symbolStats.getHighValue(), capSymbolStats.getHighValue());
newSymbolStats.setLowValue(newLow);
newSymbolStats.setHighValue(newHigh);
double numberOfNulls = stats.getOutputRowCount() * symbolStats.getNullsFraction();
double capNumberOfNulls = cap.getOutputRowCount() * capSymbolStats.getNullsFraction();
double cappedNumberOfNulls = min(numberOfNulls, capNumberOfNulls);
double cappedNullsFraction = cappedRowCount == 0 ? 1 : cappedNumberOfNulls / cappedRowCount;
newSymbolStats.setNullsFraction(cappedNullsFraction);
if (shouldUseHistograms) {
newSymbolStats.setHistogram(symbolStats.getHistogram().map(symbolHistogram -> addConjunction(symbolHistogram, new StatisticRange(newLow, newHigh, 0).toPrestoRange())));
}
result.addVariableStatistics(symbol, newSymbolStats.build());
});
return result.build();
}
private static PlanNodeStatsEstimate createZeroStats(PlanNodeStatsEstimate stats)
{
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
result.setOutputRowCount(0);
stats.getVariablesWithKnownStatistics().forEach(symbol -> result.addVariableStatistics(symbol, VariableStatsEstimate.zero()));
return result.build();
}
protected enum RangeAdditionStrategy
{
ADD_AND_SUM_DISTINCT(StatisticRange::addAndSumDistinctValues),
ADD_AND_MAX_DISTINCT(StatisticRange::addAndMaxDistinctValues),
ADD_AND_COLLAPSE_DISTINCT(StatisticRange::addAndCollapseDistinctValues),
INTERSECT(StatisticRange::intersect);
private final RangeAdditionFunction rangeAdditionFunction;
RangeAdditionStrategy(RangeAdditionFunction rangeAdditionFunction)
{
this.rangeAdditionFunction = rangeAdditionFunction;
}
public RangeAdditionFunction getRangeAdditionFunction()
{
return rangeAdditionFunction;
}
}
@FunctionalInterface
protected interface RangeAdditionFunction
{
StatisticRange add(StatisticRange leftRange, StatisticRange rightRange);
}
public PlanNodeStatsEstimate addStatsAndSumDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
{
return addStats(left, right, RangeAdditionStrategy.ADD_AND_SUM_DISTINCT);
}
public PlanNodeStatsEstimate addStatsAndMaxDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
{
return addStats(left, right, RangeAdditionStrategy.ADD_AND_MAX_DISTINCT);
}
public PlanNodeStatsEstimate addStatsAndCollapseDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
{
return addStats(left, right, RangeAdditionStrategy.ADD_AND_COLLAPSE_DISTINCT);
}
public PlanNodeStatsEstimate addStatsAndIntersect(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right)
{
if (left.isOutputRowCountUnknown() || right.isOutputRowCountUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
double estimatedRowCount = Math.min(left.getOutputRowCount(), right.getOutputRowCount());
double rowCount = concat(
left.getVariablesWithKnownStatistics().stream(),
right.getVariablesWithKnownStatistics().stream())
.distinct()
.map(symbol -> {
StatisticRange lstats = StatisticRange.from(left.getVariableStatistics(symbol));
StatisticRange rstats = StatisticRange.from(right.getVariableStatistics(symbol));
return Math.min(
left.getOutputRowCount() * lstats.overlapPercentWith(rstats),
right.getOutputRowCount() * rstats.overlapPercentWith(lstats));
}).reduce(Math::min).orElse(estimatedRowCount);
buildVariableStatistics(left, right, statsBuilder, rowCount, RangeAdditionStrategy.INTERSECT);
return statsBuilder.setOutputRowCount(rowCount).build();
}
private PlanNodeStatsEstimate addStats(
PlanNodeStatsEstimate left,
PlanNodeStatsEstimate right,
RangeAdditionStrategy strategy)
{
double rowCount = left.getOutputRowCount() + right.getOutputRowCount();
double totalSize = left.getTotalSize() + right.getTotalSize();
if (isNaN(rowCount) && isNaN(totalSize)) {
return PlanNodeStatsEstimate.unknown();
}
PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
buildVariableStatistics(left, right, statsBuilder, rowCount, strategy);
return statsBuilder.setOutputRowCount(rowCount)
.setTotalSize(totalSize).build();
}
private void buildVariableStatistics(
PlanNodeStatsEstimate left,
PlanNodeStatsEstimate right,
PlanNodeStatsEstimate.Builder statsBuilder,
double estimatedRowCount,
RangeAdditionStrategy strategy)
{
concat(left.getVariablesWithKnownStatistics().stream(), right.getVariablesWithKnownStatistics().stream())
.distinct()
.forEach(symbol -> {
VariableStatsEstimate symbolStats = VariableStatsEstimate.unknown();
if (estimatedRowCount <= 0) {
symbolStats = VariableStatsEstimate.zero();
}
else if (estimatedRowCount > 0) {
symbolStats = addColumnStats(
left.getVariableStatistics(symbol),
left.getOutputRowCount(),
right.getVariableStatistics(symbol),
right.getOutputRowCount(),
estimatedRowCount,
strategy);
}
statsBuilder.addVariableStatistics(symbol, symbolStats);
});
}
private VariableStatsEstimate addColumnStats(
VariableStatsEstimate leftStats,
double leftRows,
VariableStatsEstimate rightStats,
double rightRows,
double newRowCount,
RangeAdditionStrategy strategy)
{
checkArgument(newRowCount > 0, "newRowCount must be greater than zero");
StatisticRange leftRange = StatisticRange.from(leftStats);
StatisticRange rightRange = StatisticRange.from(rightStats);
StatisticRange sum = strategy.getRangeAdditionFunction().add(leftRange, rightRange);
double nullsCountRight = rightStats.getNullsFraction() * rightRows;
double nullsCountLeft = leftStats.getNullsFraction() * leftRows;
double totalSizeLeft = (leftRows - nullsCountLeft) * leftStats.getAverageRowSize();
double totalSizeRight = (rightRows - nullsCountRight) * rightStats.getAverageRowSize();
double newNullsFraction = Math.min((nullsCountLeft + nullsCountRight) / newRowCount, 1);
double newNonNullsRowCount = newRowCount * (1.0 - newNullsFraction);
// FIXME, weights to average. left and right should be equal in most cases anyway
double newAverageRowSize = newNonNullsRowCount == 0 ? 0 : ((totalSizeLeft + totalSizeRight) / newNonNullsRowCount);
VariableStatsEstimate.Builder statistics = VariableStatsEstimate.builder()
.setStatisticsRange(sum)
.setAverageRowSize(newAverageRowSize)
.setNullsFraction(newNullsFraction);
if (shouldUseHistograms) {
Optional<ConnectorHistogram> newHistogram = RangeAdditionStrategy.INTERSECT == strategy ?
leftStats.getHistogram().map(leftHistogram -> DisjointRangeDomainHistogram.addConjunction(leftHistogram, rightRange.toPrestoRange())) :
leftStats.getHistogram().map(leftHistogram -> DisjointRangeDomainHistogram.addDisjunction(leftHistogram, rightRange.toPrestoRange()));
statistics.setHistogram(newHistogram);
}
return statistics.build();
}
}