TestSemiJoinStatsCalculator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.cost;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.util.Optional;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.cost.PlanNodeStatsAssertion.assertThat;
import static com.facebook.presto.cost.SemiJoinStatsCalculator.computeAntiJoin;
import static com.facebook.presto.cost.SemiJoinStatsCalculator.computeSemiJoin;
import static java.lang.Double.NEGATIVE_INFINITY;
import static java.lang.Double.NaN;
import static java.lang.Double.POSITIVE_INFINITY;
public class TestSemiJoinStatsCalculator
{
private PlanNodeStatsEstimate inputStatistics;
private VariableStatsEstimate uStats;
private VariableStatsEstimate wStats;
private VariableStatsEstimate xStats;
private VariableStatsEstimate yStats;
private VariableStatsEstimate zStats;
private VariableStatsEstimate leftOpenStats;
private VariableStatsEstimate rightOpenStats;
private VariableStatsEstimate unknownRangeStats;
private VariableStatsEstimate emptyRangeStats;
private VariableStatsEstimate fractionalNdvStats;
private VariableReferenceExpression u = new VariableReferenceExpression(Optional.empty(), "u", BIGINT);
private VariableReferenceExpression w = new VariableReferenceExpression(Optional.empty(), "w", BIGINT);
private VariableReferenceExpression x = new VariableReferenceExpression(Optional.empty(), "x", BIGINT);
private VariableReferenceExpression y = new VariableReferenceExpression(Optional.empty(), "y", BIGINT);
private VariableReferenceExpression z = new VariableReferenceExpression(Optional.empty(), "z", BIGINT);
private VariableReferenceExpression leftOpen = new VariableReferenceExpression(Optional.empty(), "leftOpen", BIGINT);
private VariableReferenceExpression rightOpen = new VariableReferenceExpression(Optional.empty(), "rightOpen", BIGINT);
private VariableReferenceExpression unknownRange = new VariableReferenceExpression(Optional.empty(), "unknownRange", BIGINT);
private VariableReferenceExpression emptyRange = new VariableReferenceExpression(Optional.empty(), "emptyRange", BIGINT);
private VariableReferenceExpression unknown = new VariableReferenceExpression(Optional.empty(), "unknown", BIGINT);
private VariableReferenceExpression fractionalNdv = new VariableReferenceExpression(Optional.empty(), "fractionalNdv", BIGINT);
@BeforeClass
public void setUp()
throws Exception
{
uStats = VariableStatsEstimate.builder()
.setAverageRowSize(8.0)
.setDistinctValuesCount(300)
.setLowValue(0)
.setHighValue(20)
.setNullsFraction(0.1)
.build();
wStats = VariableStatsEstimate.builder()
.setAverageRowSize(8.0)
.setDistinctValuesCount(30)
.setLowValue(0)
.setHighValue(20)
.setNullsFraction(0.1)
.build();
xStats = VariableStatsEstimate.builder()
.setAverageRowSize(4.0)
.setDistinctValuesCount(40.0)
.setLowValue(-10.0)
.setHighValue(10.0)
.setNullsFraction(0.25)
.build();
yStats = VariableStatsEstimate.builder()
.setAverageRowSize(4.0)
.setDistinctValuesCount(20.0)
.setLowValue(0.0)
.setHighValue(5.0)
.setNullsFraction(0.5)
.build();
zStats = VariableStatsEstimate.builder()
.setAverageRowSize(4.0)
.setDistinctValuesCount(5.0)
.setLowValue(-100.0)
.setHighValue(100.0)
.setNullsFraction(0.1)
.build();
leftOpenStats = VariableStatsEstimate.builder()
.setAverageRowSize(4.0)
.setDistinctValuesCount(50.0)
.setLowValue(NEGATIVE_INFINITY)
.setHighValue(15.0)
.setNullsFraction(0.1)
.build();
rightOpenStats = VariableStatsEstimate.builder()
.setAverageRowSize(4.0)
.setDistinctValuesCount(50.0)
.setLowValue(-15.0)
.setHighValue(POSITIVE_INFINITY)
.setNullsFraction(0.1)
.build();
unknownRangeStats = VariableStatsEstimate.builder()
.setAverageRowSize(4.0)
.setDistinctValuesCount(50.0)
.setLowValue(NEGATIVE_INFINITY)
.setHighValue(POSITIVE_INFINITY)
.setNullsFraction(0.1)
.build();
emptyRangeStats = VariableStatsEstimate.builder()
.setAverageRowSize(4.0)
.setDistinctValuesCount(0.0)
.setLowValue(NaN)
.setHighValue(NaN)
.setNullsFraction(NaN)
.build();
fractionalNdvStats = VariableStatsEstimate.builder()
.setAverageRowSize(NaN)
.setDistinctValuesCount(0.1)
.setNullsFraction(0)
.build();
inputStatistics = PlanNodeStatsEstimate.builder()
.addVariableStatistics(u, uStats)
.addVariableStatistics(w, wStats)
.addVariableStatistics(x, xStats)
.addVariableStatistics(y, yStats)
.addVariableStatistics(z, zStats)
.addVariableStatistics(leftOpen, leftOpenStats)
.addVariableStatistics(rightOpen, rightOpenStats)
.addVariableStatistics(unknownRange, unknownRangeStats)
.addVariableStatistics(emptyRange, emptyRangeStats)
.addVariableStatistics(unknown, VariableStatsEstimate.unknown())
.addVariableStatistics(fractionalNdv, fractionalNdvStats)
.setOutputRowCount(1000.0)
.build();
}
@Test
public void testSemiJoin()
{
// overlapping ranges
assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, w))
.variableStats(x, stats -> stats
.lowValue(xStats.getLowValue())
.highValue(xStats.getHighValue())
.nullsFraction(0)
.distinctValuesCount(wStats.getDistinctValuesCount()))
.variableStats(w, stats -> stats.isEqualTo(wStats))
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCount(inputStatistics.getOutputRowCount() * xStats.getValuesFraction() * (wStats.getDistinctValuesCount() / xStats.getDistinctValuesCount()));
// overlapping ranges, nothing filtered out
assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, u))
.variableStats(x, stats -> stats
.lowValue(xStats.getLowValue())
.highValue(xStats.getHighValue())
.nullsFraction(0)
.distinctValuesCount(xStats.getDistinctValuesCount()))
.variableStats(u, stats -> stats.isEqualTo(uStats))
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCount(inputStatistics.getOutputRowCount() * xStats.getValuesFraction());
// source stats are unknown
assertThat(computeSemiJoin(inputStatistics, inputStatistics, unknown, u))
.variableStats(unknown, stats -> stats
.nullsFraction(0)
.distinctValuesCountUnknown()
.unknownRange())
.variableStats(u, stats -> stats.isEqualTo(uStats))
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCountUnknown();
// filtering stats are unknown
assertThat(computeSemiJoin(inputStatistics, inputStatistics, x, unknown))
.variableStats(x, stats -> stats
.nullsFraction(0)
.lowValue(xStats.getLowValue())
.highValue(xStats.getHighValue())
.distinctValuesCountUnknown())
.variableStatsUnknown(unknown)
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCountUnknown();
// zero distinct values
assertThat(computeSemiJoin(inputStatistics, inputStatistics, emptyRange, emptyRange))
.outputRowsCount(0);
// fractional distinct values
assertThat(computeSemiJoin(inputStatistics, inputStatistics, fractionalNdv, fractionalNdv))
.outputRowsCount(1000)
.variableStats(fractionalNdv, stats -> stats
.nullsFraction(0)
.distinctValuesCount(0.1));
}
@Test
public void testAntiJoin()
{
// overlapping ranges
assertThat(computeAntiJoin(inputStatistics, inputStatistics, u, x))
.variableStats(u, stats -> stats
.lowValue(uStats.getLowValue())
.highValue(uStats.getHighValue())
.nullsFraction(0)
.distinctValuesCount(uStats.getDistinctValuesCount() - xStats.getDistinctValuesCount()))
.variableStats(x, stats -> stats.isEqualTo(xStats))
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCount(inputStatistics.getOutputRowCount() * uStats.getValuesFraction() * (1 - xStats.getDistinctValuesCount() / uStats.getDistinctValuesCount()));
// overlapping ranges, everything filtered out (but we leave 0.5 due to safety coefficient)
assertThat(computeAntiJoin(inputStatistics, inputStatistics, x, u))
.variableStats(x, stats -> stats
.lowValue(xStats.getLowValue())
.highValue(xStats.getHighValue())
.nullsFraction(0)
.distinctValuesCount(xStats.getDistinctValuesCount() * 0.5))
.variableStats(u, stats -> stats.isEqualTo(uStats))
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCount(inputStatistics.getOutputRowCount() * xStats.getValuesFraction() * 0.5);
// source stats are unknown
assertThat(computeAntiJoin(inputStatistics, inputStatistics, unknown, u))
.variableStats(unknown, stats -> stats
.nullsFraction(0)
.distinctValuesCountUnknown()
.unknownRange())
.variableStats(u, stats -> stats.isEqualTo(uStats))
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCountUnknown();
// filtering stats are unknown
assertThat(computeAntiJoin(inputStatistics, inputStatistics, x, unknown))
.variableStats(x, stats -> stats
.nullsFraction(0)
.lowValue(xStats.getLowValue())
.highValue(xStats.getHighValue())
.distinctValuesCountUnknown())
.variableStatsUnknown(unknown)
.variableStats(z, stats -> stats.isEqualTo(zStats))
.outputRowsCountUnknown();
// zero distinct values
assertThat(computeAntiJoin(inputStatistics, inputStatistics, emptyRange, emptyRange))
.outputRowsCount(0);
// fractional distinct values
assertThat(computeAntiJoin(inputStatistics, inputStatistics, fractionalNdv, fractionalNdv))
.outputRowsCount(500)
.variableStats(fractionalNdv, stats -> stats
.nullsFraction(0)
.distinctValuesCount(0.05));
}
}