TestDecryption.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.orc;

import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.orc.cache.StorageOrcFileTailSource;
import com.facebook.presto.orc.metadata.DwrfEncryption;
import com.facebook.presto.orc.metadata.EncryptionGroup;
import com.facebook.presto.orc.metadata.Footer;
import com.facebook.presto.orc.metadata.OrcType;
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.metadata.StripeInformation;
import com.facebook.presto.orc.metadata.statistics.ColumnStatistics;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.orc.AbstractOrcRecordReader.getDecryptionKeyMetadata;
import static com.facebook.presto.orc.AbstractTestOrcReader.intsBetween;
import static com.facebook.presto.orc.DwrfEncryptionInfo.createNodeToGroupMap;
import static com.facebook.presto.orc.NoOpOrcWriterStats.NOOP_WRITER_STATS;
import static com.facebook.presto.orc.NoopOrcAggregatedMemoryContext.NOOP_ORC_AGGREGATED_MEMORY_CONTEXT;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static com.facebook.presto.orc.OrcReader.MAX_BATCH_SIZE;
import static com.facebook.presto.orc.OrcReader.validateEncryption;
import static com.facebook.presto.orc.OrcTester.HIVE_STORAGE_TIME_ZONE;
import static com.facebook.presto.orc.OrcTester.MAX_BLOCK_SIZE;
import static com.facebook.presto.orc.OrcTester.assertFileContentsPresto;
import static com.facebook.presto.orc.OrcTester.rowType;
import static com.facebook.presto.orc.OrcTester.writeOrcColumnsPresto;
import static com.facebook.presto.orc.StripeReader.getDiskRanges;
import static com.facebook.presto.orc.metadata.ColumnEncoding.DEFAULT_SEQUENCE_ID;
import static com.facebook.presto.orc.metadata.CompressionKind.ZSTD;
import static com.facebook.presto.orc.metadata.KeyProvider.UNKNOWN;
import static com.facebook.presto.orc.metadata.OrcType.OrcTypeKind.INT;
import static com.facebook.presto.orc.metadata.OrcType.OrcTypeKind.LIST;
import static com.facebook.presto.orc.metadata.OrcType.OrcTypeKind.MAP;
import static com.facebook.presto.orc.metadata.OrcType.OrcTypeKind.STRUCT;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.DATA;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.ROW_INDEX;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

public class TestDecryption
{
    private static final List<byte[]> A_KEYS = ImmutableList.of("key1a".getBytes(), "key2a".getBytes());
    private static final List<byte[]> B_KEYS = ImmutableList.of("key1b".getBytes(), "key2b".getBytes());
    private static final StripeInformation A_STRIPE = new StripeInformation(1, 2, 3, 4, 5, OptionalLong.empty(), A_KEYS);
    private static final StripeInformation NO_KEYS_STRIPE = new StripeInformation(1, 2, 3, 4, 5, OptionalLong.empty(), ImmutableList.of());
    private static final StripeInformation B_STRIPE = new StripeInformation(1, 2, 3, 4, 5, OptionalLong.empty(), B_KEYS);

    private static final OrcType ROW_TYPE = new OrcType(STRUCT, ImmutableList.of(1, 2, 4, 7), ImmutableList.of("col_int", "col_list", "col_map", "col_row"), Optional.empty(), Optional.empty(), Optional.empty());
    private static final OrcType ROW_TYPE2 = new OrcType(STRUCT, ImmutableList.of(8), ImmutableList.of("sub_row1"), Optional.empty(), Optional.empty(), Optional.empty());
    private static final OrcType ROW_TYPE3 = new OrcType(STRUCT, ImmutableList.of(9, 10), ImmutableList.of("sub_int1", "sub_int2"), Optional.empty(), Optional.empty(), Optional.empty());
    private static final OrcType INT_TYPE = new OrcType(INT, ImmutableList.of(), ImmutableList.of(), Optional.empty(), Optional.empty(), Optional.empty());
    private static final OrcType LIST_TYPE = new OrcType(LIST, ImmutableList.of(3), ImmutableList.of(), Optional.empty(), Optional.empty(), Optional.empty());
    private static final OrcType MAP_TYPE = new OrcType(MAP, ImmutableList.of(5, 6), ImmutableList.of(), Optional.empty(), Optional.empty(), Optional.empty());

