TestOrcWriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.orc;

import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.io.DataOutput;
import com.facebook.presto.common.io.DataSink;
import com.facebook.presto.common.io.OutputStreamDataSink;
import com.facebook.presto.orc.OrcWriteValidation.OrcWriteValidationMode;
import com.facebook.presto.orc.metadata.CompressionKind;
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.metadata.StripeFooter;
import com.facebook.presto.orc.writer.StreamLayoutFactory;
import com.facebook.presto.orc.writer.StreamLayoutFactory.ColumnSizeLayoutFactory;
import com.facebook.presto.orc.writer.StreamLayoutFactory.StreamSizeLayoutFactory;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static com.facebook.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.orc.DwrfEncryptionProvider.NO_ENCRYPTION;
import static com.facebook.presto.orc.NoOpOrcWriterStats.NOOP_WRITER_STATS;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static com.facebook.presto.orc.OrcEncoding.ORC;
import static com.facebook.presto.orc.OrcTester.HIVE_STORAGE_TIME_ZONE;
import static com.facebook.presto.orc.StripeReader.isIndexStream;
import static com.facebook.presto.orc.TestingOrcPredicate.ORC_ROW_GROUP_SIZE;
import static com.facebook.presto.orc.TestingOrcPredicate.ORC_STRIPE_SIZE;
import static com.facebook.presto.orc.metadata.CompressionKind.NONE;
import static com.facebook.presto.orc.metadata.CompressionKind.ZLIB;
import static com.facebook.presto.orc.metadata.CompressionKind.ZSTD;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;

public class TestOrcWriter
{
    @DataProvider(name = "compressionLevels")
    public static Object[][] zstdCompressionLevels()
    {
        ImmutableList.Builder<Object[]> parameters = new ImmutableList.Builder<>();
        parameters.add(new Object[] {ORC, NONE, OptionalInt.empty()});
        parameters.add(new Object[] {DWRF, ZSTD, OptionalInt.of(7)});
        parameters.add(new Object[] {DWRF, ZSTD, OptionalInt.empty()});
        parameters.add(new Object[] {DWRF, ZLIB, OptionalInt.of(5)});
        parameters.add(new Object[] {DWRF, ZLIB, OptionalInt.empty()});
        return parameters.build().toArray(new Object[0][]);
    }

    @Test(dataProvider = "compressionLevels")
    public void testWriteOutputStreamsInOrder(OrcEncoding encoding, CompressionKind kind, OptionalInt level)
            throws IOException
    {
        testStreamOrder(encoding, kind, level, new StreamSizeLayoutFactory(), () -> new Consumer<Stream>()
        {
            int size;

            @Override
            public void accept(Stream stream)
            {
                if (!isIndexStream(stream)) {
                    assertGreaterThanOrEqual(stream.getLength(), size, stream.toString());
                    size = stream.getLength();
                }
            }
        });
    }

    @Test(dataProvider = "compressionLevels")
    public void testOutputStreamsByColumnSize(OrcEncoding encoding, CompressionKind kind, OptionalInt level)
            throws IOException
    {
        Map<Integer, Integer> nodeSizes = new LinkedHashMap<>();
        testStreamOrder(encoding, kind, level, new ColumnSizeLayoutFactory(), () -> stream -> {
            if (!isIndexStream(stream)) {
                int node = stream.getColumn();
                int oldSize = nodeSizes.computeIfAbsent(node, (c) -> 0);
                nodeSizes.put(node, oldSize + stream.getLength());
            }
        });

        List<Integer> actual = ImmutableList.copyOf(nodeSizes.values());
        List<Integer> expected = actual.stream().sorted(Comparator.reverseOrder()).collect(Collectors.toList());
        assertEquals(actual, expected);
    }

