TableStatisticsRecorder.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.tpch.statistics;
import io.airlift.tpch.TpchColumn;
import io.airlift.tpch.TpchColumnType;
import io.airlift.tpch.TpchEntity;
import io.airlift.tpch.TpchTable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.lang.String.format;
class TableStatisticsRecorder
{
<E extends TpchEntity> TableStatisticsData recordStatistics(TpchTable<E> tpchTable, Predicate<E> constraint, double scaleFactor)
{
int parallelPartsToProcess = Runtime.getRuntime().availableProcessors();
if (tpchTable.equals(TpchTable.NATION) || tpchTable.equals(TpchTable.REGION)) {
//These tables have too few rows to benefit from parallel processing
parallelPartsToProcess = 1;
}
ArrayList<CompletableFuture<TablePartStatistics>> statsRecorders = new ArrayList<>();
for (int i = 0; i < parallelPartsToProcess; i++) {
//Generate a part of data
Iterable<E> rows = tpchTable.createGenerator(scaleFactor, i + 1, parallelPartsToProcess);
//Record its statistics on a separate thread
statsRecorders.add(CompletableFuture.supplyAsync(() -> recordStatistics(rows, tpchTable.getColumns(), constraint)));
}
try {
//Wait for all parts to finish processing
CompletableFuture.allOf(statsRecorders.toArray(new CompletableFuture[0])).get();
}
catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
}
Optional<TablePartStatistics> combinedStatistics = statsRecorders.stream().map(CompletableFuture::join).reduce((x, y) -> {
x.setRowCount(x.getRowCount() + y.getRowCount());
x.setRawColStats(Stream.of(x.getRawColStats(), y.getRawColStats())
.flatMap(m -> m.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, ColumnStatisticsRecorder::mergeWith)));
return x;
});
checkState(combinedStatistics.isPresent(), "combinedStatistics empty, no stats were produced");
final Map<String, ColumnStatisticsData> combinedTableStatsRecording = combinedStatistics.get()
.getRawColStats().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, y -> y.getValue().getRecording()));
return new TableStatisticsData(combinedStatistics.get().getRowCount(), combinedTableStatsRecording);
}
private <E extends TpchEntity> TablePartStatistics recordStatistics(Iterable<E> rows, List<TpchColumn<E>> columns, Predicate<E> constraint)
{
Map<String, ColumnStatisticsRecorder> statisticsRecorders = createStatisticsRecorders(columns);
long rowCount = 0;
for (E row : rows) {
if (constraint.test(row)) {
rowCount++;
for (TpchColumn<E> column : columns) {
Comparable<?> value = getTpchValue(row, column);
statisticsRecorders.get(column.getColumnName()).record(value);
}
}
}
return new TablePartStatistics(rowCount, statisticsRecorders);
}
private <E extends TpchEntity> Map<String, ColumnStatisticsRecorder> createStatisticsRecorders(List<TpchColumn<E>> columns)
{
return columns.stream()
.collect(toImmutableMap(TpchColumn::getColumnName, (column) -> new ColumnStatisticsRecorder(column.getType())));
}
private <E extends TpchEntity> Comparable<?> getTpchValue(E row, TpchColumn<E> column)
{
TpchColumnType.Base baseType = column.getType().getBase();
switch (baseType) {
case IDENTIFIER:
return column.getIdentifier(row);
case INTEGER:
return column.getInteger(row);
case DATE:
return column.getDate(row);
case DOUBLE:
return column.getDouble(row);
case VARCHAR:
return column.getString(row);
}
throw new UnsupportedOperationException(format("Unsupported TPCH base type [%s]", baseType));
}
private static class TablePartStatistics
{
long rowCount;
Map<String, ColumnStatisticsRecorder> rawColStats;
public TablePartStatistics(long rowCount, Map<String, ColumnStatisticsRecorder> rawColStats)
{
this.rowCount = rowCount;
this.rawColStats = rawColStats;
}
public Map<String, ColumnStatisticsRecorder> getRawColStats()
{
return rawColStats;
}
public void setRawColStats(Map<String, ColumnStatisticsRecorder> rawColStats)
{
this.rawColStats = rawColStats;
}
public long getRowCount()
{
return rowCount;
}
public void setRowCount(long rowCount)
{
this.rowCount = rowCount;
}
}
}