    @Test
    public void testValidateEncrypted()
    {
        List<EncryptionGroup> encryptionGroups = ImmutableList.of(
                new EncryptionGroup(
                        ImmutableList.of(1, 3),
                        Optional.empty(),
                        ImmutableList.of(Slices.EMPTY_SLICE, Slices.EMPTY_SLICE)),
                new EncryptionGroup(
                        ImmutableList.of(4),
                        Optional.empty(),
                        ImmutableList.of(Slices.EMPTY_SLICE)));

        Optional<DwrfEncryption> encryption = Optional.of(new DwrfEncryption(UNKNOWN, encryptionGroups));

        Footer footer = createFooterWithEncryption(ImmutableList.of(A_STRIPE, NO_KEYS_STRIPE), encryption);
        validateEncryption(footer, new OrcDataSourceId("1"));
    }

    @Test
    public void testValidateUnencrypted()
    {
        Footer footer = createFooterWithEncryption(ImmutableList.of(NO_KEYS_STRIPE), Optional.empty());
        validateEncryption(footer, new OrcDataSourceId("1"));
    }

    @Test(expectedExceptions = OrcCorruptionException.class)
    public void testValidateMissingStripeKeys()
    {
        List<EncryptionGroup> encryptionGroups = ImmutableList.of(
                new EncryptionGroup(
                        ImmutableList.of(1, 3),
                        Optional.empty(),
                        ImmutableList.of(Slices.EMPTY_SLICE, Slices.EMPTY_SLICE)),
                new EncryptionGroup(
                        ImmutableList.of(4),
                        Optional.empty(),
                        ImmutableList.of(Slices.EMPTY_SLICE)));

        Optional<DwrfEncryption> encryption = Optional.of(new DwrfEncryption(UNKNOWN, encryptionGroups));
        Footer footer = createFooterWithEncryption(ImmutableList.of(NO_KEYS_STRIPE), encryption);
        validateEncryption(footer, new OrcDataSourceId("1"));
    }

    @Test(expectedExceptions = OrcCorruptionException.class)
    public void testValidateMismatchedGroups()
    {
        List<EncryptionGroup> encryptionGroups = ImmutableList.of(
                new EncryptionGroup(
                        ImmutableList.of(1, 3),
                        Optional.empty(),
                        ImmutableList.of(Slices.EMPTY_SLICE, Slices.EMPTY_SLICE)));

        Optional<DwrfEncryption> encryption = Optional.of(new DwrfEncryption(UNKNOWN, encryptionGroups));
        Footer footer = createFooterWithEncryption(ImmutableList.of(A_STRIPE), encryption);
        validateEncryption(footer, new OrcDataSourceId("1"));
    }

    private Footer createFooterWithEncryption(List<StripeInformation> stripes, Optional<DwrfEncryption> encryption)
    {
        List<OrcType> types = ImmutableList.of(ROW_TYPE, INT_TYPE, LIST_TYPE, INT_TYPE, MAP_TYPE, INT_TYPE, INT_TYPE);

        return new Footer(
                1,
                2,
                OptionalLong.empty(),
                stripes,
                types,
                ImmutableList.of(),
                ImmutableMap.of(),
                encryption,
                Optional.empty());
    }

    @Test
    public void testCreateNodeToGroupMap()
    {
        List<List<Integer>> encryptionGroups = ImmutableList.of(
                ImmutableList.of(1, 3),
                ImmutableList.of(4),
                ImmutableList.of(7));

        List<OrcType> types = ImmutableList.of(ROW_TYPE, INT_TYPE, LIST_TYPE, INT_TYPE, MAP_TYPE, INT_TYPE, INT_TYPE, ROW_TYPE2, ROW_TYPE3, INT_TYPE, INT_TYPE);
        Map<Integer, Integer> actual = createNodeToGroupMap(encryptionGroups, types);
        Map<Integer, Integer> expected = ImmutableMap.<Integer, Integer>builder()
                .put(1, 0)
                .put(3, 0)
                .put(4, 1)
                .put(5, 1)
                .put(6, 1)
                .put(7, 2)
                .put(8, 2)
                .put(9, 2)
                .put(10, 2)
                .build();
        assertEquals(actual, expected);
    }