    private void testStreamOrder(OrcEncoding encoding, CompressionKind kind, OptionalInt level, StreamLayoutFactory streamLayoutFactory, Supplier<Consumer<Stream>> streamConsumerFactory)
            throws IOException
    {
        OrcWriterOptions orcWriterOptions = OrcWriterOptions.builder()
                .withFlushPolicy(DefaultOrcWriterFlushPolicy.builder()
                        .withStripeMinSize(new DataSize(0, MEGABYTE))
                        .withStripeMaxSize(new DataSize(32, MEGABYTE))
                        .withStripeMaxRowCount(ORC_STRIPE_SIZE)
                        .build())
                .withRowGroupMaxRowCount(ORC_ROW_GROUP_SIZE)
                .withDictionaryMaxMemory(new DataSize(32, MEGABYTE))
                .withCompressionLevel(level)
                .withStreamLayoutFactory(streamLayoutFactory)
                .build();
        for (OrcWriteValidationMode validationMode : OrcWriteValidationMode.values()) {
            TempFile tempFile = new TempFile();
            OrcWriter writer = new OrcWriter(
                    new OutputStreamDataSink(new FileOutputStream(tempFile.getFile())),
                    ImmutableList.of("test1", "test2", "test3", "test4", "test5"),
                    ImmutableList.of(VARCHAR, VARCHAR, VARCHAR, VARCHAR, VARCHAR),
                    encoding,
                    kind,
                    Optional.empty(),
                    NO_ENCRYPTION,
                    orcWriterOptions,
                    ImmutableMap.of(),
                    HIVE_STORAGE_TIME_ZONE,
                    true,
                    validationMode,
                    NOOP_WRITER_STATS);

            // write down some data with unsorted streams
            String[] data = new String[] {"a", "bbbbb", "ccc", "dd", "eeee"};
            Block[] blocks = new Block[data.length];
            int entries = 65536;
            BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, entries);
            for (int i = 0; i < data.length; i++) {
                byte[] bytes = data[i].getBytes();
                for (int j = 0; j < entries; j++) {
                    // force to write different data
                    bytes[0] = (byte) ((bytes[0] + 1) % 128);
                    blockBuilder.writeBytes(Slices.wrappedBuffer(bytes, 0, bytes.length), 0, bytes.length);
                    blockBuilder.closeEntry();
                }
                blocks[i] = blockBuilder.build();
                blockBuilder = blockBuilder.newBlockBuilderLike(null);
            }

            writer.write(new Page(blocks));
            writer.close();

            for (StripeFooter stripeFooter : OrcTester.getStripes(tempFile.getFile(), encoding)) {
                Consumer<Stream> streamConsumer = streamConsumerFactory.get();
                boolean dataStreamStarted = false;
                for (Stream stream : stripeFooter.getStreams()) {
                    if (isIndexStream(stream)) {
                        assertFalse(dataStreamStarted);
                        continue;
                    }
                    dataStreamStarted = true;
                    streamConsumer.accept(stream);
                }
            }
        }
    }

    @Test(expectedExceptions = IOException.class, expectedExceptionsMessageRegExp = "Dummy exception from mocked instance")
    public void testVerifyNoIllegalStateException()
            throws IOException
    {
        OrcWriter writer = new OrcWriter(
                new MockDataSink(),
                ImmutableList.of("test1"),
                ImmutableList.of(VARCHAR),
                ORC,
                NONE,
                Optional.empty(),
                NO_ENCRYPTION,
                OrcWriterOptions.builder()
                        .withFlushPolicy(DefaultOrcWriterFlushPolicy.builder()
                                .withStripeMinSize(new DataSize(0, MEGABYTE))
                                .withStripeMaxSize(new DataSize(32, MEGABYTE))
                                .withStripeMaxRowCount(10)
                                .build())
                        .withRowGroupMaxRowCount(ORC_ROW_GROUP_SIZE)
                        .withDictionaryMaxMemory(new DataSize(32, MEGABYTE))
                        .build(),
                ImmutableMap.of(),
                HIVE_STORAGE_TIME_ZONE,
                false,
                null,
                NOOP_WRITER_STATS);

        int entries = 65536;
        BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, entries);
        byte[] bytes = "dummyString".getBytes();
        for (int j = 0; j < entries; j++) {
            // force to write different data
            bytes[0] = (byte) ((bytes[0] + 1) % 128);
            blockBuilder.writeBytes(Slices.wrappedBuffer(bytes, 0, bytes.length), 0, bytes.length);
            blockBuilder.closeEntry();
        }
        Block[] blocks = new Block[] {blockBuilder.build()};

        try {
            // Throw IOException after first flush
            writer.write(new Page(blocks));
        }
        catch (IOException e) {
            writer.close();
        }
    }

    public static class MockDataSink
            implements DataSink
    {
        public MockDataSink()
        {
        }

        @Override
        public long size()
        {
            return -1L;
        }

        @Override
        public long getRetainedSizeInBytes()
        {
            return -1L;
        }

        @Override
        public void write(List<DataOutput> outputData)
                throws IOException
        {
            throw new IOException("Dummy exception from mocked instance");
        }

        @Override
        public void close()
        {
        }
    }
}