AbstractStatisticsBuilderTest.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc.metadata.statistics;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import static com.facebook.presto.orc.metadata.statistics.ColumnStatistics.mergeColumnStatistics;
import static java.lang.Math.min;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
public abstract class AbstractStatisticsBuilderTest<B extends StatisticsBuilder, T>
{
public enum StatisticsType
{
NONE, BOOLEAN, INTEGER, DOUBLE, STRING, DATE, DECIMAL
}
private final StatisticsType statisticsType;
private final Supplier<B> statisticsBuilderSupplier;
private final BiConsumer<B, T> adder;
public AbstractStatisticsBuilderTest(StatisticsType statisticsType, Supplier<B> statisticsBuilderSupplier, BiConsumer<B, T> adder)
{
this.statisticsType = statisticsType;
this.statisticsBuilderSupplier = statisticsBuilderSupplier;
this.adder = adder;
}
@Test
public void testNoValue()
{
B statisticsBuilder = statisticsBuilderSupplier.get();
AggregateColumnStatistics aggregateColumnStatistics = new AggregateColumnStatistics();
assertNoColumnStatistics(statisticsBuilder.buildColumnStatistics(), 0);
aggregateColumnStatistics.add(statisticsBuilder.buildColumnStatistics());
assertNoColumnStatistics(aggregateColumnStatistics.getMergedColumnStatistics(Optional.empty()), 0);
assertNoColumnStatistics(statisticsBuilder.buildColumnStatistics(), 0);
aggregateColumnStatistics.add(statisticsBuilder.buildColumnStatistics());
assertNoColumnStatistics(aggregateColumnStatistics.getMergedColumnStatistics(Optional.empty()), 0);
assertNoColumnStatistics(statisticsBuilder.buildColumnStatistics(), 0);
aggregateColumnStatistics.add(statisticsBuilder.buildColumnStatistics());
assertNoColumnStatistics(aggregateColumnStatistics.getMergedColumnStatistics(Optional.empty()), 0);
}
public void assertTotalValueBytes(long expectedTotalValueBytes, List<T> values)
{
// test add value
B statisticsBuilder = statisticsBuilderSupplier.get();
for (T value : values) {
adder.accept(statisticsBuilder, value);
}
assertEquals(statisticsBuilder.buildColumnStatistics().getTotalValueSizeInBytes(), expectedTotalValueBytes);
// test merge
statisticsBuilder = statisticsBuilderSupplier.get();
for (int i = 0; i < values.size() / 2; i++) {
adder.accept(statisticsBuilder, values.get(i));
}
ColumnStatistics firstStats = statisticsBuilder.buildColumnStatistics();
statisticsBuilder = statisticsBuilderSupplier.get();
for (int i = values.size() / 2; i < values.size(); i++) {
adder.accept(statisticsBuilder, values.get(i));
}
ColumnStatistics secondStats = statisticsBuilder.buildColumnStatistics();
assertEquals(mergeColumnStatistics(ImmutableList.of(firstStats, secondStats)).getTotalValueSizeInBytes(), expectedTotalValueBytes);
}
public void assertMinMaxValues(T expectedMin, T expectedMax)
{
// just min
assertValues(expectedMin, expectedMin, ImmutableList.of(expectedMin));
// just max
assertValues(expectedMax, expectedMax, ImmutableList.of(expectedMax));
// both
assertValues(expectedMin, expectedMax, ImmutableList.of(expectedMin, expectedMax));
}
public void assertValues(T expectedMin, T expectedMax, List<T> values)
{
assertValuesInternal(expectedMin, expectedMax, values);
assertValuesInternal(expectedMin, expectedMax, ImmutableList.copyOf(values).reverse());
List<T> randomOrder = new ArrayList<>(values);
Collections.shuffle(randomOrder, new Random(42));
assertValuesInternal(expectedMin, expectedMax, randomOrder);
}
private void assertValuesInternal(T expectedMin, T expectedMax, List<T> values)
{
B statisticsBuilder = statisticsBuilderSupplier.get();
AggregateColumnStatistics aggregateColumnStatistics = new AggregateColumnStatistics();
aggregateColumnStatistics.add(statisticsBuilder.buildColumnStatistics());
assertColumnStatistics(statisticsBuilder.buildColumnStatistics(), 0, null, null, aggregateColumnStatistics);
for (int loop = 0; loop < 4; loop++) {
for (T value : values) {
adder.accept(statisticsBuilder, value);
aggregateColumnStatistics.add(statisticsBuilder.buildColumnStatistics());
}
assertColumnStatistics(statisticsBuilder.buildColumnStatistics(), values.size() * (loop + 1), expectedMin, expectedMax, aggregateColumnStatistics);
}
}
public static void assertNoColumnStatistics(ColumnStatistics columnStatistics, int expectedNumberOfValues)
{
assertEquals(columnStatistics.getNumberOfValues(), expectedNumberOfValues);
assertNull(columnStatistics.getBooleanStatistics());
assertNull(columnStatistics.getIntegerStatistics());
assertNull(columnStatistics.getDoubleStatistics());
assertNull(columnStatistics.getStringStatistics());
assertNull(columnStatistics.getDateStatistics());
assertNull(columnStatistics.getDecimalStatistics());
assertNull(columnStatistics.getBloomFilter());
}
private void assertColumnStatistics(
ColumnStatistics columnStatistics,
int expectedNumberOfValues,
T expectedMin,
T expectedMax,
AggregateColumnStatistics aggregateColumnStatistics)
{
assertColumnStatistics(columnStatistics, expectedNumberOfValues, expectedMin, expectedMax);
// merge in forward order
int totalCount = aggregateColumnStatistics.getTotalCount();
assertColumnStatistics(aggregateColumnStatistics.getMergedColumnStatistics(Optional.empty()), totalCount, expectedMin, expectedMax);
assertColumnStatistics(aggregateColumnStatistics.getMergedColumnStatisticsPairwise(Optional.empty()), totalCount, expectedMin, expectedMax);
// merge in a random order
for (int i = 0; i < 10; i++) {
assertColumnStatistics(aggregateColumnStatistics.getMergedColumnStatistics(Optional.of(ThreadLocalRandom.current())), totalCount, expectedMin, expectedMax);
assertColumnStatistics(aggregateColumnStatistics.getMergedColumnStatisticsPairwise(Optional.of(ThreadLocalRandom.current())), totalCount, expectedMin, expectedMax);
}
List<ColumnStatistics> statisticsList = aggregateColumnStatistics.getStatisticsList();
assertNoColumnStatistics(mergeColumnStatistics(insertEmptyColumnStatisticsAt(statisticsList, 0, 10)), totalCount + 10);
assertNoColumnStatistics(mergeColumnStatistics(insertEmptyColumnStatisticsAt(statisticsList, statisticsList.size(), 10)), totalCount + 10);
assertNoColumnStatistics(mergeColumnStatistics(insertEmptyColumnStatisticsAt(statisticsList, statisticsList.size() / 2, 10)), totalCount + 10);
}
static List<ColumnStatistics> insertEmptyColumnStatisticsAt(List<ColumnStatistics> statisticsList, int index, long numberOfValues)
{
List<ColumnStatistics> newStatisticsList = new ArrayList<>(statisticsList);
newStatisticsList.add(index, new ColumnStatistics(numberOfValues, null, null, null));
return newStatisticsList;
}
public void assertColumnStatistics(ColumnStatistics columnStatistics, int expectedNumberOfValues, T expectedMin, T expectedMax)
{
assertEquals(columnStatistics.getNumberOfValues(), expectedNumberOfValues);
if (statisticsType == StatisticsType.BOOLEAN && expectedNumberOfValues > 0) {
assertNotNull(columnStatistics.getBooleanStatistics());
}
else {
assertNull(columnStatistics.getBooleanStatistics());
}
if (statisticsType == StatisticsType.INTEGER && expectedNumberOfValues > 0) {
assertRangeStatistics(columnStatistics.getIntegerStatistics(), expectedMin, expectedMax);
}
else {
assertNull(columnStatistics.getIntegerStatistics());
}
if (statisticsType == StatisticsType.DOUBLE && expectedNumberOfValues > 0) {
assertRangeStatistics(columnStatistics.getDoubleStatistics(), expectedMin, expectedMax);
}
else {
assertNull(columnStatistics.getDoubleStatistics());
}
if (statisticsType == StatisticsType.STRING && expectedNumberOfValues > 0) {
assertRangeStatistics(columnStatistics.getStringStatistics(), expectedMin, expectedMax);
}
else {
assertNull(columnStatistics.getStringStatistics());
}
if (statisticsType == StatisticsType.DATE && expectedNumberOfValues > 0) {
assertRangeStatistics(columnStatistics.getDateStatistics(), expectedMin, expectedMax);
}
else {
assertNull(columnStatistics.getDateStatistics());
}
if (statisticsType == StatisticsType.DECIMAL && expectedNumberOfValues > 0) {
assertRangeStatistics(columnStatistics.getDecimalStatistics(), expectedMin, expectedMax);
}
else {
assertNull(columnStatistics.getDecimalStatistics());
}
assertNull(columnStatistics.getBloomFilter());
}
void assertRangeStatistics(RangeStatistics<?> rangeStatistics, T expectedMin, T expectedMax)
{
assertNotNull(rangeStatistics);
assertEquals(rangeStatistics.getMin(), expectedMin);
assertEquals(rangeStatistics.getMax(), expectedMax);
}
public static class AggregateColumnStatistics
{
private int totalCount;
private final ImmutableList.Builder<ColumnStatistics> statisticsList = ImmutableList.builder();
public void add(ColumnStatistics columnStatistics)
{
totalCount += columnStatistics.getNumberOfValues();
statisticsList.add(columnStatistics);
}
public int getTotalCount()
{
return totalCount;
}
public List<ColumnStatistics> getStatisticsList()
{
return statisticsList.build();
}
public ColumnStatistics getMergedColumnStatistics(Optional<Random> random)
{
List<ColumnStatistics> statistics = new ArrayList<>(statisticsList.build());
random.ifPresent(rand -> Collections.shuffle(statistics, rand));
return mergeColumnStatistics(ImmutableList.copyOf(statistics));
}
public ColumnStatistics getMergedColumnStatisticsPairwise(Optional<Random> random)
{
List<ColumnStatistics> statistics = new ArrayList<>(statisticsList.build());
random.ifPresent(rand -> Collections.shuffle(statistics, rand));
return getMergedColumnStatisticsPairwise(ImmutableList.copyOf(statistics));
}
private static ColumnStatistics getMergedColumnStatisticsPairwise(List<ColumnStatistics> statistics)
{
while (statistics.size() > 1) {
ImmutableList.Builder<ColumnStatistics> mergedStatistics = ImmutableList.builder();
for (int i = 0; i < statistics.size(); i += 2) {
mergedStatistics.add(mergeColumnStatistics(statistics.subList(i, min(i + 2, statistics.size()))));
}
statistics = mergedStatistics.build();
}
return statistics.get(0);
}
}
}