OrcWriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.orc;

import com.facebook.presto.common.Page;
import com.facebook.presto.common.io.DataOutput;
import com.facebook.presto.common.io.DataSink;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.orc.OrcWriteValidation.OrcWriteValidationBuilder;
import com.facebook.presto.orc.OrcWriteValidation.OrcWriteValidationMode;
import com.facebook.presto.orc.metadata.ColumnEncoding;
import com.facebook.presto.orc.metadata.CompressedMetadataWriter;
import com.facebook.presto.orc.metadata.CompressionKind;
import com.facebook.presto.orc.metadata.DwrfEncryption;
import com.facebook.presto.orc.metadata.DwrfStripeCacheData;
import com.facebook.presto.orc.metadata.DwrfStripeCacheWriter;
import com.facebook.presto.orc.metadata.EncryptionGroup;
import com.facebook.presto.orc.metadata.Footer;
import com.facebook.presto.orc.metadata.Metadata;
import com.facebook.presto.orc.metadata.OrcType;
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.metadata.StripeEncryptionGroup;
import com.facebook.presto.orc.metadata.StripeFooter;
import com.facebook.presto.orc.metadata.StripeInformation;
import com.facebook.presto.orc.metadata.statistics.ColumnStatistics;
import com.facebook.presto.orc.metadata.statistics.StripeStatistics;
import com.facebook.presto.orc.proto.DwrfProto;
import com.facebook.presto.orc.stream.StreamDataOutput;
import com.facebook.presto.orc.writer.ColumnWriter;
import com.facebook.presto.orc.writer.CompressionBufferPool;
import com.facebook.presto.orc.writer.CompressionBufferPool.LastUsedCompressionBufferPool;
import com.facebook.presto.orc.writer.DictionaryColumnWriter;
import com.facebook.presto.orc.writer.StreamLayout;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import it.unimi.dsi.fastutil.ints.Int2LongMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.objects.Object2LongMap;
import org.joda.time.DateTimeZone;
import org.openjdk.jol.info.ClassLayout;

