ComparisonStatsCalculator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.statistics.DisjointRangeDomainHistogram;
import com.facebook.presto.spi.statistics.Estimate;
import com.facebook.presto.spi.statistics.HistogramCalculator;
import com.facebook.presto.spi.statistics.UniformDistributionHistogram;
import com.facebook.presto.sql.tree.ComparisonExpression;
import java.util.Optional;
import java.util.OptionalDouble;
import static com.facebook.presto.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT;
import static com.facebook.presto.cost.VariableStatsEstimate.buildFrom;
import static com.facebook.presto.util.MoreMath.firstNonNaN;
import static com.facebook.presto.util.MoreMath.max;
import static com.facebook.presto.util.MoreMath.min;
import static java.lang.Double.NEGATIVE_INFINITY;
import static java.lang.Double.NaN;
import static java.lang.Double.POSITIVE_INFINITY;
import static java.lang.Double.isFinite;
import static java.lang.Double.isNaN;
import static java.util.Objects.requireNonNull;
public final class ComparisonStatsCalculator
{
private static final Logger log = Logger.get(ComparisonStatsCalculator.class);
private final boolean useHistograms;
public ComparisonStatsCalculator(Session session)
{
requireNonNull(session, "session is null");
this.useHistograms = SystemSessionProperties.shouldOptimizerUseHistograms(session);
}
public PlanNodeStatsEstimate estimateExpressionToLiteralComparison(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate expressionStatistics,
Optional<VariableReferenceExpression> expressionVariable,
OptionalDouble literalValue,
ComparisonExpression.Operator operator)
{
switch (operator) {
case EQUAL:
return estimateExpressionEqualToLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue);
case NOT_EQUAL:
return estimateExpressionNotEqualToLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue);
case LESS_THAN:
return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, false);
case LESS_THAN_OR_EQUAL:
return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, true);
case GREATER_THAN:
return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, false);
case GREATER_THAN_OR_EQUAL:
return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, true);
case IS_DISTINCT_FROM:
return PlanNodeStatsEstimate.unknown();
default:
throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
}
}
private PlanNodeStatsEstimate estimateExpressionEqualToLiteral(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate expressionStatistics,
Optional<VariableReferenceExpression> expressionVariable,
OptionalDouble literalValue)
{
StatisticRange filterRange;
if (literalValue.isPresent()) {
filterRange = new StatisticRange(literalValue.getAsDouble(), false, literalValue.getAsDouble(), false, 1);
}
else {
filterRange = new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, 1);
}
return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange);
}
private PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate expressionStatistics,
Optional<VariableReferenceExpression> expressionVariable,
OptionalDouble literalValue)
{
StatisticRange filterRange;
if (literalValue.isPresent()) {
filterRange = new StatisticRange(literalValue.getAsDouble(), false, literalValue.getAsDouble(), false, 1);
}
else {
filterRange = new StatisticRange(NEGATIVE_INFINITY, true, POSITIVE_INFINITY, true, 1);
}
double filterFactor = 1 - calculateFilterFactor(expressionStatistics, filterRange);
PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics);
estimate.setOutputRowCount(filterFactor * (1 - expressionStatistics.getNullsFraction()) * inputStatistics.getOutputRowCount());
if (expressionVariable.isPresent()) {
VariableStatsEstimate symbolNewEstimate = buildFrom(expressionStatistics)
.setNullsFraction(0.0)
.setDistinctValuesCount(max(expressionStatistics.getDistinctValuesCount() - 1, 0))
.build();
estimate = estimate.addVariableStatistics(expressionVariable.get(), symbolNewEstimate);
}
return estimate.build();
}
private PlanNodeStatsEstimate estimateExpressionLessThanLiteral(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate expressionStatistics,
Optional<VariableReferenceExpression> expressionVariable,
OptionalDouble literalValue,
boolean equals)
{
StatisticRange filterRange = new StatisticRange(NEGATIVE_INFINITY, true, literalValue.orElse(POSITIVE_INFINITY), !equals, NaN);
return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange);
}
private PlanNodeStatsEstimate estimateExpressionGreaterThanLiteral(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate expressionStatistics,
Optional<VariableReferenceExpression> expressionVariable,
OptionalDouble literalValue,
boolean equals)
{
StatisticRange filterRange = new StatisticRange(literalValue.orElse(NEGATIVE_INFINITY), !equals, POSITIVE_INFINITY, true, NaN);
return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange);
}
private PlanNodeStatsEstimate estimateFilterRange(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate expressionStatistics,
Optional<VariableReferenceExpression> expressionVariable,
StatisticRange filterRange)
{
double filterFactor = calculateFilterFactor(expressionStatistics, filterRange);
StatisticRange expressionRange = StatisticRange.from(expressionStatistics);
StatisticRange intersectRange = expressionRange.intersect(filterRange);
PlanNodeStatsEstimate estimate = inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1 - expressionStatistics.getNullsFraction()) * rowCount);
if (expressionVariable.isPresent()) {
VariableStatsEstimate.Builder symbolNewEstimate =
VariableStatsEstimate.builder()
.setAverageRowSize(expressionStatistics.getAverageRowSize())
.setStatisticsRange(intersectRange)
.setNullsFraction(0.0);
if (useHistograms) {
symbolNewEstimate.setHistogram(expressionStatistics.getHistogram().map(expressionHistogram -> DisjointRangeDomainHistogram.addConjunction(expressionHistogram, intersectRange.toPrestoRange())));
}
estimate = estimate.mapVariableColumnStatistics(expressionVariable.get(), oldStats -> symbolNewEstimate.build());
}
return estimate;
}
private double calculateFilterFactor(VariableStatsEstimate variableStatistics, StatisticRange filterRange)
{
StatisticRange variableRange = StatisticRange.from(variableStatistics);
StatisticRange intersectRange = variableRange.intersect(filterRange);
Estimate filterEstimate;
if (useHistograms) {
Estimate distinctEstimate = isNaN(variableStatistics.getDistinctValuesCount()) ? Estimate.unknown() : Estimate.of(variableRange.getDistinctValuesCount());
filterEstimate = HistogramCalculator.calculateFilterFactor(intersectRange.toPrestoRange(), intersectRange.getDistinctValuesCount(),
variableStatistics.getHistogram().orElseGet(() -> new UniformDistributionHistogram(variableStatistics.getLowValue(), variableStatistics.getHighValue())), distinctEstimate, true);
if (log.isDebugEnabled()) {
double expressionFilter = variableRange.overlapPercentWith(intersectRange);
if (!Double.isNaN(expressionFilter) &&
!filterEstimate.fuzzyEquals(Estimate.of(expressionFilter), .0001)) {
log.debug(String.format("histogram-calculated filter factor differs from the uniformity assumption:" +
"expression range: %s%n" +
"intersect range: %s%n" +
"overlapPercent: %s%n" +
"histogram: %s%n" +
"histogramFilterIntersect: %s%n", variableRange, intersectRange, expressionFilter, variableStatistics.getHistogram(), filterEstimate));
}
}
}
else {
filterEstimate = Estimate.estimateFromDouble(variableRange.overlapPercentWith(intersectRange));
}
return filterEstimate.orElse(() -> UNKNOWN_FILTER_COEFFICIENT);
}
public PlanNodeStatsEstimate estimateExpressionToExpressionComparison(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate leftExpressionStatistics,
Optional<VariableReferenceExpression> leftExpressionVariable,
VariableStatsEstimate rightExpressionStatistics,
Optional<VariableReferenceExpression> rightExpressionVariable,
ComparisonExpression.Operator operator)
{
switch (operator) {
case EQUAL:
return estimateExpressionEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionVariable, rightExpressionStatistics, rightExpressionVariable);
case NOT_EQUAL:
return estimateExpressionNotEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionVariable, rightExpressionStatistics, rightExpressionVariable);
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
case IS_DISTINCT_FROM:
return PlanNodeStatsEstimate.unknown();
default:
throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
}
}
private PlanNodeStatsEstimate estimateExpressionEqualToExpression(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate leftExpressionStatistics,
Optional<VariableReferenceExpression> leftExpressionVariable,
VariableStatsEstimate rightExpressionStatistics,
Optional<VariableReferenceExpression> rightExpressionVariable)
{
if (isNaN(leftExpressionStatistics.getDistinctValuesCount()) || isNaN(rightExpressionStatistics.getDistinctValuesCount())) {
return PlanNodeStatsEstimate.unknown();
}
StatisticRange leftExpressionRange = StatisticRange.from(leftExpressionStatistics);
StatisticRange rightExpressionRange = StatisticRange.from(rightExpressionStatistics);
StatisticRange intersect = leftExpressionRange.intersect(rightExpressionRange);
double nullsFilterFactor = (1 - leftExpressionStatistics.getNullsFraction()) * (1 - rightExpressionStatistics.getNullsFraction());
double leftNdv = leftExpressionRange.getDistinctValuesCount();
double rightNdv = rightExpressionRange.getDistinctValuesCount();
double filterFactor = 1.0 / max(leftNdv, rightNdv, 1);
double retainedNdv = min(leftNdv, rightNdv);
PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics)
.setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor);
VariableStatsEstimate equalityStats = VariableStatsEstimate.builder()
.setAverageRowSize(averageExcludingNaNs(leftExpressionStatistics.getAverageRowSize(), rightExpressionStatistics.getAverageRowSize()))
.setNullsFraction(0)
.setStatisticsRange(intersect)
.setDistinctValuesCount(retainedNdv)
.build();
leftExpressionVariable.ifPresent(variable -> estimate.addVariableStatistics(variable, equalityStats));
rightExpressionVariable.ifPresent(variable -> estimate.addVariableStatistics(variable, equalityStats));
return estimate.build();
}
private PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(
PlanNodeStatsEstimate inputStatistics,
VariableStatsEstimate leftExpressionStatistics,
Optional<VariableReferenceExpression> leftExpressionVariable,
VariableStatsEstimate rightExpressionStatistics,
Optional<VariableReferenceExpression> rightExpressionVariable)
{
double nullsFilterFactor = (1 - leftExpressionStatistics.getNullsFraction()) * (1 - rightExpressionStatistics.getNullsFraction());
PlanNodeStatsEstimate inputNullsFiltered = inputStatistics.mapOutputRowCount(size -> size * nullsFilterFactor);
VariableStatsEstimate leftNullsFiltered = leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0);
VariableStatsEstimate rightNullsFiltered = rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0);
PlanNodeStatsEstimate equalityStats = estimateExpressionEqualToExpression(
inputNullsFiltered,
leftNullsFiltered,
leftExpressionVariable,
rightNullsFiltered,
rightExpressionVariable);
if (equalityStats.isOutputRowCountUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(inputNullsFiltered);
double equalityFilterFactor = equalityStats.getOutputRowCount() / inputNullsFiltered.getOutputRowCount();
if (!isFinite(equalityFilterFactor)) {
equalityFilterFactor = 0.0;
}
result.setOutputRowCount(inputNullsFiltered.getOutputRowCount() * (1 - equalityFilterFactor));
leftExpressionVariable.ifPresent(symbol -> result.addVariableStatistics(symbol, leftNullsFiltered));
rightExpressionVariable.ifPresent(symbol -> result.addVariableStatistics(symbol, rightNullsFiltered));
return result.build();
}
private static double averageExcludingNaNs(double first, double second)
{
if (isNaN(first) && isNaN(second)) {
return NaN;
}
if (!isNaN(first) && !isNaN(second)) {
return (first + second) / 2;
}
return firstNonNaN(first, second);
}
}