    @Test
    public void testGetStripeDecryptionKeys()
    {
        List<StripeInformation> encryptedStripes = ImmutableList.of(A_STRIPE, NO_KEYS_STRIPE, B_STRIPE, NO_KEYS_STRIPE);
        assertEquals(getDecryptionKeyMetadata(0, encryptedStripes), A_KEYS);
        assertEquals(getDecryptionKeyMetadata(1, encryptedStripes), A_KEYS);
        assertEquals(getDecryptionKeyMetadata(2, encryptedStripes), B_KEYS);
        assertEquals(getDecryptionKeyMetadata(3, encryptedStripes), B_KEYS);
    }

    @Test
    public void testGetStripeDecryptionKeysUnencrypted()
    {
        List<StripeInformation> unencryptedStripes = ImmutableList.of(NO_KEYS_STRIPE, NO_KEYS_STRIPE);
        assertEquals(getDecryptionKeyMetadata(0, unencryptedStripes), ImmutableList.of());
        assertEquals(getDecryptionKeyMetadata(1, unencryptedStripes), ImmutableList.of());
    }

    @Test
    public void testGetDiskRanges()
    {
        List<Stream> unencryptedStreams = ImmutableList.of(
                new Stream(3, ROW_INDEX, 5, true, DEFAULT_SEQUENCE_ID, Optional.of(15L)),
                new Stream(4, DEFAULT_SEQUENCE_ID, ROW_INDEX, 5, true),
                new Stream(3, DATA, 5, true, DEFAULT_SEQUENCE_ID, Optional.of(45L)),
                new Stream(4, DEFAULT_SEQUENCE_ID, DATA, 5, true));

        List<Stream> group1Streams = ImmutableList.of(
                new Stream(0, DEFAULT_SEQUENCE_ID, ROW_INDEX, 5, true),
                new Stream(5, ROW_INDEX, 5, true, DEFAULT_SEQUENCE_ID, Optional.of(25L)),
                new Stream(0, DATA, 5, true, DEFAULT_SEQUENCE_ID, Optional.of(30L)),
                new Stream(5, DATA, 5, true, DEFAULT_SEQUENCE_ID, Optional.of(55L)));

        List<Stream> group2Streams = ImmutableList.of(
                new Stream(1, ROW_INDEX, 5, true, DEFAULT_SEQUENCE_ID, Optional.of(5L)),
                new Stream(2, DEFAULT_SEQUENCE_ID, ROW_INDEX, 5, true),
                new Stream(1, DATA, 5, true, DEFAULT_SEQUENCE_ID, Optional.of(35L)),
                new Stream(2, DEFAULT_SEQUENCE_ID, DATA, 5, true));

        Map<StreamId, DiskRange> actual = getDiskRanges(ImmutableList.of(unencryptedStreams, group1Streams, group2Streams));

        Map<StreamId, DiskRange> expected = ImmutableMap.<StreamId, DiskRange>builder()
                .put(new StreamId(0, DEFAULT_SEQUENCE_ID, ROW_INDEX), new DiskRange(0, 5))
                .put(new StreamId(1, DEFAULT_SEQUENCE_ID, ROW_INDEX), new DiskRange(5, 5))
                .put(new StreamId(2, DEFAULT_SEQUENCE_ID, ROW_INDEX), new DiskRange(10, 5))
                .put(new StreamId(3, DEFAULT_SEQUENCE_ID, ROW_INDEX), new DiskRange(15, 5))
                .put(new StreamId(4, DEFAULT_SEQUENCE_ID, ROW_INDEX), new DiskRange(20, 5))
                .put(new StreamId(5, DEFAULT_SEQUENCE_ID, ROW_INDEX), new DiskRange(25, 5))
                .put(new StreamId(0, DEFAULT_SEQUENCE_ID, DATA), new DiskRange(30, 5))
                .put(new StreamId(1, DEFAULT_SEQUENCE_ID, DATA), new DiskRange(35, 5))
                .put(new StreamId(2, DEFAULT_SEQUENCE_ID, DATA), new DiskRange(40, 5))
                .put(new StreamId(3, DEFAULT_SEQUENCE_ID, DATA), new DiskRange(45, 5))
                .put(new StreamId(4, DEFAULT_SEQUENCE_ID, DATA), new DiskRange(50, 5))
                .put(new StreamId(5, DEFAULT_SEQUENCE_ID, DATA), new DiskRange(55, 5))
                .build();
        assertEquals(actual, expected);
    }

