AbstractOrcRecordReader.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.FixedWidthType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.orc.metadata.MetadataReader;
import com.facebook.presto.orc.metadata.OrcType;
import com.facebook.presto.orc.metadata.PostScript;
import com.facebook.presto.orc.metadata.StripeInformation;
import com.facebook.presto.orc.metadata.statistics.ColumnStatistics;
import com.facebook.presto.orc.metadata.statistics.StripeStatistics;
import com.facebook.presto.orc.reader.StreamReader;
import com.facebook.presto.orc.stream.InputStreamSources;
import com.facebook.presto.orc.stream.SharedBuffer;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.io.Closer;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import org.joda.time.DateTimeZone;
import org.openjdk.jol.info.ClassLayout;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Collectors;
import static com.facebook.presto.orc.AbstractOrcRecordReader.LinearProbeRangeFinder.createTinyStripesRangeFinder;
import static com.facebook.presto.orc.DwrfEncryptionInfo.createDwrfEncryptionInfo;
import static com.facebook.presto.orc.OrcDataSourceUtils.mergeAdjacentDiskRanges;
import static com.facebook.presto.orc.OrcReader.BATCH_SIZE_GROWTH_FACTOR;
import static com.facebook.presto.orc.OrcReader.MAX_BATCH_SIZE;
import static com.facebook.presto.orc.OrcWriteValidation.WriteChecksumBuilder.createWriteChecksumBuilder;
import static com.facebook.presto.orc.metadata.OrcType.OrcTypeKind.STRUCT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.util.Comparator.comparingLong;
import static java.util.Objects.requireNonNull;
abstract class AbstractOrcRecordReader<T extends StreamReader>
implements Closeable
{
protected final OrcAggregatedMemoryContext systemMemoryUsage;
private static final int INSTANCE_SIZE = ClassLayout.parseClass(AbstractOrcRecordReader.class).instanceSize();
private final OrcDataSource orcDataSource;
private final T[] streamReaders;
private final long totalRowCount;
private final long splitLength;
private final Set<Integer> presentColumns;
private final long maxBlockBytes;
private final Optional<EncryptionLibrary> encryptionLibrary;
private final Map<Integer, Integer> dwrfEncryptionGroupMap;
private final Map<Integer, Slice> intermediateKeyMetadata;
private long currentPosition;
private long currentStripePosition;
private int currentBatchSize;
private int nextBatchSize;
private int maxBatchSize = MAX_BATCH_SIZE;
private final List<StripeInformation> stripes;
private final StripeReader stripeReader;
private int currentStripe = -1;
private OrcAggregatedMemoryContext currentStripeSystemMemoryContext;
private Optional<DwrfEncryptionInfo> dwrfEncryptionInfo = Optional.empty();
private final long fileRowCount;
private final List<Long> stripeFilePositions;
private long filePosition;
private Iterator<RowGroup> rowGroups = ImmutableList.<RowGroup>of().iterator();
private int currentRowGroup = -1;
private int currentGroupRowCount;
private int nextRowInGroup;
private final long[] maxBytesPerCell;
private long maxCombinedBytesPerRow;
private final Map<String, Slice> userMetadata;
private final Optional<OrcWriteValidation> writeValidation;
private final Optional<OrcWriteValidation.WriteChecksumBuilder> writeChecksumBuilder;
private final Optional<OrcWriteValidation.StatisticsValidation> rowGroupStatisticsValidation;
private final Optional<OrcWriteValidation.StatisticsValidation> stripeStatisticsValidation;
private final Optional<OrcWriteValidation.StatisticsValidation> fileStatisticsValidation;
private final Optional<OrcFileIntrospector> fileIntrospector;
private final RuntimeStats runtimeStats;
public AbstractOrcRecordReader(
Map<Integer, Type> includedColumns,
Map<Integer, List<Subfield>> requiredSubfields,
T[] streamReaders,
OrcPredicate predicate,
long numberOfRows,
List<StripeInformation> fileStripes,
List<ColumnStatistics> fileStats,
List<StripeStatistics> stripeStats,
OrcDataSource orcDataSource,
long splitOffset,
long splitLength,
List<OrcType> types,
Optional<OrcDecompressor> decompressor,
Optional<EncryptionLibrary> encryptionLibrary,
Map<Integer, Integer> dwrfEncryptionGroupMap,
Map<Integer, Slice> columnToIntermediateKeyMap,
int rowsInRowGroup,
DateTimeZone hiveStorageTimeZone,
PostScript.HiveWriterVersion hiveWriterVersion,
MetadataReader metadataReader,
DataSize maxMergeDistance,
DataSize tinyStripeThreshold,
DataSize maxBlockSize,
Map<String, Slice> userMetadata,
OrcAggregatedMemoryContext systemMemoryUsage,
Optional<OrcWriteValidation> writeValidation,
int initialBatchSize,
StripeMetadataSource stripeMetadataSource,
boolean cacheable,
RuntimeStats runtimeStats,
Optional<OrcFileIntrospector> fileIntrospector,
long fileModificationTime)
{
requireNonNull(includedColumns, "includedColumns is null");
requireNonNull(predicate, "predicate is null");
requireNonNull(fileStripes, "fileStripes is null");
requireNonNull(stripeStats, "stripeStats is null");
requireNonNull(orcDataSource, "orcDataSource is null");
requireNonNull(types, "types is null");
requireNonNull(decompressor, "decompressor is null");
requireNonNull(encryptionLibrary, "encryptionLibrary is null");
requireNonNull(dwrfEncryptionGroupMap, "dwrfEncryptionGroupMap is null");
requireNonNull(columnToIntermediateKeyMap, "columnToIntermediateKeyMap is null");
requireNonNull(hiveStorageTimeZone, "hiveStorageTimeZone is null");
requireNonNull(userMetadata, "userMetadata is null");
requireNonNull(systemMemoryUsage, "systemMemoryUsage is null");
this.writeValidation = requireNonNull(writeValidation, "writeValidation is null");
this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(includedColumns));
this.rowGroupStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(includedColumns));
this.stripeStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(includedColumns));
this.fileStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(includedColumns));
this.systemMemoryUsage = systemMemoryUsage;
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.fileIntrospector = requireNonNull(fileIntrospector, "fileIntrospector is null");
// reduce the included columns to the set that is also present
ImmutableSet.Builder<Integer> presentColumns = ImmutableSet.builder();
OrcType root = types.get(0);
for (int column : includedColumns.keySet()) {
// an old file can have fewer columns since columns can be added
// after the file was written
if (column >= 0 && column < root.getFieldCount()) {
presentColumns.add(column);
}
}
this.presentColumns = presentColumns.build();
this.maxBlockBytes = requireNonNull(maxBlockSize, "maxBlockSize is null").toBytes();
// it is possible that old versions of orc use 0 to mean there are no row groups
checkArgument(rowsInRowGroup > 0, "rowsInRowGroup must be greater than zero");
// sort stripes by file position
List<StripeInfo> stripeInfos = new ArrayList<>();
for (int i = 0; i < fileStripes.size(); i++) {
Optional<StripeStatistics> stats = Optional.empty();
// ignore all stripe stats if too few or too many
if (stripeStats.size() == fileStripes.size()) {
stats = Optional.of(stripeStats.get(i));
}
stripeInfos.add(new StripeInfo(fileStripes.get(i), stats));
}
Collections.sort(stripeInfos, comparingLong(info -> info.getStripe().getOffset()));
long totalRowCount = 0;
long fileRowCount = 0;
ImmutableList.Builder<StripeInformation> stripes = ImmutableList.builder();
ImmutableList.Builder<Long> stripeFilePositions = ImmutableList.builder();
if (predicate.matches(numberOfRows, getStatisticsByColumnOrdinal(root, fileStats))) {
// select stripes that start within the specified split
for (StripeInfo info : stripeInfos) {
StripeInformation stripe = info.getStripe();
if (splitContainsStripe(splitOffset, splitLength, stripe) && isStripeIncluded(root, stripe, info.getStats(), predicate)) {
stripes.add(stripe);
stripeFilePositions.add(fileRowCount);
totalRowCount += stripe.getNumberOfRows();
}
fileRowCount += stripe.getNumberOfRows();
}
}
this.totalRowCount = totalRowCount;
this.stripes = stripes.build();
this.stripeFilePositions = stripeFilePositions.build();
orcDataSource = wrapWithCacheIfTinyStripes(orcDataSource, this.stripes, maxMergeDistance, tinyStripeThreshold, systemMemoryUsage);
this.orcDataSource = orcDataSource;
this.splitLength = splitLength;
this.fileRowCount = stripeInfos.stream()
.map(StripeInfo::getStripe)
.mapToLong(StripeInformation::getNumberOfRows)
.sum();
this.userMetadata = ImmutableMap.copyOf(Maps.transformValues(userMetadata, Slices::copyOf));
this.currentStripeSystemMemoryContext = this.systemMemoryUsage.newOrcAggregatedMemoryContext();
Set<Integer> includedOrcColumns = getIncludedOrcColumns(types, this.presentColumns, requireNonNull(requiredSubfields, "requiredSubfields is null"));
this.encryptionLibrary = encryptionLibrary;
this.dwrfEncryptionGroupMap = ImmutableMap.copyOf(dwrfEncryptionGroupMap);
this.intermediateKeyMetadata = createIntermediateKeysMap(columnToIntermediateKeyMap, dwrfEncryptionGroupMap, orcDataSource.getId());
checkPermissionsForEncryptedColumns(includedOrcColumns, dwrfEncryptionGroupMap, intermediateKeyMetadata);
stripeReader = new StripeReader(
orcDataSource,
decompressor,
types,
includedOrcColumns,
rowsInRowGroup,
predicate,
hiveWriterVersion,
metadataReader,
writeValidation,
stripeMetadataSource,
cacheable,
this.dwrfEncryptionGroupMap,
runtimeStats,
fileIntrospector,
fileModificationTime);
this.streamReaders = requireNonNull(streamReaders, "streamReaders is null");
for (int columnId = 0; columnId < root.getFieldCount(); columnId++) {
if (includedColumns.containsKey(columnId)) {
checkArgument(streamReaders[columnId] != null, "Missing stream reader for column " + columnId);
}
}
maxBytesPerCell = new long[streamReaders.length];
OptionalInt fixedWidthRowSize = getFixedWidthRowSize(this.presentColumns, includedColumns);
if (fixedWidthRowSize.isPresent()) {
if (fixedWidthRowSize.getAsInt() == 0) {
nextBatchSize = MAX_BATCH_SIZE;
}
else {
nextBatchSize = adjustMaxBatchSize(MAX_BATCH_SIZE, maxBlockBytes, fixedWidthRowSize.getAsInt());
}
}
else {
nextBatchSize = initialBatchSize;
}
}
private static Set<Integer> getIncludedOrcColumns(List<OrcType> types, Set<Integer> includedColumns, Map<Integer, List<Subfield>> requiredSubfields)
{
Set<Integer> includes = new LinkedHashSet<>();
OrcType root = types.get(0);
for (int includedColumn : includedColumns) {
List<Subfield> subfields = Optional.ofNullable(requiredSubfields.get(includedColumn)).orElse(ImmutableList.of());
includeOrcColumnsRecursive(types, includes, root.getFieldTypeIndex(includedColumn), subfields);
}
return includes;
}
private static void includeOrcColumnsRecursive(List<OrcType> types, Set<Integer> result, int typeId, List<Subfield> requiredSubfields)
{
result.add(typeId);
OrcType type = types.get(typeId);
Optional<Map<String, List<Subfield>>> requiredFields = Optional.empty();
if (type.getOrcTypeKind() == STRUCT) {
requiredFields = getRequiredFields(requiredSubfields);
}
int children = type.getFieldCount();
for (int i = 0; i < children; ++i) {
List<Subfield> subfields = ImmutableList.of();
if (requiredFields.isPresent()) {
String fieldName = type.getFieldNames().get(i).toLowerCase(Locale.ENGLISH);
if (!requiredFields.get().containsKey(fieldName)) {
continue;
}
subfields = requiredFields.get().get(fieldName);
}
includeOrcColumnsRecursive(types, result, type.getFieldTypeIndex(i), subfields);
}
}
private static Optional<Map<String, List<Subfield>>> getRequiredFields(List<Subfield> requiredSubfields)
{
if (requiredSubfields.isEmpty()) {
return Optional.empty();
}
Map<String, List<Subfield>> fields = new HashMap<>();
for (Subfield subfield : requiredSubfields) {
List<Subfield.PathElement> path = subfield.getPath();
if (path.size() == 1 && path.get(0) instanceof Subfield.NoSubfield) {
continue;
}
String name = ((Subfield.NestedField) path.get(0)).getName().toLowerCase(Locale.ENGLISH);
fields.computeIfAbsent(name, k -> new ArrayList<>());
if (path.size() > 1) {
fields.get(name).add(new Subfield("c", path.subList(1, path.size())));
}
}
return Optional.of(ImmutableMap.copyOf(fields));
}
private static Map<Integer, Slice> createIntermediateKeysMap(
Map<Integer, Slice> columnsToKeys,
Map<Integer, Integer> dwrfEncryptionGroupMap,
OrcDataSourceId dataSourceId)
{
Map<Integer, Slice> intermediateKeys = new HashMap<>(dwrfEncryptionGroupMap.values().size());
for (Map.Entry<Integer, Slice> entry : columnsToKeys.entrySet()) {
Slice key = entry.getValue();
int orcColumn = entry.getKey();
if (!dwrfEncryptionGroupMap.containsKey(orcColumn)) {
// ignore columns that don't have encryption groups
continue;
}
int group = dwrfEncryptionGroupMap.get(orcColumn);
Slice previous = intermediateKeys.putIfAbsent(group, key);
if (previous != null && !key.equals(previous)) {
throw new OrcCorruptionException(dataSourceId, "intermediate keys mapping does not match encryption groups");
}
}
return ImmutableMap.copyOf(intermediateKeys);
}
private void checkPermissionsForEncryptedColumns(Set<Integer> includedOrcColumns, Map<Integer, Integer> dwrfEncryptionGroupMap, Map<Integer, Slice> intermediateKeyMetadata)
{
for (Integer column : includedOrcColumns) {
if (dwrfEncryptionGroupMap.containsKey(column) && !intermediateKeyMetadata.containsKey(dwrfEncryptionGroupMap.get(column))) {
throw new OrcPermissionsException(orcDataSource.getId(), "No IEK provided to Decrypt column number %s", column);
}
}
}
private static OptionalInt getFixedWidthRowSize(Set<Integer> columns, Map<Integer, Type> columnTypes)
{
int totalFixedWidth = 0;
for (int column : columns) {
Type type = columnTypes.get(column);
if (type instanceof FixedWidthType) {
// add 1 byte for null flag
totalFixedWidth += ((FixedWidthType) type).getFixedSize() + 1;
}
else {
return OptionalInt.empty();
}
}
return OptionalInt.of(totalFixedWidth);
}
/**
* Returns the sum of the largest cells in size from each column
*/
public long getMaxCombinedBytesPerRow()
{
return maxCombinedBytesPerRow;
}
protected void updateMaxCombinedBytesPerRow(int columnIndex, Block block)
{
if (block.getPositionCount() > 0) {
long bytesPerCell = block.getSizeInBytes() / block.getPositionCount();
if (maxBytesPerCell[columnIndex] < bytesPerCell) {
maxCombinedBytesPerRow = maxCombinedBytesPerRow - maxBytesPerCell[columnIndex] + bytesPerCell;
maxBytesPerCell[columnIndex] = bytesPerCell;
maxBatchSize = adjustMaxBatchSize(maxBatchSize, maxBlockBytes, maxCombinedBytesPerRow);
}
}
}
protected T[] getStreamReaders()
{
return streamReaders;
}
private static boolean splitContainsStripe(long splitOffset, long splitLength, StripeInformation stripe)
{
long splitEndOffset = splitOffset + splitLength;
return splitOffset <= stripe.getOffset() && stripe.getOffset() < splitEndOffset;
}
private static boolean isStripeIncluded(
OrcType rootStructType,
StripeInformation stripe,
Optional<StripeStatistics> stripeStats,
OrcPredicate predicate)
{
// if there are no stats, include the column
if (!stripeStats.isPresent()) {
return true;
}
return predicate.matches(stripe.getNumberOfRows(), getStatisticsByColumnOrdinal(rootStructType, stripeStats.get().getColumnStatistics()));
}
@VisibleForTesting
static OrcDataSource wrapWithCacheIfTinyStripes(OrcDataSource dataSource, List<StripeInformation> stripes, DataSize maxMergeDistance, DataSize tinyStripeThreshold, OrcAggregatedMemoryContext systemMemoryContext)
{
if (dataSource instanceof CachingOrcDataSource) {
return dataSource;
}
for (StripeInformation stripe : stripes) {
if (stripe.getTotalLength() > tinyStripeThreshold.toBytes()) {
return dataSource;
}
}
return new CachingOrcDataSource(dataSource, createTinyStripesRangeFinder(stripes, maxMergeDistance, tinyStripeThreshold), systemMemoryContext.newOrcLocalMemoryContext(CachingOrcDataSource.class.getSimpleName()));
}
/**
* Return the row position relative to the start of the file.
*/
public long getFilePosition()
{
return filePosition;
}
/**
* Returns the total number of rows in the file. This count includes rows
* for stripes that were completely excluded due to stripe statistics.
*/
public long getFileRowCount()
{
return fileRowCount;
}
/**
* Return the row position within the stripes being read by this reader.
* This position will include rows that were never read due to row groups
* that are excluded due to row group statistics. Thus, it will advance
* faster than the number of rows actually read.
*/
public long getReaderPosition()
{
return currentPosition;
}
/**
* Returns the total number of rows that can possibly be read by this reader.
* This count may be fewer than the number of rows in the file if some
* stripes were excluded due to stripe statistics, but may be more than
* the number of rows read if some row groups are excluded due to statistics.
*/
public long getReaderRowCount()
{
return totalRowCount;
}
public long getSplitLength()
{
return splitLength;
}
@Override
public void close()
throws IOException
{
try (Closer closer = Closer.create()) {
closer.register(orcDataSource);
for (StreamReader column : streamReaders) {
if (column != null) {
closer.register(column::close);
}
}
}
rowGroups = null;
if (writeChecksumBuilder.isPresent()) {
OrcWriteValidation.WriteChecksum actualChecksum = writeChecksumBuilder.get().build();
validateWrite(validation -> validation.getChecksum().getTotalRowCount() == actualChecksum.getTotalRowCount(), "Invalid row count");
List<Long> columnHashes = actualChecksum.getColumnHashes();
for (int i = 0; i < columnHashes.size(); i++) {
int columnIndex = i;
validateWrite(validation -> validation.getChecksum().getColumnHashes().get(columnIndex).equals(columnHashes.get(columnIndex)),
"Invalid checksum for column %s", columnIndex);
}
validateWrite(validation -> validation.getChecksum().getStripeHash() == actualChecksum.getStripeHash(), "Invalid stripes checksum");
}
if (fileStatisticsValidation.isPresent()) {
List<ColumnStatistics> columnStatistics = fileStatisticsValidation.get().build();
writeValidation.get().validateFileStatistics(orcDataSource.getId(), columnStatistics);
}
}
public boolean isColumnPresent(int hiveColumnIndex)
{
return presentColumns.contains(hiveColumnIndex);
}
public Map<String, Slice> getUserMetadata()
{
return ImmutableMap.copyOf(Maps.transformValues(userMetadata, Slices::copyOf));
}
private boolean advanceToNextRowGroup()
throws IOException
{
nextRowInGroup = 0;
if (currentRowGroup >= 0) {
if (rowGroupStatisticsValidation.isPresent()) {
OrcWriteValidation.StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.get();
long offset = stripes.get(currentStripe).getOffset();
writeValidation.get().validateRowGroupStatistics(orcDataSource.getId(), offset, currentRowGroup, statisticsValidation.build());
statisticsValidation.reset();
}
}
while (!rowGroups.hasNext() && currentStripe < stripes.size()) {
advanceToNextStripe();
currentRowGroup = -1;
}
if (!rowGroups.hasNext()) {
currentGroupRowCount = 0;
return false;
}
currentRowGroup++;
RowGroup currentRowGroup = rowGroups.next();
currentGroupRowCount = toIntExact(currentRowGroup.getRowCount());
if (currentRowGroup.getMinAverageRowBytes() > 0) {
maxBatchSize = adjustMaxBatchSize(maxBatchSize, maxBlockBytes, currentRowGroup.getMinAverageRowBytes());
}
currentPosition = currentStripePosition + currentRowGroup.getRowOffset();
filePosition = stripeFilePositions.get(currentStripe) + currentRowGroup.getRowOffset();
// give reader data streams from row group
InputStreamSources rowGroupStreamSources = currentRowGroup.getStreamSources();
for (StreamReader column : streamReaders) {
if (column != null) {
column.startRowGroup(rowGroupStreamSources);
}
}
return true;
}
private static int adjustMaxBatchSize(int maxBatchSize, long maxBlockBytes, long averageRowBytes)
{
return toIntExact(min(maxBatchSize, max(1, maxBlockBytes / averageRowBytes)));
}
protected int getNextRowInGroup()
{
return nextRowInGroup;
}
protected void batchRead(int batchSize)
{
nextRowInGroup += batchSize;
}
protected int prepareNextBatch()
throws IOException
{
// update position for current row group (advancing resets them)
filePosition += currentBatchSize;
currentPosition += currentBatchSize;
// if next row is within the current group return
if (nextRowInGroup >= currentGroupRowCount) {
// attempt to advance to next row group
if (!advanceToNextRowGroup()) {
filePosition = fileRowCount;
currentPosition = totalRowCount;
return -1;
}
}
// We will grow currentBatchSize by BATCH_SIZE_GROWTH_FACTOR starting from initialBatchSize to maxBatchSize or
// the number of rows left in this rowgroup, whichever is smaller. maxBatchSize is adjusted according to the
// block size for every batch and never exceed MAX_BATCH_SIZE. But when the number of rows in the last batch in
// the current rowgroup is smaller than min(nextBatchSize, maxBatchSize), the nextBatchSize for next batch in
// the new rowgroup should be grown based on min(nextBatchSize, maxBatchSize) but not by the number of rows in
// the last batch, i.e. currentGroupRowCount - nextRowInGroup. For example, if the number of rows read for
// single fixed width column are: 1, 16, 256, 1024, 1024,..., 1024, 256 and the 256 was because there is only
// 256 rows left in this row group, then the nextBatchSize should be 1024 instead of 512. So we need to grow the
// nextBatchSize before limiting the currentBatchSize by currentGroupRowCount - nextRowInGroup.
currentBatchSize = toIntExact(min(nextBatchSize, maxBatchSize));
nextBatchSize = min(currentBatchSize * BATCH_SIZE_GROWTH_FACTOR, MAX_BATCH_SIZE);
currentBatchSize = toIntExact(min(currentBatchSize, currentGroupRowCount - nextRowInGroup));
return currentBatchSize;
}
private void advanceToNextStripe()
throws IOException
{
currentStripeSystemMemoryContext.close();
currentStripeSystemMemoryContext = systemMemoryUsage.newOrcAggregatedMemoryContext();
rowGroups = ImmutableList.<RowGroup>of().iterator();
if (currentStripe >= 0) {
if (stripeStatisticsValidation.isPresent()) {
OrcWriteValidation.StatisticsValidation statisticsValidation = stripeStatisticsValidation.get();
long offset = stripes.get(currentStripe).getOffset();
writeValidation.get().validateStripeStatistics(orcDataSource.getId(), offset, statisticsValidation.build());
statisticsValidation.reset();
}
}
currentStripe++;
if (currentStripe >= stripes.size()) {
return;
}
if (currentStripe > 0) {
currentStripePosition += stripes.get(currentStripe - 1).getNumberOfRows();
}
StripeInformation stripeInformation = stripes.get(currentStripe);
validateWriteStripe(stripeInformation.getNumberOfRows());
List<byte[]> stripeDecryptionKeyMetadata = getDecryptionKeyMetadata(currentStripe, stripes);
// if there are encrypted columns and dwrfEncryptionInfo hasn't been set yet
// or it has been set, but we have new decryption keys,
// set dwrfEncryptionInfo
if ((!stripeDecryptionKeyMetadata.isEmpty() && !dwrfEncryptionInfo.isPresent())
|| (dwrfEncryptionInfo.isPresent() && !stripeDecryptionKeyMetadata.equals(dwrfEncryptionInfo.get().getEncryptedKeyMetadatas()))) {
verify(encryptionLibrary.isPresent(), "encryptionLibrary is absent");
dwrfEncryptionInfo = Optional.of(createDwrfEncryptionInfo(encryptionLibrary.get(), stripeDecryptionKeyMetadata, intermediateKeyMetadata, dwrfEncryptionGroupMap));
}
SharedBuffer sharedDecompressionBuffer = new SharedBuffer(currentStripeSystemMemoryContext.newOrcLocalMemoryContext("sharedDecompressionBuffer"));
Stripe stripe = stripeReader.readStripe(stripeInformation, currentStripeSystemMemoryContext, dwrfEncryptionInfo, sharedDecompressionBuffer);
if (stripe != null) {
for (StreamReader column : streamReaders) {
if (column != null) {
column.startStripe(stripe);
}
}
rowGroups = stripe.getRowGroups().iterator();
}
fileIntrospector.ifPresent(introspector -> introspector.onStripe(stripeInformation, stripe));
}
@VisibleForTesting
public static List<byte[]> getDecryptionKeyMetadata(int currentStripe, List<StripeInformation> stripes)
{
// if this stripe has encryption keys, then those are used
// otherwise look at nearest prior stripe that specifies encryption keys
// if the first stripe has no encryption information, then there are no encrypted columns;
for (int i = currentStripe; i >= 0; i--) {
if (!stripes.get(i).getKeyMetadata().isEmpty()) {
return stripes.get(i).getKeyMetadata();
}
}
return ImmutableList.of();
}
private void validateWrite(Predicate<OrcWriteValidation> test, String messageFormat, Object... args)
throws OrcCorruptionException
{
if (writeValidation.isPresent() && !test.apply(writeValidation.get())) {
throw new OrcCorruptionException(orcDataSource.getId(), "Write validation failed: " + messageFormat, args);
}
}
private void validateWriteStripe(long rowCount)
{
if (writeChecksumBuilder.isPresent()) {
writeChecksumBuilder.get().addStripe(rowCount);
}
}
private static Map<Integer, ColumnStatistics> getStatisticsByColumnOrdinal(OrcType rootStructType, List<ColumnStatistics> fileStats)
{
requireNonNull(rootStructType, "rootStructType is null");
checkArgument(rootStructType.getOrcTypeKind() == STRUCT);
requireNonNull(fileStats, "fileStats is null");
ImmutableMap.Builder<Integer, ColumnStatistics> statistics = ImmutableMap.builder();
for (int ordinal = 0; ordinal < rootStructType.getFieldCount(); ordinal++) {
if (fileStats.size() > ordinal) {
ColumnStatistics element = fileStats.get(rootStructType.getFieldTypeIndex(ordinal));
if (element != null) {
statistics.put(ordinal, element);
}
}
}
return statistics.build();
}
/**
* @return The size of memory retained by all the stream readers (local buffers + object overhead)
*/
@VisibleForTesting
long getStreamReaderRetainedSizeInBytes()
{
long totalRetainedSizeInBytes = 0;
for (StreamReader column : streamReaders) {
if (column != null) {
totalRetainedSizeInBytes += column.getRetainedSizeInBytes();
}
}
return totalRetainedSizeInBytes;
}
/**
* @return The size of memory retained by the current stripe (excludes object overheads)
*/
@VisibleForTesting
long getCurrentStripeRetainedSizeInBytes()
{
return currentStripeSystemMemoryContext.getBytes();
}
protected long getRetainedSizeInBytes()
{
return INSTANCE_SIZE + getStreamReaderRetainedSizeInBytes() + getCurrentStripeRetainedSizeInBytes();
}
/**
* @return The system memory reserved by this OrcRecordReader. It does not include non-leaf level StreamReaders'
* instance sizes.
*/
@VisibleForTesting
long getSystemMemoryUsage()
{
return systemMemoryUsage.getBytes();
}
protected boolean shouldValidateWritePageChecksum()
{
return writeChecksumBuilder.isPresent();
}
protected void validateWritePageChecksum(Page page)
{
if (writeChecksumBuilder.isPresent()) {
writeChecksumBuilder.get().addPage(page);
rowGroupStatisticsValidation.get().addPage(page);
stripeStatisticsValidation.get().addPage(page);
fileStatisticsValidation.get().addPage(page);
}
}
protected OrcDataSourceId getOrcDataSourceId()
{
return orcDataSource.getId();
}
private static class StripeInfo
{
private final StripeInformation stripe;
private final Optional<StripeStatistics> stats;
public StripeInfo(StripeInformation stripe, Optional<StripeStatistics> stats)
{
this.stripe = requireNonNull(stripe, "stripe is null");
this.stats = requireNonNull(stats, "metadata is null");
}
public StripeInformation getStripe()
{
return stripe;
}
public Optional<StripeStatistics> getStats()
{
return stats;
}
}
@VisibleForTesting
static class LinearProbeRangeFinder
implements CachingOrcDataSource.RegionFinder
{
private final List<DiskRange> diskRanges;
private int index;
public LinearProbeRangeFinder(List<DiskRange> diskRanges)
{
this.diskRanges = diskRanges;
}
@Override
public DiskRange getRangeFor(long desiredOffset)
{
// Assumption: range are always read in order
// Assumption: bytes that are not part of any range are never read
for (; index < diskRanges.size(); index++) {
DiskRange range = diskRanges.get(index);
if (range.getEnd() > desiredOffset) {
checkArgument(range.getOffset() <= desiredOffset);
return range;
}
}
throw new IllegalArgumentException("Invalid desiredOffset " + desiredOffset);
}
public static OrcBatchRecordReader.LinearProbeRangeFinder createTinyStripesRangeFinder(List<StripeInformation> stripes, DataSize maxMergeDistance, DataSize tinyStripeThreshold)
{
if (stripes.isEmpty()) {
return new OrcBatchRecordReader.LinearProbeRangeFinder(ImmutableList.of());
}
List<DiskRange> scratchDiskRanges = stripes.stream()
.map(stripe -> new DiskRange(stripe.getOffset(), toIntExact(stripe.getTotalLength())))
.collect(Collectors.toList());
List<DiskRange> diskRanges = mergeAdjacentDiskRanges(scratchDiskRanges, maxMergeDistance, tinyStripeThreshold);
return new OrcBatchRecordReader.LinearProbeRangeFinder(diskRanges);
}
}
}