TestOrcWriter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.io.DataOutput;
import com.facebook.presto.common.io.DataSink;
import com.facebook.presto.common.io.OutputStreamDataSink;
import com.facebook.presto.orc.OrcWriteValidation.OrcWriteValidationMode;
import com.facebook.presto.orc.metadata.CompressionKind;
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.metadata.StripeFooter;
import com.facebook.presto.orc.writer.StreamLayoutFactory;
import com.facebook.presto.orc.writer.StreamLayoutFactory.ColumnSizeLayoutFactory;
import com.facebook.presto.orc.writer.StreamLayoutFactory.StreamSizeLayoutFactory;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static com.facebook.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.orc.DwrfEncryptionProvider.NO_ENCRYPTION;
import static com.facebook.presto.orc.NoOpOrcWriterStats.NOOP_WRITER_STATS;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static com.facebook.presto.orc.OrcEncoding.ORC;
import static com.facebook.presto.orc.OrcTester.HIVE_STORAGE_TIME_ZONE;
import static com.facebook.presto.orc.StripeReader.isIndexStream;
import static com.facebook.presto.orc.TestingOrcPredicate.ORC_ROW_GROUP_SIZE;
import static com.facebook.presto.orc.TestingOrcPredicate.ORC_STRIPE_SIZE;
import static com.facebook.presto.orc.metadata.CompressionKind.NONE;
import static com.facebook.presto.orc.metadata.CompressionKind.ZLIB;
import static com.facebook.presto.orc.metadata.CompressionKind.ZSTD;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
public class TestOrcWriter
{
@DataProvider(name = "compressionLevels")
public static Object[][] zstdCompressionLevels()
{
ImmutableList.Builder<Object[]> parameters = new ImmutableList.Builder<>();
parameters.add(new Object[] {ORC, NONE, OptionalInt.empty()});
parameters.add(new Object[] {DWRF, ZSTD, OptionalInt.of(7)});
parameters.add(new Object[] {DWRF, ZSTD, OptionalInt.empty()});
parameters.add(new Object[] {DWRF, ZLIB, OptionalInt.of(5)});
parameters.add(new Object[] {DWRF, ZLIB, OptionalInt.empty()});
return parameters.build().toArray(new Object[0][]);
}
@Test(dataProvider = "compressionLevels")
public void testWriteOutputStreamsInOrder(OrcEncoding encoding, CompressionKind kind, OptionalInt level)
throws IOException
{
testStreamOrder(encoding, kind, level, new StreamSizeLayoutFactory(), () -> new Consumer<Stream>()
{
int size;
@Override
public void accept(Stream stream)
{
if (!isIndexStream(stream)) {
assertGreaterThanOrEqual(stream.getLength(), size, stream.toString());
size = stream.getLength();
}
}
});
}
@Test(dataProvider = "compressionLevels")
public void testOutputStreamsByColumnSize(OrcEncoding encoding, CompressionKind kind, OptionalInt level)
throws IOException
{
Map<Integer, Integer> nodeSizes = new LinkedHashMap<>();
testStreamOrder(encoding, kind, level, new ColumnSizeLayoutFactory(), () -> stream -> {
if (!isIndexStream(stream)) {
int node = stream.getColumn();
int oldSize = nodeSizes.computeIfAbsent(node, (c) -> 0);
nodeSizes.put(node, oldSize + stream.getLength());
}
});
List<Integer> actual = ImmutableList.copyOf(nodeSizes.values());
List<Integer> expected = actual.stream().sorted(Comparator.reverseOrder()).collect(Collectors.toList());
assertEquals(actual, expected);
}
private void testStreamOrder(OrcEncoding encoding, CompressionKind kind, OptionalInt level, StreamLayoutFactory streamLayoutFactory, Supplier<Consumer<Stream>> streamConsumerFactory)
throws IOException
{
OrcWriterOptions orcWriterOptions = OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder()
.withStripeMinSize(new DataSize(0, MEGABYTE))
.withStripeMaxSize(new DataSize(32, MEGABYTE))
.withStripeMaxRowCount(ORC_STRIPE_SIZE)
.build())
.withRowGroupMaxRowCount(ORC_ROW_GROUP_SIZE)
.withDictionaryMaxMemory(new DataSize(32, MEGABYTE))
.withCompressionLevel(level)
.withStreamLayoutFactory(streamLayoutFactory)
.build();
for (OrcWriteValidationMode validationMode : OrcWriteValidationMode.values()) {
TempFile tempFile = new TempFile();
OrcWriter writer = new OrcWriter(
new OutputStreamDataSink(new FileOutputStream(tempFile.getFile())),
ImmutableList.of("test1", "test2", "test3", "test4", "test5"),
ImmutableList.of(VARCHAR, VARCHAR, VARCHAR, VARCHAR, VARCHAR),
encoding,
kind,
Optional.empty(),
NO_ENCRYPTION,
orcWriterOptions,
ImmutableMap.of(),
HIVE_STORAGE_TIME_ZONE,
true,
validationMode,
NOOP_WRITER_STATS);
// write down some data with unsorted streams
String[] data = new String[] {"a", "bbbbb", "ccc", "dd", "eeee"};
Block[] blocks = new Block[data.length];
int entries = 65536;
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, entries);
for (int i = 0; i < data.length; i++) {
byte[] bytes = data[i].getBytes();
for (int j = 0; j < entries; j++) {
// force to write different data
bytes[0] = (byte) ((bytes[0] + 1) % 128);
blockBuilder.writeBytes(Slices.wrappedBuffer(bytes, 0, bytes.length), 0, bytes.length);
blockBuilder.closeEntry();
}
blocks[i] = blockBuilder.build();
blockBuilder = blockBuilder.newBlockBuilderLike(null);
}
writer.write(new Page(blocks));
writer.close();
for (StripeFooter stripeFooter : OrcTester.getStripes(tempFile.getFile(), encoding)) {
Consumer<Stream> streamConsumer = streamConsumerFactory.get();
boolean dataStreamStarted = false;
for (Stream stream : stripeFooter.getStreams()) {
if (isIndexStream(stream)) {
assertFalse(dataStreamStarted);
continue;
}
dataStreamStarted = true;
streamConsumer.accept(stream);
}
}
}
}
@Test(expectedExceptions = IOException.class, expectedExceptionsMessageRegExp = "Dummy exception from mocked instance")
public void testVerifyNoIllegalStateException()
throws IOException
{
OrcWriter writer = new OrcWriter(
new MockDataSink(),
ImmutableList.of("test1"),
ImmutableList.of(VARCHAR),
ORC,
NONE,
Optional.empty(),
NO_ENCRYPTION,
OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder()
.withStripeMinSize(new DataSize(0, MEGABYTE))
.withStripeMaxSize(new DataSize(32, MEGABYTE))
.withStripeMaxRowCount(10)
.build())
.withRowGroupMaxRowCount(ORC_ROW_GROUP_SIZE)
.withDictionaryMaxMemory(new DataSize(32, MEGABYTE))
.build(),
ImmutableMap.of(),
HIVE_STORAGE_TIME_ZONE,
false,
null,
NOOP_WRITER_STATS);
int entries = 65536;
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, entries);
byte[] bytes = "dummyString".getBytes();
for (int j = 0; j < entries; j++) {
// force to write different data
bytes[0] = (byte) ((bytes[0] + 1) % 128);
blockBuilder.writeBytes(Slices.wrappedBuffer(bytes, 0, bytes.length), 0, bytes.length);
blockBuilder.closeEntry();
}
Block[] blocks = new Block[] {blockBuilder.build()};
try {
// Throw IOException after first flush
writer.write(new Page(blocks));
}
catch (IOException e) {
writer.close();
}
}
public static class MockDataSink
implements DataSink
{
public MockDataSink()
{
}
@Override
public long size()
{
return -1L;
}
@Override
public long getRetainedSizeInBytes()
{
return -1L;
}
@Override
public void write(List<DataOutput> outputData)
throws IOException
{
throw new IOException("Dummy exception from mocked instance");
}
@Override
public void close()
{
}
}
}