    @Test
    public void testMultipleEncryptionGroupsRowType()
            throws Exception
    {
        Type rowType = rowType(BIGINT, BIGINT, BIGINT);
        Slice iek1 = Slices.utf8Slice("iek1");
        Slice iek2 = Slices.utf8Slice("iek2");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(2), iek1),
                        new WriterEncryptionGroup(ImmutableList.of(3), iek2)));
        List<Type> types = ImmutableList.of(rowType);
        List<Long> columnValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toList());

        List<List<?>> values = ImmutableList.of(columnValues.stream()
                .map(OrcTester::toHiveStruct)
                .collect(toList()));
        List<Integer> outputColumns = IntStream.range(0, types.size())
                .boxed()
                .collect(toImmutableList());

        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(2, iek1,
                        3, iek2),
                ImmutableMap.of(0, rowType),
                ImmutableMap.of(),
                outputColumns);
    }

    @Test
    public void testSingleEncryptionGroupRowType()
            throws Exception
    {
        Type rowType = rowType(BIGINT, BIGINT, BIGINT);
        Slice iek1 = Slices.utf8Slice("iek1");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(1), iek1)));
        List<Type> types = ImmutableList.of(rowType);
        List<Long> columnValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toList());

        List<List<?>> values = ImmutableList.of(columnValues.stream()
                .map(OrcTester::toHiveStruct)
                .collect(toList()));
        List<Integer> outputColumns = IntStream.range(0, types.size())
                .boxed()
                .collect(toImmutableList());

        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(1, iek1),
                ImmutableMap.of(0, rowType),
                ImmutableMap.of(),
                outputColumns);
    }

    @Test
    public void testEncryptionGroupWithMultipleTypes()
            throws Exception
    {
        Slice iek1 = Slices.utf8Slice("iek1");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(1, 2), iek1)));
        List<Type> types = ImmutableList.of(BIGINT, VARCHAR);
        List<Long> intValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toImmutableList());
        List<String> varcharValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(String::valueOf)
                .collect(toImmutableList());

        List<List<?>> values = ImmutableList.of(intValues, varcharValues);
        List<Integer> outputColumns = IntStream.range(0, types.size())
                .boxed()
                .collect(toImmutableList());

        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(1, iek1, 2, iek1),
                ImmutableMap.of(0, BIGINT, 1, VARCHAR),
                ImmutableMap.of(),
                outputColumns);
    }

    @Test
    public void testEncryptionGroupWithReversedOrderNodes()
            throws Exception
    {
        Slice iek1 = Slices.utf8Slice("iek1");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(2, 1), iek1)));
        List<Type> types = ImmutableList.of(BIGINT, VARCHAR);
        List<Long> intValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toImmutableList());
        List<String> varcharValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(String::valueOf)
                .collect(toImmutableList());

        List<List<?>> values = ImmutableList.of(intValues, varcharValues);
        List<Integer> outputColumns = IntStream.range(0, types.size())
                .boxed()
                .collect(toImmutableList());

        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(1, iek1, 2, iek1),
                ImmutableMap.of(0, BIGINT, 1, VARCHAR),
                ImmutableMap.of(),
                outputColumns);
    }

    @Test
    public void testMultipleEncryptionGroupsMultipleColumns()
            throws Exception
    {
        Type rowType = rowType(BIGINT, BIGINT, BIGINT);
        Slice iek1 = Slices.utf8Slice("iek1");
        Slice iek2 = Slices.utf8Slice("iek2");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(1, 8), iek1),
                        new WriterEncryptionGroup(ImmutableList.of(5, 9, 10), iek2)));
        List<Type> types = ImmutableList.of(rowType, BIGINT, VARCHAR, VARCHAR, BIGINT, VARCHAR, BIGINT, BIGINT);
        List<Long> columnValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toList());
        List<String> varcharValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(String::valueOf)
                .collect(toList());

        List<List<?>> rowValues = columnValues.stream()
                .map(OrcTester::toHiveStruct)
                .collect(toList());
        List<List<?>> values = ImmutableList.of(
                rowValues,
                columnValues,
                varcharValues,
                varcharValues,
                columnValues,
                varcharValues,
                columnValues,
                columnValues);
        List<Integer> outputColumns = IntStream.range(0, types.size())
                .boxed()
                .collect(toImmutableList());

        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(1, iek1, 5, iek2, 8, iek1, 9, iek2, 10, iek2),
                ImmutableMap.<Integer, Type>builder()
                        .put(0, rowType)
                        .put(1, BIGINT)
                        .put(2, VARCHAR)
                        .put(3, VARCHAR)
                        .put(4, BIGINT)
                        .put(5, VARCHAR)
                        .put(6, BIGINT)
                        .put(7, BIGINT)
                        .build(),
                ImmutableMap.of(),
                outputColumns);
    }

    @Test
    public void testEncryptionMultipleColumns()
            throws Exception
    {
        List<Long> columnValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toList());
        List<Type> types = ImmutableList.of(BIGINT, BIGINT);
        List<List<?>> values = ImmutableList.of(columnValues, columnValues);
        Slice iek1 = Slices.utf8Slice("iek1");
        Slice iek2 = Slices.utf8Slice("iek2");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(1), iek1),
                        new WriterEncryptionGroup(ImmutableList.of(2), iek2)));
        List<Integer> outputColumns = IntStream.range(0, types.size())
                .boxed()
                .collect(toImmutableList());
        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(
                        1, iek1,
                        2, iek2),
                ImmutableMap.of(0, BIGINT, 1, BIGINT),
                ImmutableMap.of(),
                outputColumns);
    }

    @Test
    public void testSkipFirstStripe()
            throws Exception
    {
        OrcDataSource orcDataSource = new FileOrcDataSource(
                OrcReaderTestingUtils.getResourceFile("encrypted_2splits.dwrf"),
                new DataSize(1, MEGABYTE),
                new DataSize(1, MEGABYTE),
                new DataSize(1, MEGABYTE),
                true);
        OrcReader orcReader = new OrcReader(
                orcDataSource,
                DWRF,
                new StorageOrcFileTailSource(),
                new StorageStripeMetadataSource(),
                NOOP_ORC_AGGREGATED_MEMORY_CONTEXT,
                OrcReaderOptions.builder()
                        .withMaxMergeDistance(new DataSize(1, MEGABYTE))
                        .withTinyStripeThreshold(new DataSize(1, MEGABYTE))
                        .withMaxBlockSize(MAX_BLOCK_SIZE)
                        .build(),
                false,
                new DwrfEncryptionProvider(new UnsupportedEncryptionLibrary(), new TestingPlainKeyEncryptionLibrary()),
                DwrfKeyProvider.of(ImmutableMap.of(0, Slices.utf8Slice("key"))),
                new RuntimeStats());

        int offset = 10;
        try (OrcSelectiveRecordReader recordReader = getSelectiveRecordReader(orcDataSource, orcReader, offset)) {
            assertFileContentsPresto(
                    ImmutableList.of(BIGINT),
                    recordReader,
                    ImmutableList.of(ImmutableList.of(1L)),
                    ImmutableList.of(0));
        }
    }

    @Test(expectedExceptions = OrcPermissionsException.class)
    public void testPermissionErrorForEncryptedWithoutKeys()
            throws Exception
    {
        List<Type> types = ImmutableList.of(rowType(BIGINT, BIGINT, BIGINT));
        List<Long> columnValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toList());
        List<List<?>> values = ImmutableList.of(columnValues.stream()
                .map(OrcTester::toHiveStruct)
                .collect(toList()));

        Slice iek = Slices.utf8Slice("iek");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(0), iek)));

        List<Integer> outputColumns = IntStream.range(0, types.size())
                .boxed()
                .collect(toImmutableList());
        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(),
                ImmutableMap.of(0, rowType(BIGINT, BIGINT, BIGINT)),
                ImmutableMap.of(),
                outputColumns);
    }

    @Test
    public void testReadPermittedColumnsWithoutAllKeys()
            throws Exception
    {
        Subfield subfield1 = new Subfield("c.field_0");
        Subfield subfield3 = new Subfield("c.field_2");
        Type rowType = rowType(BIGINT, BIGINT, BIGINT);
        Slice iek1 = Slices.utf8Slice("iek1");
        Slice iek2 = Slices.utf8Slice("iek2");
        Slice iek3 = Slices.utf8Slice("iek3");
        DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
                UNKNOWN,
                ImmutableList.of(
                        new WriterEncryptionGroup(ImmutableList.of(2), iek1),
                        new WriterEncryptionGroup(ImmutableList.of(3, 6), iek2),
                        new WriterEncryptionGroup(ImmutableList.of(5), iek3)));
        List<Type> types = ImmutableList.of(rowType, BIGINT, BIGINT, BIGINT);
        List<Long> columnValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
                .map(Number::longValue)
                .collect(toList());

        List<List<?>> writtenValues = ImmutableList.of(columnValues.stream()
                        .map(OrcTester::toHiveStruct)
                        .collect(toList()),
                columnValues,
                columnValues,
                columnValues);
        List<List<?>> readValues = ImmutableList.of(columnValues.stream()
                        .map(value -> asList(value, null, value))
                        .collect(toList()),
                columnValues,
                columnValues);
        testDecryptionRoundTrip(
                types,
                writtenValues,
                readValues,
                Optional.of(dwrfWriterEncryption),
                ImmutableMap.of(2, iek1, 5, iek3),
                ImmutableMap.of(0, rowType, 1, BIGINT, 3, BIGINT),
                ImmutableMap.of(0, ImmutableList.of(subfield1, subfield3)),
                ImmutableList.of(0, 1, 3));
    }

    @Test
    public void testReadsEmptyFile()
            throws Exception
    {
        List<Type> types = ImmutableList.of(BIGINT, BIGINT);
        List<List<?>> values = ImmutableList.of(ImmutableList.of(), ImmutableList.of());
        Slice iek1 = Slices.utf8Slice("iek1");
        Slice iek2 = Slices.utf8Slice("iek2");
        List<Integer> outputColumns = ImmutableList.of(0, 1);
        // empty files don't specify encryption groups
        testDecryptionRoundTrip(
                types,
                values,
                values,
                Optional.empty(),
                ImmutableMap.of(
                        1, iek1,
                        2, iek2),
                ImmutableMap.of(0, BIGINT, 1, BIGINT),
                ImmutableMap.of(),
                outputColumns);
    }

    private static void testDecryptionRoundTrip(
            List<Type> types,
            List<List<?>> writtenValues,
            List<List<?>> readValues,
            Optional<DwrfWriterEncryption> dwrfWriterEncryption,
            Map<Integer, Slice> readerIntermediateKeys,
            Map<Integer, Type> includedColumns,
            Map<Integer, List<Subfield>> requiredSubfields,
            List<Integer> outputColumns)
            throws Exception
    {
        try (TempFile tempFile = new TempFile()) {
            writeOrcColumnsPresto(tempFile.getFile(), OrcTester.Format.DWRF, ZSTD, dwrfWriterEncryption, types, writtenValues, NOOP_WRITER_STATS);

            assertFileContentsPresto(
                    types,
                    tempFile.getFile(),
                    readValues,
                    DWRF,
                    OrcPredicate.TRUE,
                    Optional.empty(),
                    ImmutableList.of(),
                    ImmutableMap.of(),
                    requiredSubfields,
                    readerIntermediateKeys,
                    includedColumns,
                    outputColumns);

            validateFileStatistics(tempFile, dwrfWriterEncryption, readerIntermediateKeys);
        }
    }

    private static void validateFileStatistics(TempFile tempFile, Optional<DwrfWriterEncryption> dwrfWriterEncryption, Map<Integer, Slice> readerIntermediateKeys)
            throws IOException
    {
        OrcReader readerNoKeys = OrcTester.createCustomOrcReader(tempFile, DWRF, false, ImmutableMap.of());
        if (readerNoKeys.getFooter().getStripes().isEmpty()) {
            // files w/o stripes don't have stats
            assertEquals(readerNoKeys.getFooter().getFileStats().size(), 0);
            return;
        }

        if (dwrfWriterEncryption.isPresent()) {
            List<OrcType> types = readerNoKeys.getTypes();
            List<ColumnStatistics> fileStatsNoKey = readerNoKeys.getFooter().getFileStats();
            assertEquals(fileStatsNoKey.size(), types.size());

            Set<Integer> allEncryptedNodes = dwrfWriterEncryption.get().getWriterEncryptionGroups().stream()
                    .flatMap(group -> group.getNodes().stream())
                    .flatMap(node -> collectNodeTree(types, node).stream())
                    .collect(Collectors.toSet());

            for (Set<Integer> readerKeyNodes : Sets.powerSet(readerIntermediateKeys.keySet())) {
                Map<Integer, Slice> readerKeys = new HashMap<>();
                readerKeyNodes.forEach(node -> readerKeys.put(node, readerIntermediateKeys.get(node)));

                // nodes that are supposed to be decrypted by the reader
                Set<Integer> decryptedNodes = readerKeys.keySet().stream()
                        .flatMap(node -> collectNodeTree(types, node).stream())
                        .collect(Collectors.toSet());

                // decryptedNodes should be a subset of encrypted nodes
                assertTrue(allEncryptedNodes.containsAll(decryptedNodes));

                OrcReader readerWithKeys = OrcTester.createCustomOrcReader(tempFile, DWRF, false, readerIntermediateKeys);
                List<ColumnStatistics> fileStatsWithKey = readerWithKeys.getFooter().getFileStats();
                assertEquals(fileStatsWithKey.size(), types.size());

                for (int node = 0; node < types.size(); node++) {
                    ColumnStatistics statsWithKey = fileStatsWithKey.get(node);
                    ColumnStatistics statsNoKey = fileStatsNoKey.get(node);
                    OrcType type = types.get(node);

                    // encrypted nodes should have no type info
                    if (allEncryptedNodes.contains(node)) {
                        assertTrue(hasNoTypeStats(statsNoKey));
                    }
                    else {
                        assertStatsTypeMatch(statsNoKey, type);
                        assertStatsTypeMatch(statsWithKey, type);
                        assertEquals(statsNoKey, statsWithKey);
                    }

                    if (decryptedNodes.contains(node)) {
                        assertStatsTypeMatch(statsWithKey, type);
                    }
                }
            }
        }
    }

    private static void assertStatsTypeMatch(ColumnStatistics stats, OrcType type)
    {
        OrcType.OrcTypeKind kind = type.getOrcTypeKind();
        if (kind == OrcType.OrcTypeKind.BINARY) {
            assertNotNull(stats.getBinaryStatistics());
        }
        else if (kind == OrcType.OrcTypeKind.BOOLEAN) {
            assertNotNull(stats.getBooleanStatistics());
        }
        else if (kind == OrcType.OrcTypeKind.BYTE || kind == OrcType.OrcTypeKind.SHORT || kind == OrcType.OrcTypeKind.INT || kind == OrcType.OrcTypeKind.LONG) {
            assertNotNull(stats.getIntegerStatistics());
        }
        else if (kind == OrcType.OrcTypeKind.FLOAT || kind == OrcType.OrcTypeKind.DOUBLE) {
            assertNotNull(stats.getDoubleStatistics());
        }
        else if (kind == OrcType.OrcTypeKind.STRING) {
            assertNotNull(stats.getStringStatistics());
        }
        else {
            assertTrue(hasNoTypeStats(stats));
        }
    }

    private static boolean hasNoTypeStats(ColumnStatistics columnStatistics)
    {
        return columnStatistics.getBooleanStatistics() == null
                && columnStatistics.getIntegerStatistics() == null
                && columnStatistics.getDoubleStatistics() == null
                && columnStatistics.getStringStatistics() == null
                && columnStatistics.getDateStatistics() == null
                && columnStatistics.getDecimalStatistics() == null
                && columnStatistics.getBinaryStatistics() == null;
    }

    private static Set<Integer> collectNodeTree(List<OrcType> types, int node)
    {
        Set<Integer> nodes = new HashSet<>();
        collectNodeTree(nodes, types, node);
        return nodes;
    }

    private static void collectNodeTree(Set<Integer> nodes, List<OrcType> types, int node)
    {
        nodes.add(node);
        for (Integer subNode : types.get(node).getFieldTypeIndexes()) {
            collectNodeTree(nodes, types, subNode);
        }
    }

    private static OrcSelectiveRecordReader getSelectiveRecordReader(OrcDataSource orcDataSource, OrcReader orcReader, int offset)
    {
        return orcReader.createSelectiveRecordReader(
                ImmutableMap.of(0, BIGINT),
                ImmutableList.of(0),
                ImmutableMap.of(),
                ImmutableList.of(),
                ImmutableMap.of(),
                ImmutableMap.of(),
                ImmutableMap.of(),
                ImmutableMap.of(),
                OrcPredicate.TRUE,
                offset,
                orcDataSource.getSize(),
                HIVE_STORAGE_TIME_ZONE,
                new TestingHiveOrcAggregatedMemoryContext(),
                Optional.empty(),
                MAX_BATCH_SIZE);
    }
}