import javax.annotation.Nullable;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.presto.common.io.DataOutput.createDataOutput;
import static com.facebook.presto.orc.DwrfEncryptionInfo.UNENCRYPTED;
import static com.facebook.presto.orc.DwrfEncryptionInfo.createNodeToGroupMap;
import static com.facebook.presto.orc.FlushReason.CLOSED;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static com.facebook.presto.orc.OrcReader.validateFile;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DIRECT;
import static com.facebook.presto.orc.metadata.ColumnEncoding.DEFAULT_SEQUENCE_ID;
import static com.facebook.presto.orc.metadata.DwrfMetadataWriter.toFileStatistics;
import static com.facebook.presto.orc.metadata.DwrfMetadataWriter.toStripeEncryptionGroup;
import static com.facebook.presto.orc.metadata.OrcType.createNodeIdToColumnMap;
import static com.facebook.presto.orc.metadata.OrcType.mapColumnToNode;
import static com.facebook.presto.orc.metadata.PostScript.MAGIC;
import static com.facebook.presto.orc.metadata.statistics.ColumnStatistics.mergeColumnStatistics;
import static com.facebook.presto.orc.writer.ColumnWriters.createColumnWriter;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.Integer.min;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public class OrcWriter
        implements Closeable
{
    private static final int INSTANCE_SIZE = ClassLayout.parseClass(OrcWriter.class).instanceSize();

    static final String PRESTO_ORC_WRITER_VERSION_METADATA_KEY = "presto.writer.version";
    static final String PRESTO_ORC_WRITER_VERSION;

    static {
        String version = OrcWriter.class.getPackage().getImplementationVersion();
        PRESTO_ORC_WRITER_VERSION = version == null ? "UNKNOWN" : version;
    }

    private final WriterStats stats;
    private final OrcWriterFlushPolicy flushPolicy;
    private final DataSink dataSink;
    private final List<Type> types;
    private final OrcEncoding orcEncoding;
    private final ColumnWriterOptions columnWriterOptions;
    private final int rowGroupMaxRowCount;
    private final StreamLayout streamLayout;
    private final Map<String, String> userMetadata;
    private final CompressedMetadataWriter metadataWriter;
    private final DateTimeZone hiveStorageTimeZone;

    private final DwrfEncryptionProvider dwrfEncryptionProvider;
    private final DwrfEncryptionInfo dwrfEncryptionInfo;
    private final Optional<DwrfWriterEncryption> dwrfWriterEncryption;

    private final List<ClosedStripe> closedStripes = new ArrayList<>();
    private final List<OrcType> orcTypes;

    private final List<ColumnWriter> columnWriters;
    private final Optional<DwrfStripeCacheWriter> dwrfStripeCacheWriter;
    private final int dictionaryMaxMemoryBytes;
    private final DictionaryCompressionOptimizer dictionaryCompressionOptimizer;
    @Nullable
    private final OrcWriteValidation.OrcWriteValidationBuilder validationBuilder;
    private final CompressionBufferPool compressionBufferPool;

    private int stripeRowCount;
    private int rowGroupRowCount;
    private int bufferedBytes;
    private long columnWritersRetainedBytes;
    private long closedStripesRetainedBytes;
    private long previouslyRecordedSizeInBytes;
    private boolean closed;

    private long numberOfRows;
    private long stripeRawSize;
    private long rawSize;
    private List<ColumnStatistics> unencryptedStats;
    private final Map<Integer, Integer> nodeIdToColumn;
    private final StreamSizeHelper streamSizeHelper;

    public OrcWriter(
            DataSink dataSink,
            List<String> columnNames,
            List<Type> types,
            OrcEncoding orcEncoding,
            CompressionKind compressionKind,
            Optional<DwrfWriterEncryption> encryption,
            DwrfEncryptionProvider dwrfEncryptionProvider,
            OrcWriterOptions options,
            Map<String, String> userMetadata,
            DateTimeZone hiveStorageTimeZone,
            boolean validate,
            OrcWriteValidationMode validationMode,
            WriterStats stats)
    {
        this(
                dataSink,
                columnNames,
                types,
                Optional.empty(),
                orcEncoding,
                compressionKind,
                encryption,
                dwrfEncryptionProvider,
                options,
                userMetadata,
                hiveStorageTimeZone,
                validate,
                validationMode,
                stats);
    }

    public OrcWriter(
            DataSink dataSink,
            List<String> columnNames,
            List<Type> types,
            Optional<List<OrcType>> inputOrcTypes,
            OrcEncoding orcEncoding,
            CompressionKind compressionKind,
            Optional<DwrfWriterEncryption> encryption,
            DwrfEncryptionProvider dwrfEncryptionProvider,
            OrcWriterOptions options,
            Map<String, String> userMetadata,
            DateTimeZone hiveStorageTimeZone,
            boolean validate,
            OrcWriteValidationMode validationMode,
            WriterStats stats)
    {
        this.validationBuilder = validate ? new OrcWriteValidation.OrcWriteValidationBuilder(validationMode, types).setStringStatisticsLimitInBytes(toIntExact(options.getMaxStringStatisticsLimit().toBytes())) : null;

        this.dataSink = requireNonNull(dataSink, "dataSink is null");
        this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
        this.orcEncoding = requireNonNull(orcEncoding, "orcEncoding is null");
        this.compressionBufferPool = new LastUsedCompressionBufferPool();

        requireNonNull(columnNames, "columnNames is null");
        requireNonNull(inputOrcTypes, "inputOrcTypes is null");
        this.orcTypes = inputOrcTypes.orElseGet(() -> OrcType.createOrcRowType(0, columnNames, types));
        this.nodeIdToColumn = createNodeIdToColumnMap(this.orcTypes);

        requireNonNull(compressionKind, "compressionKind is null");
        Set<Integer> flattenedNodes = mapColumnToNode(options.getFlattenedColumns(), orcTypes);
        this.columnWriterOptions = ColumnWriterOptions.builder()
                .setCompressionKind(compressionKind)
                .setCompressionLevel(options.getCompressionLevel())
                .setCompressionMaxBufferSize(options.getMaxCompressionBufferSize())
                .setMinOutputBufferChunkSize(options.getMinOutputBufferChunkSize())
                .setMaxOutputBufferChunkSize(options.getMaxOutputBufferChunkSize())
                .setStringStatisticsLimit(options.getMaxStringStatisticsLimit())
                .setIntegerDictionaryEncodingEnabled(options.isIntegerDictionaryEncodingEnabled())
                .setStringDictionarySortingEnabled(options.isStringDictionarySortingEnabled())
                .setStringDictionaryEncodingEnabled(options.isStringDictionaryEncodingEnabled())
                .setIgnoreDictionaryRowGroupSizes(options.isIgnoreDictionaryRowGroupSizes())
                .setPreserveDirectEncodingStripeCount(options.getPreserveDirectEncodingStripeCount())
                .setCompressionBufferPool(compressionBufferPool)
                .setFlattenedNodes(flattenedNodes)
                .setMapStatisticsEnabled(options.isMapStatisticsEnabled())
                .setMaxFlattenedMapKeyCount(options.getMaxFlattenedMapKeyCount())
                .setResetOutputBuffer(options.isResetOutputBuffer())
                .setLazyOutputBuffer(options.isLazyOutputBuffer())
                .build();
        recordValidation(validation -> validation.setCompression(compressionKind));
        recordValidation(validation -> validation.setFlattenedNodes(flattenedNodes));
        recordValidation(validation -> validation.setOrcTypes(orcTypes));

        requireNonNull(options, "options is null");
        this.flushPolicy = requireNonNull(options.getFlushPolicy(), "flushPolicy is null");
        this.rowGroupMaxRowCount = options.getRowGroupMaxRowCount();
        recordValidation(validation -> validation.setRowGroupMaxRowCount(rowGroupMaxRowCount));
        this.streamLayout = requireNonNull(options.getStreamLayoutFactory().create(), "streamLayout is null");

        this.userMetadata = ImmutableMap.<String, String>builder()
                .putAll(requireNonNull(userMetadata, "userMetadata is null"))
                .put(PRESTO_ORC_WRITER_VERSION_METADATA_KEY, PRESTO_ORC_WRITER_VERSION)
                .build();
        this.metadataWriter = new CompressedMetadataWriter(orcEncoding.createMetadataWriter(), columnWriterOptions, Optional.empty());
        this.hiveStorageTimeZone = requireNonNull(hiveStorageTimeZone, "hiveStorageTimeZone is null");
        this.stats = requireNonNull(stats, "stats is null");
        this.streamSizeHelper = new StreamSizeHelper(orcTypes, columnWriterOptions.getFlattenedNodes(), columnWriterOptions.isMapStatisticsEnabled());

        recordValidation(validation -> validation.setColumnNames(columnNames));

        dwrfWriterEncryption = requireNonNull(encryption, "encryption is null");
        this.dwrfEncryptionProvider = requireNonNull(dwrfEncryptionProvider, "dwrfEncryptionProvider is null");
        if (dwrfWriterEncryption.isPresent()) {
            List<WriterEncryptionGroup> writerEncryptionGroups = dwrfWriterEncryption.get().getWriterEncryptionGroups();
            Map<Integer, Integer> nodeToGroupMap = createNodeToGroupMap(
                    writerEncryptionGroups
                            .stream()
                            .map(WriterEncryptionGroup::getNodes)
                            .collect(toImmutableList()),
                    orcTypes);
            EncryptionLibrary encryptionLibrary = dwrfEncryptionProvider.getEncryptionLibrary(dwrfWriterEncryption.get().getKeyProvider());
            List<byte[]> dataEncryptionKeys = writerEncryptionGroups.stream()
                    .map(group -> encryptionLibrary.generateDataEncryptionKey(group.getIntermediateKeyMetadata().getBytes()))
                    .collect(toImmutableList());
            Map<Integer, DwrfDataEncryptor> dwrfEncryptors = IntStream.range(0, writerEncryptionGroups.size())
                    .boxed()
                    .collect(toImmutableMap(
                            groupId -> groupId,
                            groupId -> new DwrfDataEncryptor(dataEncryptionKeys.get(groupId), encryptionLibrary)));

            List<byte[]> encryptedKeyMetadatas = IntStream.range(0, writerEncryptionGroups.size())
                    .boxed()
                    .map(groupId -> encryptionLibrary.encryptKey(
                            writerEncryptionGroups.get(groupId).getIntermediateKeyMetadata().getBytes(),
                            dataEncryptionKeys.get(groupId),
                            0,
                            dataEncryptionKeys.get(groupId).length))
                    .collect(toImmutableList());
            this.dwrfEncryptionInfo = new DwrfEncryptionInfo(dwrfEncryptors, encryptedKeyMetadatas, nodeToGroupMap);
        }
        else {
            this.dwrfEncryptionInfo = UNENCRYPTED;
        }

        // set DwrfStripeCacheWriter for DWRF files if it's enabled through the options
        if (orcEncoding == DWRF) {
            this.dwrfStripeCacheWriter = options.getDwrfStripeCacheOptions()
                    .map(dwrfWriterOptions -> new DwrfStripeCacheWriter(
                            dwrfWriterOptions.getStripeCacheMode(),
                            dwrfWriterOptions.getStripeCacheMaxSize()));
        }
        else {
            this.dwrfStripeCacheWriter = Optional.empty();
        }

        // create column writers
        OrcType rootType = orcTypes.get(0);
        checkArgument(rootType.getFieldCount() == types.size());
        ImmutableList.Builder<ColumnWriter> columnWriters = ImmutableList.builder();
        ImmutableSet.Builder<DictionaryColumnWriter> dictionaryColumnWriters = ImmutableSet.builder();
        for (int columnIndex = 0; columnIndex < types.size(); columnIndex++) {
            int nodeIndex = rootType.getFieldTypeIndex(columnIndex);
            Type fieldType = types.get(columnIndex);
            ColumnWriter columnWriter = createColumnWriter(
                    nodeIndex,
                    DEFAULT_SEQUENCE_ID,
                    orcTypes,
                    fieldType,
                    columnWriterOptions,
                    orcEncoding,
                    hiveStorageTimeZone,
                    dwrfEncryptionInfo,
                    orcEncoding.createMetadataWriter());
            columnWriters.add(columnWriter);

            if (columnWriter instanceof DictionaryColumnWriter) {
                dictionaryColumnWriters.add((DictionaryColumnWriter) columnWriter);
            }
            else {
                for (ColumnWriter nestedColumnWriter : columnWriter.getNestedColumnWriters()) {
                    if (nestedColumnWriter instanceof DictionaryColumnWriter) {
                        dictionaryColumnWriters.add((DictionaryColumnWriter) nestedColumnWriter);
                    }
                }
            }
        }
        this.columnWriters = columnWriters.build();
        this.dictionaryMaxMemoryBytes = toIntExact(options.getDictionaryMaxMemory().toBytes());
        int dictionaryMemoryAlmostFullRangeBytes = toIntExact(options.getDictionaryMemoryAlmostFullRange().toBytes());
        int dictionaryUsefulCheckColumnSizeBytes = toIntExact(options.getDictionaryUsefulCheckColumnSize().toBytes());
        this.dictionaryCompressionOptimizer = new DictionaryCompressionOptimizer(
                dictionaryColumnWriters.build(),
                flushPolicy.getStripeMinBytes(),
                flushPolicy.getStripeMaxBytes(),
                flushPolicy.getStripeMaxRowCount(),
                dictionaryMaxMemoryBytes,
                dictionaryMemoryAlmostFullRangeBytes,
                dictionaryUsefulCheckColumnSizeBytes,
                options.getDictionaryUsefulCheckPerChunkFrequency());

        for (Entry<String, String> entry : this.userMetadata.entrySet()) {
            recordValidation(validation -> validation.addMetadataProperty(entry.getKey(), utf8Slice(entry.getValue())));
        }

        this.previouslyRecordedSizeInBytes = getRetainedBytes();
        stats.updateSizeInBytes(previouslyRecordedSizeInBytes);
    }

    @VisibleForTesting
    List<ColumnWriter> getColumnWriters()
    {
        return columnWriters;
    }

    @VisibleForTesting
    DictionaryCompressionOptimizer getDictionaryCompressionOptimizer()
    {
        return dictionaryCompressionOptimizer;
    }

    /**
     * Number of bytes already flushed to the data sink.
     */
    public long getWrittenBytes()
    {
        return dataSink.size();
    }

    /**
     * Number of pending bytes not yet flushed.
     */
    public int getBufferedBytes()
    {
        return bufferedBytes;
    }

    public long getRetainedBytes()
    {
        return INSTANCE_SIZE +
                columnWritersRetainedBytes +
                closedStripesRetainedBytes +
                dataSink.getRetainedSizeInBytes() +
                compressionBufferPool.getRetainedBytes() +
                (validationBuilder == null ? 0 : validationBuilder.getRetainedSize());
    }

    public void write(Page page)
            throws IOException
    {
        requireNonNull(page, "page is null");
        if (page.getPositionCount() == 0) {
            return;
        }

        checkArgument(page.getChannelCount() == columnWriters.size());

        if (validationBuilder != null) {
            validationBuilder.addPage(page);
        }

        int maxChunkRowCount = flushPolicy.getMaxChunkRowCount(page);

        while (page != null) {
            // logical size and row group boundaries
            int chunkRows = min(maxChunkRowCount, min(rowGroupMaxRowCount - rowGroupRowCount, flushPolicy.getStripeMaxRowCount() - stripeRowCount));

            // align page to max size per chunk
            chunkRows = min(page.getPositionCount(), chunkRows);

            Page chunk = page.getRegion(0, chunkRows);

            if (chunkRows < page.getPositionCount()) {
                page = page.getRegion(chunkRows, page.getPositionCount() - chunkRows);
            }
            else {
                page = null;
            }
            writeChunk(chunk);
        }

        long recordedSizeInBytes = getRetainedBytes();
        stats.updateSizeInBytes(recordedSizeInBytes - previouslyRecordedSizeInBytes);
        previouslyRecordedSizeInBytes = recordedSizeInBytes;
    }

    private void writeChunk(Page chunk)
            throws IOException
    {
        if (rowGroupRowCount == 0) {
            columnWriters.forEach(ColumnWriter::beginRowGroup);
        }

        // write chunks
        bufferedBytes = 0;
        for (int channel = 0; channel < chunk.getChannelCount(); channel++) {
            ColumnWriter writer = columnWriters.get(channel);
            stripeRawSize += writer.writeBlock(chunk.getBlock(channel));
            bufferedBytes += writer.getBufferedBytes();
        }

        // update stats
        rowGroupRowCount += chunk.getPositionCount();
        checkState(rowGroupRowCount <= rowGroupMaxRowCount);
        stripeRowCount += chunk.getPositionCount();

        // record checkpoint if necessary
        if (rowGroupRowCount == rowGroupMaxRowCount) {
            finishRowGroup();
        }

        // convert dictionary encoded columns to direct if dictionary memory usage exceeded
        dictionaryCompressionOptimizer.optimize(bufferedBytes, stripeRowCount);

        // flush stripe if necessary
        bufferedBytes = toIntExact(columnWriters.stream().mapToLong(ColumnWriter::getBufferedBytes).sum());
        boolean dictionaryIsFull = dictionaryCompressionOptimizer.isFull(bufferedBytes);
        Optional<FlushReason> flushReason = flushPolicy.shouldFlushStripe(stripeRowCount, bufferedBytes, dictionaryIsFull);
        if (flushReason.isPresent()) {
            flushStripe(flushReason.get());
        }
        columnWritersRetainedBytes = columnWriters.stream().mapToLong(ColumnWriter::getRetainedBytes).sum();
    }

    private void finishRowGroup()
    {
        Map<Integer, ColumnStatistics> columnStatistics = new HashMap<>();
        columnWriters.forEach(columnWriter -> columnStatistics.putAll(columnWriter.finishRowGroup()));
        recordValidation(validation -> validation.addRowGroupStatistics(columnStatistics));
        rowGroupRowCount = 0;
    }

    private void flushStripe(FlushReason flushReason)
            throws IOException
    {
        List<DataOutput> outputData = new ArrayList<>();
        long stripeStartOffset = dataSink.size();
        // add header to first stripe (this is not required but nice to have)
        if (closedStripes.isEmpty()) {
            outputData.add(createDataOutput(MAGIC));
            stripeStartOffset += MAGIC.length();
        }

        flushColumnWriters(flushReason);
        try {
            // add stripe data
            outputData.addAll(bufferStripeData(stripeStartOffset, flushReason));
            rawSize += stripeRawSize;
            // if the file is being closed, add the file footer
            if (flushReason == CLOSED) {
                outputData.addAll(bufferFileFooter());
            }

            // write all data
            dataSink.write(outputData);
        }
        finally {
            // open next stripe
            columnWriters.forEach(ColumnWriter::reset);
            dictionaryCompressionOptimizer.reset();
            rowGroupRowCount = 0;
            stripeRowCount = 0;
            stripeRawSize = 0;
            bufferedBytes = toIntExact(columnWriters.stream().mapToLong(ColumnWriter::getBufferedBytes).sum());
        }
    }

    private void flushColumnWriters(FlushReason flushReason)
    {
        if (stripeRowCount == 0) {
            verify(flushReason == CLOSED, "An empty stripe is not allowed");
        }
        else {
            if (rowGroupRowCount > 0) {
                finishRowGroup();
            }

            // convert any dictionary encoded column with a low compression ratio to direct
            dictionaryCompressionOptimizer.finalOptimize(bufferedBytes);
        }

        columnWriters.forEach(ColumnWriter::close);
    }

    /**
     * Collect the data for the stripe.  This is not the actual data, but
     * instead are functions that know how to write the data.
     */
    private List<DataOutput> bufferStripeData(long stripeStartOffset, FlushReason flushReason)
            throws IOException
    {
        if (stripeRowCount == 0) {
            return ImmutableList.of();
        }

        List<Stream> unencryptedStreams = new ArrayList<>(columnWriters.size() * 3);
        Multimap<Integer, Stream> encryptedStreams = ArrayListMultimap.create();
        List<StreamDataOutput> indexStreams = new ArrayList<>(columnWriters.size());

        // get index streams
        long indexLength = 0;
        long offset = 0;
        int previousEncryptionGroup = -1;
        for (ColumnWriter columnWriter : columnWriters) {
            List<StreamDataOutput> streams = columnWriter.getIndexStreams(Optional.empty());
            indexStreams.addAll(streams);
            for (StreamDataOutput indexStream : streams) {
                // The ordering is critical because the stream only contain a length with no offset.
                // if the previous stream was part of a different encryption group, need to specify an offset so we know the column order
                Optional<Integer> encryptionGroup = dwrfEncryptionInfo.getGroupByNodeId(indexStream.getStream().getColumn());
                if (encryptionGroup.isPresent()) {
                    Stream stream = previousEncryptionGroup == encryptionGroup.get() ? indexStream.getStream() : indexStream.getStream().withOffset(offset);
                    encryptedStreams.put(encryptionGroup.get(), stream);
                    previousEncryptionGroup = encryptionGroup.get();
                }
                else {
                    Stream stream = previousEncryptionGroup == -1 ? indexStream.getStream() : indexStream.getStream().withOffset(offset);
                    unencryptedStreams.add(stream);
                    previousEncryptionGroup = -1;
                }
                offset += indexStream.size();
                indexLength += indexStream.size();
            }
        }

        if (dwrfStripeCacheWriter.isPresent()) {
            dwrfStripeCacheWriter.get().addIndexStreams(ImmutableList.copyOf(indexStreams), indexLength);
        }

        // data streams (sorted by size)
        long dataLength = 0;
        List<StreamDataOutput> dataStreams = new ArrayList<>(columnWriters.size() * 2);
        for (ColumnWriter columnWriter : columnWriters) {
            List<StreamDataOutput> streams = columnWriter.getDataStreams();
            dataStreams.addAll(streams);
            dataLength += streams.stream()
                    .mapToLong(StreamDataOutput::size)
                    .sum();
        }

        ImmutableMap.Builder<Integer, ColumnEncoding> columnEncodingsBuilder = ImmutableMap.builder();
        columnEncodingsBuilder.put(0, new ColumnEncoding(DIRECT, 0));
        columnWriters.forEach(columnWriter -> columnEncodingsBuilder.putAll(columnWriter.getColumnEncodings()));
        Map<Integer, ColumnEncoding> columnEncodings = columnEncodingsBuilder.build();

        // reorder data streams
        streamLayout.reorder(dataStreams, nodeIdToColumn, columnEncodings);
        streamSizeHelper.collectStreamSizes(Iterables.concat(indexStreams, dataStreams), columnEncodings);

        // add data streams
        for (StreamDataOutput dataStream : dataStreams) {
            // The ordering is critical because the stream only contains a length with no offset.
            // if the previous stream was part of a different encryption group, need to specify an offset so we know the column order
            Optional<Integer> encryptionGroup = dwrfEncryptionInfo.getGroupByNodeId(dataStream.getStream().getColumn());
            if (encryptionGroup.isPresent()) {
                Stream stream = previousEncryptionGroup == encryptionGroup.get() ? dataStream.getStream() : dataStream.getStream().withOffset(offset);
                encryptedStreams.put(encryptionGroup.get(), stream);
                previousEncryptionGroup = encryptionGroup.get();
            }
            else {
                Stream stream = previousEncryptionGroup == -1 ? dataStream.getStream() : dataStream.getStream().withOffset(offset);
                unencryptedStreams.add(stream);
                previousEncryptionGroup = -1;
            }
            offset += dataStream.size();
        }

        Map<Integer, ColumnStatistics> columnStatistics = new HashMap<>();
        columnWriters.forEach(columnWriter -> columnStatistics.putAll(columnWriter.getColumnStripeStatistics()));

        // the 0th column is a struct column for the whole row
        columnStatistics.put(0, new ColumnStatistics((long) stripeRowCount, null, stripeRawSize, null));

        Map<Integer, ColumnEncoding> unencryptedColumnEncodings = columnEncodings.entrySet().stream()
                .filter(entry -> !dwrfEncryptionInfo.getGroupByNodeId(entry.getKey()).isPresent())
                .collect(toImmutableMap(Entry::getKey, Entry::getValue));

        Map<Integer, ColumnEncoding> encryptedColumnEncodings = columnEncodings.entrySet().stream()
                .filter(entry -> dwrfEncryptionInfo.getGroupByNodeId(entry.getKey()).isPresent())
                .collect(toImmutableMap(Entry::getKey, Entry::getValue));
        List<Slice> encryptedGroups = createEncryptedGroups(encryptedStreams, encryptedColumnEncodings);

        StripeFooter stripeFooter = new StripeFooter(unencryptedStreams, unencryptedColumnEncodings, encryptedGroups);
        Slice footer = metadataWriter.writeStripeFooter(stripeFooter);
        DataOutput footerDataOutput = createDataOutput(footer);
        dwrfStripeCacheWriter.ifPresent(stripeCacheWriter -> stripeCacheWriter.addStripeFooter(createDataOutput(footer)));

        // create final stripe statistics
        StripeStatistics statistics = new StripeStatistics(toDenseList(columnStatistics, orcTypes.size()));

        recordValidation(validation -> validation.addStripeStatistics(stripeStartOffset, statistics));

        StripeInformation stripeInformation = new StripeInformation(stripeRowCount, stripeStartOffset, indexLength, dataLength, footer.length(), OptionalLong.of(stripeRawSize), dwrfEncryptionInfo.getEncryptedKeyMetadatas());
        ClosedStripe closedStripe = new ClosedStripe(stripeInformation, statistics);
        closedStripes.add(closedStripe);
        closedStripesRetainedBytes += closedStripe.getRetainedSizeInBytes();

        recordValidation(validation -> validation.addStripe(stripeInformation.getNumberOfRows()));
        stats.recordStripeWritten(
                flushPolicy.getStripeMinBytes(),
                flushPolicy.getStripeMaxBytes(),
                dictionaryMaxMemoryBytes,
                flushReason,
                dictionaryCompressionOptimizer.getDictionaryMemoryBytes(),
                stripeInformation);

        return ImmutableList.<DataOutput>builder()
                .addAll(indexStreams)
                .addAll(dataStreams)
                .add(footerDataOutput)
                .build();
    }

    private List<Slice> createEncryptedGroups(Multimap<Integer, Stream> encryptedStreams, Map<Integer, ColumnEncoding> encryptedColumnEncodings)
            throws IOException
    {
        ImmutableList.Builder<Slice> encryptedGroups = ImmutableList.builder();
        for (int i = 0; i < encryptedStreams.keySet().size(); i++) {
            int groupId = i;
            Map<Integer, ColumnEncoding> groupColumnEncodings = encryptedColumnEncodings.entrySet().stream()
                    .filter(entry -> dwrfEncryptionInfo.getGroupByNodeId(entry.getKey()).orElseThrow(() -> new VerifyError("missing group for encryptedColumn")) == groupId)
                    .collect(toImmutableMap(Entry::getKey, Entry::getValue));
            DwrfDataEncryptor dwrfDataEncryptor = dwrfEncryptionInfo.getEncryptorByGroupId(i);
            OrcOutputBuffer buffer = new OrcOutputBuffer(columnWriterOptions, Optional.of(dwrfDataEncryptor));
            toStripeEncryptionGroup(
                    new StripeEncryptionGroup(
                            ImmutableList.copyOf(encryptedStreams.get(i)),
                            groupColumnEncodings))
                    .writeTo(buffer);
            buffer.close();
            DynamicSliceOutput output = new DynamicSliceOutput(toIntExact(buffer.getOutputDataSize()));
            buffer.writeDataTo(output);
            encryptedGroups.add(output.slice());
        }
        return encryptedGroups.build();
    }

    @Override
    public void close()
            throws IOException
    {
        if (closed) {
            return;
        }
        closed = true;
        stats.updateSizeInBytes(-previouslyRecordedSizeInBytes);
        previouslyRecordedSizeInBytes = 0;

        flushStripe(CLOSED);

        dataSink.close();
    }

    /**
     * Collect the data for the file footer.  This is not the actual data, but
     * instead are functions that know how to write the data.
     */
    private List<DataOutput> bufferFileFooter()
            throws IOException
    {
        List<DataOutput> outputData = new ArrayList<>();

        Metadata metadata = new Metadata(closedStripes.stream()
                .map(ClosedStripe::getStatistics)
                .collect(toList()));
        Slice metadataSlice = metadataWriter.writeMetadata(metadata);
        outputData.add(createDataOutput(metadataSlice));

        numberOfRows = closedStripes.stream()
                .mapToLong(stripe -> stripe.getStripeInformation().getNumberOfRows())
                .sum();

        List<ColumnStatistics> fileStats = toFileStats(
                closedStripes.stream()
                        .map(ClosedStripe::getStatistics)
                        .map(StripeStatistics::getColumnStatistics)
                        .collect(toList()),
                streamSizeHelper.getNodeSizes(),
                streamSizeHelper.getMapKeySizes());
        recordValidation(validation -> validation.setFileStatistics(fileStats));

        Map<String, Slice> userMetadata = this.userMetadata.entrySet().stream()
                .collect(Collectors.toMap(Entry::getKey, entry -> utf8Slice(entry.getValue())));

        unencryptedStats = new ArrayList<>();
        Map<Integer, Map<Integer, Slice>> encryptedStats = new HashMap<>();
        addStatsRecursive(fileStats, 0, new HashMap<>(), unencryptedStats, encryptedStats);
        Optional<DwrfEncryption> dwrfEncryption;
        if (dwrfWriterEncryption.isPresent()) {
            ImmutableList.Builder<EncryptionGroup> encryptionGroupBuilder = ImmutableList.builder();
            List<WriterEncryptionGroup> writerEncryptionGroups = dwrfWriterEncryption.get().getWriterEncryptionGroups();
            for (int i = 0; i < writerEncryptionGroups.size(); i++) {
                WriterEncryptionGroup group = writerEncryptionGroups.get(i);
                Map<Integer, Slice> groupStats = encryptedStats.get(i);
                encryptionGroupBuilder.add(
                        new EncryptionGroup(
                                group.getNodes(),
                                Optional.empty(), // reader will just use key metadata from the stripe
                                group.getNodes().stream()
                                        .map(groupStats::get)
                                        .collect(toList())));
            }
            dwrfEncryption = Optional.of(
                    new DwrfEncryption(
                            dwrfWriterEncryption.get().getKeyProvider(),
                            encryptionGroupBuilder.build()));
        }
        else {
            dwrfEncryption = Optional.empty();
        }

        Optional<DwrfStripeCacheData> dwrfStripeCacheData = dwrfStripeCacheWriter.map(DwrfStripeCacheWriter::getDwrfStripeCacheData);
        Slice dwrfStripeCacheSlice = metadataWriter.writeDwrfStripeCache(dwrfStripeCacheData);
        outputData.add(createDataOutput(dwrfStripeCacheSlice));

        Optional<List<Integer>> dwrfStripeCacheOffsets = dwrfStripeCacheWriter.map(DwrfStripeCacheWriter::getOffsets);
        Footer footer = new Footer(
                numberOfRows,
                rowGroupMaxRowCount,
                OptionalLong.of(rawSize),
                closedStripes.stream()
                        .map(ClosedStripe::getStripeInformation)
                        .collect(toList()),
                orcTypes,
                ImmutableList.copyOf(unencryptedStats),
                userMetadata,
                dwrfEncryption,
                dwrfStripeCacheOffsets);

        closedStripes.clear();
        closedStripesRetainedBytes = 0;

        Slice footerSlice = metadataWriter.writeFooter(footer);
        outputData.add(createDataOutput(footerSlice));

        recordValidation(validation -> validation.setVersion(metadataWriter.getOrcMetadataVersion()));
        Slice postscriptSlice = metadataWriter.writePostscript(
                footerSlice.length(),
                metadataSlice.length(),
                columnWriterOptions.getCompressionKind(),
                columnWriterOptions.getCompressionMaxBufferSize(),
                dwrfStripeCacheData);
        outputData.add(createDataOutput(postscriptSlice));
        outputData.add(createDataOutput(Slices.wrappedBuffer((byte) postscriptSlice.length())));
        return outputData;
    }

    private void addStatsRecursive(List<ColumnStatistics> allStats, int index, Map<Integer, List<ColumnStatistics>> nodeAndSubNodeStats, List<ColumnStatistics> unencryptedStats, Map<Integer, Map<Integer, Slice>> encryptedStats)
            throws IOException
    {
        if (allStats.isEmpty()) {
            return;
        }
        ColumnStatistics columnStatistics = allStats.get(index);
        if (dwrfEncryptionInfo.getGroupByNodeId(index).isPresent()) {
            int group = dwrfEncryptionInfo.getGroupByNodeId(index).get();
            boolean isRootNode = dwrfWriterEncryption.get().getWriterEncryptionGroups().get(group).getNodes().contains(index);
            verify(isRootNode && nodeAndSubNodeStats.isEmpty() || nodeAndSubNodeStats.size() == 1 && nodeAndSubNodeStats.get(group) != null,
                    "nodeAndSubNodeStats should only be present for subnodes of a group");
            nodeAndSubNodeStats.computeIfAbsent(group, x -> new ArrayList<>()).add(columnStatistics);
            unencryptedStats.add(new ColumnStatistics(
                    columnStatistics.getNumberOfValues(),
                    null,
                    columnStatistics.hasRawSize() ? columnStatistics.getRawSize() : null,
                    columnStatistics.hasStorageSize() ? columnStatistics.getStorageSize() : null));
            for (Integer fieldIndex : orcTypes.get(index).getFieldTypeIndexes()) {
                addStatsRecursive(allStats, fieldIndex, nodeAndSubNodeStats, unencryptedStats, encryptedStats);
            }
            if (isRootNode) {
                Slice encryptedFileStatistics = toEncryptedFileStatistics(nodeAndSubNodeStats.get(group), group);
                encryptedStats.computeIfAbsent(group, x -> new HashMap<>()).put(index, encryptedFileStatistics);
            }
        }
        else {
            unencryptedStats.add(columnStatistics);
            for (Integer fieldIndex : orcTypes.get(index).getFieldTypeIndexes()) {
                addStatsRecursive(allStats, fieldIndex, new HashMap<>(), unencryptedStats, encryptedStats);
            }
        }
    }

    private Slice toEncryptedFileStatistics(List<ColumnStatistics> statsFromRoot, int groupId)
            throws IOException
    {
        DwrfProto.FileStatistics fileStatistics = toFileStatistics(statsFromRoot);
        DwrfDataEncryptor dwrfDataEncryptor = dwrfEncryptionInfo.getEncryptorByGroupId(groupId);
        OrcOutputBuffer buffer = new OrcOutputBuffer(columnWriterOptions, Optional.of(dwrfDataEncryptor));
        fileStatistics.writeTo(buffer);
        buffer.close();
        DynamicSliceOutput output = new DynamicSliceOutput(toIntExact(buffer.getOutputDataSize()));
        buffer.writeDataTo(output);
        return output.slice();
    }

    private void recordValidation(Consumer<OrcWriteValidationBuilder> task)
    {
        if (validationBuilder != null) {
            task.accept(validationBuilder);
        }
    }

    public void validate(OrcDataSource input)
            throws OrcCorruptionException
    {
        checkState(validationBuilder != null, "validation is not enabled");
        ImmutableMap.Builder<Integer, Slice> intermediateKeyMetadata = ImmutableMap.builder();
        if (dwrfWriterEncryption.isPresent()) {
            List<WriterEncryptionGroup> writerEncryptionGroups = dwrfWriterEncryption.get().getWriterEncryptionGroups();
            for (int i = 0; i < writerEncryptionGroups.size(); i++) {
                for (Integer node : writerEncryptionGroups.get(i).getNodes()) {
                    intermediateKeyMetadata.put(node, writerEncryptionGroups.get(i).getIntermediateKeyMetadata());
                }
            }
        }

        validateFile(
                validationBuilder.build(),
                input,
                types,
                hiveStorageTimeZone,
                orcEncoding,
                OrcReaderOptions.builder()
                        .withMaxMergeDistance(new DataSize(1, MEGABYTE))
                        .withTinyStripeThreshold(new DataSize(8, MEGABYTE))
                        .withMaxBlockSize(new DataSize(16, MEGABYTE))
                        .build(),
                dwrfEncryptionProvider,
                DwrfKeyProvider.of(intermediateKeyMetadata.build()));
    }

    public long getFileRowCount()
    {
        checkState(closed, "File row count is not available until the writing has finished");
        return numberOfRows;
    }

    public List<ColumnStatistics> getFileStats()
    {
        checkState(closed, "File statistics are not available until the writing has finished");
        return unencryptedStats;
    }

    private static <T> List<T> toDenseList(Map<Integer, T> data, int expectedSize)
    {
        checkArgument(data.size() == expectedSize);
        if (expectedSize == 0) {
            return ImmutableList.of();
        }

        List<Integer> sortedKeys = new ArrayList<>(data.keySet());
        Collections.sort(sortedKeys);

        ImmutableList.Builder<T> denseList = ImmutableList.builderWithExpectedSize(expectedSize);
        for (Integer key : sortedKeys) {
            denseList.add(data.get(key));
        }
        return denseList.build();
    }

    private static List<ColumnStatistics> toFileStats(List<List<ColumnStatistics>> stripes, Int2LongMap nodeSizes, Int2ObjectMap<Object2LongMap<DwrfProto.KeyInfo>> mapKeySizes)
    {
        if (stripes.isEmpty()) {
            return ImmutableList.of();
        }

        int columnCount = stripes.get(0).size();
        checkArgument(stripes.stream().allMatch(stripe -> columnCount == stripe.size()));

        ImmutableList.Builder<ColumnStatistics> fileStats = ImmutableList.builder();
        for (int i = 0; i < columnCount; i++) {
            int column = i;
            List<ColumnStatistics> stripeColumnStats = stripes.stream()
                    .map(stripe -> stripe.get(column))
                    .collect(toList());
            long storageSize = nodeSizes.getOrDefault(column, 0L);
            Object2LongMap<DwrfProto.KeyInfo> keySizes = mapKeySizes.get(column);
            ColumnStatistics columnStats = mergeColumnStatistics(stripeColumnStats, storageSize, keySizes);
            fileStats.add(columnStats);
        }
        return fileStats.build();
    }

    private static class ClosedStripe
    {
        private static final int INSTANCE_SIZE = ClassLayout.parseClass(ClosedStripe.class).instanceSize() + ClassLayout.parseClass(StripeInformation.class).instanceSize();

        private final StripeInformation stripeInformation;
        private final StripeStatistics statistics;

        public ClosedStripe(StripeInformation stripeInformation, StripeStatistics statistics)
        {
            this.stripeInformation = requireNonNull(stripeInformation, "stripeInformation is null");
            this.statistics = requireNonNull(statistics, "stripeStatistics is null");
        }

        public StripeInformation getStripeInformation()
        {
            return stripeInformation;
        }

        public StripeStatistics getStatistics()
        {
            return statistics;
        }

        public long getRetainedSizeInBytes()
        {
            return INSTANCE_SIZE + statistics.getRetainedSizeInBytes();
        }
    }
}