TestDictionaryColumnWriter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.orc.DictionaryCompressionOptimizer.DictionaryColumnManager;
import com.facebook.presto.orc.metadata.ColumnEncoding;
import com.facebook.presto.orc.metadata.StripeFooter;
import com.facebook.presto.orc.writer.DictionaryColumnWriter;
import com.facebook.presto.orc.writer.SliceDictionaryColumnWriter;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import org.testng.annotations.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.StandardTypes.ARRAY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.orc.NoOpOrcWriterStats.NOOP_WRITER_STATS;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static com.facebook.presto.orc.OrcEncoding.ORC;
import static com.facebook.presto.orc.OrcReader.INITIAL_BATCH_SIZE;
import static com.facebook.presto.orc.OrcTester.createCustomOrcSelectiveRecordReader;
import static com.facebook.presto.orc.OrcTester.createOrcWriter;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DICTIONARY;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DICTIONARY_V2;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DIRECT;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DIRECT_V2;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DWRF_DIRECT;
import static com.facebook.presto.orc.metadata.ColumnEncoding.DEFAULT_SEQUENCE_ID;
import static com.facebook.presto.orc.metadata.CompressionKind.SNAPPY;
import static com.facebook.presto.orc.metadata.CompressionKind.ZSTD;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.cycle;
import static com.google.common.collect.Iterables.limit;
import static com.google.common.collect.Lists.newArrayList;
import static io.airlift.slice.SizeOf.SIZE_OF_BYTE;
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.Math.toIntExact;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
public class TestDictionaryColumnWriter
{
private static final int COLUMN_ID = 1;
private static final int BATCH_ROWS = 1_000;
private static final int STRIPE_MAX_ROWS = 15_000;
private static final int INTEGER_VALUES_DICTIONARY_BASE = 1_000_000_000;
private static final Random RANDOM = new Random();
private static int megabytes(int size)
{
return toIntExact(new DataSize(size, MEGABYTE).toBytes());
}
private static int getStripeSize(int size)
{
if (size == 0) {
return 0;
}
return ((size - 1) / STRIPE_MAX_ROWS) + 1;
}
@Test
public void testStringNoRows()
throws Exception
{
List<String> values = ImmutableList.of();
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
List<StripeFooter> stripeFooters = testStringDictionary(directConversionTester, input, values);
assertEquals(stripeFooters.size(), getStripeSize(values.size()));
}
}
@Test
public void testStringAllNullsWithDirectConversion()
throws Exception
{
List<String> values = newArrayList(limit(cycle(new String[] {null}), 90_000));
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
directConversionTester.add(7, megabytes(1), true);
directConversionTester.add(14, megabytes(1), true);
directConversionTester.add(32, megabytes(1), true);
List<StripeFooter> stripeFooters = testStringDictionary(directConversionTester, input, values);
verifyDirectEncoding(getStripeSize(values.size()), input.getEncoding(), stripeFooters);
}
}
@Test
public void testStringRandomValuesWithNull()
throws Exception
{
List<String> values = new ArrayList<>();
for (int i = 0; i < 60_000; i++) {
values.add(RANDOM.nextBoolean() ? null : UUID.randomUUID().toString());
}
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
List<StripeFooter> stripeFooters = testStringDictionary(directConversionTester, input, values);
verifyDirectEncoding(getStripeSize(values.size()), input.getEncoding(), stripeFooters);
}
}
@Test
public void testStringRandomRepeatingValues()
throws Exception
{
List<String> stripeValues = new ArrayList<>();
int dictionarySize = 1_000;
for (int i = 0; i < dictionarySize; i++) {
stripeValues.add(UUID.randomUUID().toString());
}
for (int i = dictionarySize; i < STRIPE_MAX_ROWS; i++) {
stripeValues.add(stripeValues.get(RANDOM.nextInt(dictionarySize)));
}
Collections.shuffle(stripeValues);
List<String> values = new ArrayList<>(stripeValues);
Collections.shuffle(stripeValues);
values.addAll(stripeValues);
Collections.shuffle(stripeValues);
values.addAll(stripeValues);
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
List<StripeFooter> stripeFooters = testStringDictionary(directConversionTester, input, values);
verifyDictionaryEncoding(getStripeSize(values.size()), input.getEncoding(), stripeFooters, ImmutableList.of(dictionarySize, dictionarySize, dictionarySize));
}
}
@Test
public void testStringNonRepeatingValues()
throws Exception
{
ImmutableList.Builder<String> builder = ImmutableList.builder();
for (int i = 0; i < 60_000; i++) {
builder.add(Integer.toString(i));
}
List<String> values = builder.build();
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
List<StripeFooter> stripeFooters = testStringDictionary(directConversionTester, input, values);
verifyDirectEncoding(getStripeSize(values.size()), input.getEncoding(), stripeFooters);
}
}
@Test
public void testStringIncreasedStrideSize()
throws Exception
{
ImmutableList.Builder<String> builder = ImmutableList.builder();
for (int i = 0; i < 60_000; i++) {
builder.add(Integer.toString(i));
}
List<String> values = builder.build();
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
OrcWriterOptions writerOptions = OrcWriterOptions.builder()
.withRowGroupMaxRowCount(14_876)
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder().withStripeMaxRowCount(STRIPE_MAX_ROWS).build())
.build();
testDictionary(VARCHAR, input.getEncoding(), writerOptions, directConversionTester, values);
}
}
@Test
public void testStringRepeatingValues()
throws Exception
{
ImmutableList.Builder<String> builder = ImmutableList.builder();
for (int i = 0; i < 60_000; i++) {
// Make a 7 letter String, by using million as base to force dictionary encoding.
builder.add(Integer.toString((i % 1000) + INTEGER_VALUES_DICTIONARY_BASE));
}
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
List<String> values = builder.build();
List<StripeFooter> stripeFooters = testStringDictionary(directConversionTester, input, values);
verifyDictionaryEncoding(getStripeSize(values.size()), input.getEncoding(), stripeFooters, ImmutableList.of(1000, 1000, 1000, 1000));
}
}
@Test
public void testStringRepeatingValuesWithDirectConversion()
throws Exception
{
List<String> values = new ArrayList<>(60_000);
for (int i = 0; i < 60_000; i++) {
int offset = i % 2001;
if (offset > 0) {
values.add(Integer.toString(offset + INTEGER_VALUES_DICTIONARY_BASE));
}
else {
values.add(null);
}
}
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester directConversionTester = new DirectConversionTester();
directConversionTester.add(0, megabytes(1), true);
directConversionTester.add(16, 5_000, false);
directConversionTester.add(16, megabytes(1), true);
List<StripeFooter> stripeFooters = testStringDictionary(directConversionTester, input, values);
assertEquals(getStripeSize(values.size()), stripeFooters.size());
verifyDirectEncoding(stripeFooters, input.getEncoding(), 0);
verifyDirectEncoding(stripeFooters, input.getEncoding(), 1);
verifyDictionaryEncoding(stripeFooters, input.getEncoding(), 2, 2000);
verifyDictionaryEncoding(stripeFooters, input.getEncoding(), 3, 2000);
}
}
@Test
public void testStringPreserveDirectEncoding()
throws IOException
{
ImmutableList.Builder<String> builder = ImmutableList.builder();
for (long i = 0; i < STRIPE_MAX_ROWS; i++) {
builder.add(Long.toString(Integer.MAX_VALUE + i));
}
int repeatInterval = 1500;
for (long i = 0; i < 100_000; i++) {
builder.add(Long.toString(Integer.MAX_VALUE + i % repeatInterval));
}
int preserveDirectEncodingStripeCount = 2;
OrcWriterOptions.Builder orcWriterOptionsBuilder = OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder().withStripeMaxRowCount(STRIPE_MAX_ROWS).build())
.withIntegerDictionaryEncodingEnabled(true)
.withPreserveDirectEncodingStripeCount(preserveDirectEncodingStripeCount);
List<String> values = builder.build();
for (StringDictionaryInput input : StringDictionaryInput.values()) {
DirectConversionTester tester = new DirectConversionTester();
OrcWriterOptions orcWriterOptions = orcWriterOptionsBuilder
.withStringDictionarySortingEnabled(input.isSortStringDictionaryKeys())
.build();
List<StripeFooter> stripeFooters = testDictionary(VARCHAR, input.getEncoding(), orcWriterOptions, tester, values);
assertEquals(getStripeSize(values.size()), stripeFooters.size());
for (int i = 0; i <= preserveDirectEncodingStripeCount; i++) {
verifyDirectEncoding(stripeFooters, input.getEncoding(), i);
}
for (int i = preserveDirectEncodingStripeCount + 1; i < stripeFooters.size(); i++) {
verifyDictionaryEncoding(stripeFooters, input.getEncoding(), i, repeatInterval);
}
}
}
@Test
public void testDisableStringDictionaryEncoding()
throws IOException
{
ImmutableList.Builder<String> builder = ImmutableList.builder();
for (long i = 0; i < STRIPE_MAX_ROWS * 3; i++) {
builder.add(Long.toString(0));
}
List<String> values = builder.build();
testStringDirectColumn(values);
}
@Test
public void testDisableStringOnlyNulls()
throws IOException
{
List<String> values = newArrayList(limit(cycle(new String[] {null}), 3 * STRIPE_MAX_ROWS));
testStringDirectColumn(values);
}
@Test
public void testDisableStringMixedNulls()
throws IOException
{
List<String> values = new ArrayList<>();
for (int i = 0; i < 50_000; i++) {
int childSize = i % 5;
if (childSize == 4) {
values.add(null);
}
values.add(Integer.toString(i));
}
testStringDirectColumn(values);
}
private void testStringDirectColumn(List<String> values)
throws IOException
{
long totalRows = values.size();
for (StringDictionaryInput input : StringDictionaryInput.values()) {
OrcWriterOptions orcWriterOptions = OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder().withStripeMaxRowCount(STRIPE_MAX_ROWS).build())
.withStringDictionaryEncodingEnabled(false)
.withStringDictionarySortingEnabled(input.isSortStringDictionaryKeys())
.build();
DirectConversionTester tester = new DirectConversionTester();
List<StripeFooter> stripeFooters = testDictionary(VARCHAR, input.getEncoding(), orcWriterOptions, tester, values);
long rows = 0;
int index = 0;
while (rows < totalRows) {
verifyDirectEncoding(stripeFooters, input.getEncoding(), index++);
rows += STRIPE_MAX_ROWS;
}
assertEquals(stripeFooters.size(), index);
}
}
@Test
public void testIntegerNoRows()
throws IOException
{
DirectConversionTester directConversionTester = new DirectConversionTester();
List<Integer> values = ImmutableList.of();
List<StripeFooter> stripeFooters = testIntegerDictionary(directConversionTester, values);
assertEquals(stripeFooters.size(), getStripeSize(values.size()));
}
@Test
public void testIntegerDictionaryAllNulls()
throws IOException
{
DirectConversionTester directConversionTester = new DirectConversionTester();
directConversionTester.add(7, megabytes(1), true);
directConversionTester.add(14, megabytes(1), true);
directConversionTester.add(32, megabytes(1), true);
List<Integer> values = newArrayList(limit(cycle(new Integer[] {null}), 60_000));
List<StripeFooter> stripeFooters = testIntegerDictionary(directConversionTester, values);
verifyDwrfDirectEncoding(getStripeSize(values.size()), stripeFooters);
}
@Test
public void testIntegerDictionaryAlternatingNulls()
throws IOException
{
DirectConversionTester directConversionTester = new DirectConversionTester();
List<Integer> values = newArrayList(limit(cycle(Integer.MAX_VALUE, null, Integer.MIN_VALUE), 60_000));
List<StripeFooter> stripeFooters = testIntegerDictionary(directConversionTester, values);
verifyDictionaryEncoding(getStripeSize(values.size()), DWRF, stripeFooters, ImmutableList.of(2, 2, 2, 2));
}
@Test
public void testIntegerRandomValues()
throws IOException
{
List<Integer> values = generateRandomIntegers(70_000);
DirectConversionTester directConversionTester = new DirectConversionTester();
testIntegerDictionary(directConversionTester, values);
}
@Test
public void testIntegerIncreasedStrideSize()
throws IOException
{
List<Integer> values = generateRandomIntegers(90_000);
DirectConversionTester directConversionTester = new DirectConversionTester();
OrcWriterOptions writerOptions = OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder().withStripeMaxRowCount(STRIPE_MAX_ROWS).build())
.withIntegerDictionaryEncodingEnabled(true)
.withRowGroupMaxRowCount(14_998)
.build();
testDictionary(INTEGER, DWRF, writerOptions, directConversionTester, values);
}
private List<Integer> generateRandomIntegers(int maxSize)
{
List<Integer> values = new ArrayList<>();
for (int i = 0; i < maxSize; i++) {
values.add(RANDOM.nextBoolean() ? null : RANDOM.nextInt());
}
return values;
}
@Test
public void testDictionaryRetainedSizeWithDifferentSettings()
{
DictionaryColumnWriter ignoredRowGroupWriter = getStringDictionaryColumnWriter(true);
DictionaryColumnWriter withRowGroupWriter = getStringDictionaryColumnWriter(false);
int numEntries = 10_000;
int numBlocks = 10;
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, numEntries);
Slice slice = utf8Slice("SomeString");
for (int i = 0; i < numEntries; i++) {
VARCHAR.writeSlice(blockBuilder, slice);
}
Block block = blockBuilder.build();
for (int i = 0; i < numBlocks; i++) {
writeBlock(ignoredRowGroupWriter, block);
writeBlock(withRowGroupWriter, block);
}
long ignoredRowGroupBytes = ignoredRowGroupWriter.getRowGroupRetainedSizeInBytes();
long withRowGroupBytes = withRowGroupWriter.getRowGroupRetainedSizeInBytes();
long expectedDictionaryIndexSize = (numBlocks * numEntries * SIZE_OF_BYTE);
String message = String.format("Ignored bytes %s With bytes %s", ignoredRowGroupBytes, withRowGroupBytes);
assertTrue(ignoredRowGroupBytes + expectedDictionaryIndexSize <= withRowGroupBytes, message);
}
private void writeBlock(DictionaryColumnWriter writer, Block block)
{
writer.beginRowGroup();
writer.writeBlock(block);
writer.finishRowGroup();
}
@Test
public void testLongRandomValues()
throws IOException
{
List<Long> values = generateRandomLongs(70_000);
DirectConversionTester directConversionTester = new DirectConversionTester();
testLongDictionary(directConversionTester, values);
}
@Test
public void testLongIncreasedStrideSize()
throws IOException
{
List<Long> values = generateRandomLongs(80_000);
DirectConversionTester directConversionTester = new DirectConversionTester();
OrcWriterOptions writerOptions = OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder().withStripeMaxRowCount(STRIPE_MAX_ROWS).build())
.withIntegerDictionaryEncodingEnabled(true)
.withRowGroupMaxRowCount(14_998)
.build();
testDictionary(BIGINT, DWRF, writerOptions, directConversionTester, values);
}
private static DictionaryColumnWriter getStringDictionaryColumnWriter(boolean ignoreRowGroupSizes)
{
OrcEncoding orcEncoding = DWRF;
ColumnWriterOptions columnWriterOptions = ColumnWriterOptions.builder()
.setCompressionKind(SNAPPY)
.setIgnoreDictionaryRowGroupSizes(ignoreRowGroupSizes)
.build();
return new SliceDictionaryColumnWriter(
COLUMN_ID,
DEFAULT_SEQUENCE_ID,
VARCHAR,
columnWriterOptions,
Optional.empty(),
orcEncoding,
orcEncoding.createMetadataWriter());
}
private List<Long> generateRandomLongs(int maxSize)
{
List<Long> values = new ArrayList<>();
for (int i = 0; i < maxSize; i++) {
values.add(RANDOM.nextBoolean() ? null : RANDOM.nextLong());
}
return values;
}
@Test
public void testIntegerDictionaryRepeatingValues()
throws IOException
{
ImmutableList<Integer> baseList = ImmutableList.of(Integer.MAX_VALUE, Integer.MIN_VALUE);
ImmutableList.Builder<Integer> builder = ImmutableList.builder();
builder.addAll(baseList);
int repeatInterval = 1500;
for (int i = baseList.size(); i < 45_000; i++) {
builder.add(INTEGER_VALUES_DICTIONARY_BASE + i % repeatInterval);
}
List<Integer> values = builder.build();
DirectConversionTester directConversionTester = new DirectConversionTester();
List<StripeFooter> stripeFooters = testIntegerDictionary(directConversionTester, values);
verifyDictionaryEncoding(getStripeSize(values.size()), DWRF, stripeFooters, ImmutableList.of(repeatInterval + baseList.size(), repeatInterval, repeatInterval));
// Now disable Integer dictionary encoding and verify that integer dictionary encoding is disabled
stripeFooters = testDictionary(INTEGER, DWRF, false, true, new DirectConversionTester(), values);
verifyDwrfDirectEncoding(getStripeSize(values.size()), stripeFooters);
}
@Test
public void testIntegerDictionaryNonRepeating()
throws IOException
{
DirectConversionTester directConversionTester = new DirectConversionTester();
ImmutableList<Integer> baseList = ImmutableList.of(Integer.MAX_VALUE, Integer.MIN_VALUE);
ImmutableList.Builder<Integer> builder = ImmutableList.builder();
builder.addAll(baseList);
for (int i = baseList.size(); i < 60_000; i++) {
builder.add(i);
}
List<Integer> values = builder.build();
List<StripeFooter> stripeFooters = testIntegerDictionary(directConversionTester, values);
verifyDwrfDirectEncoding(getStripeSize(values.size()), stripeFooters);
}
@Test
public void testIntegerDictionaryRepeatingValuesDirect()
throws IOException
{
DirectConversionTester directConversionTester = new DirectConversionTester();
directConversionTester.add(0, 1000, false);
directConversionTester.add(0, 4000, true);
directConversionTester.add(16, 10_000, true);
List<Integer> values = new ArrayList<>();
values.addAll(Arrays.asList(Integer.MAX_VALUE, Integer.MIN_VALUE));
values.add(null);
int repeatInterval = 1500;
for (int i = values.size(); i < 60_000; i++) {
values.add(INTEGER_VALUES_DICTIONARY_BASE + i % repeatInterval);
}
List<StripeFooter> stripeFooters = testIntegerDictionary(directConversionTester, values);
assertEquals(getStripeSize(values.size()), stripeFooters.size());
verifyDwrfDirectEncoding(stripeFooters, 0);
verifyDwrfDirectEncoding(stripeFooters, 1);
verifyDictionaryEncoding(stripeFooters, DWRF, 2, repeatInterval);
verifyDictionaryEncoding(stripeFooters, DWRF, 3, repeatInterval);
}
@Test
public void testLongDictionaryNonRepeating()
throws IOException
{
DirectConversionTester directConversionTester = new DirectConversionTester();
ImmutableList<Long> baseList = ImmutableList.of(Long.MAX_VALUE, Long.MIN_VALUE);
ImmutableList.Builder<Long> builder = ImmutableList.builder();
builder.addAll(baseList);
for (long i = baseList.size(); i < 100_000; i++) {
builder.add(i);
}
List<Long> values = builder.build();
List<StripeFooter> stripeFooters = testLongDictionary(directConversionTester, values);
verifyDwrfDirectEncoding(getStripeSize(values.size()), stripeFooters);
}
@Test
public void testLongDictionaryRepeatingValuesDirect()
throws IOException
{
DirectConversionTester directConversionTester = new DirectConversionTester();
directConversionTester.add(0, 1000, false);
directConversionTester.add(0, 4000, true);
directConversionTester.add(16, 10_000, true);
ImmutableList<Long> baseList = ImmutableList.of(Long.MAX_VALUE, Long.MIN_VALUE);
ImmutableList.Builder<Long> builder = ImmutableList.builder();
builder.addAll(baseList);
int repeatInterval = 1500;
for (long i = baseList.size(); i < 50_000; i++) {
builder.add(INTEGER_VALUES_DICTIONARY_BASE + i % repeatInterval);
}
List<Long> values = builder.build();
List<StripeFooter> stripeFooters = testLongDictionary(directConversionTester, values);
assertEquals(getStripeSize(values.size()), stripeFooters.size());
verifyDwrfDirectEncoding(stripeFooters, 0);
verifyDwrfDirectEncoding(stripeFooters, 1);
verifyDictionaryEncoding(stripeFooters, DWRF, 2, repeatInterval);
verifyDwrfDirectEncoding(stripeFooters, 3);
}
@Test
public void testLongPreserveDirectEncoding()
throws IOException
{
ImmutableList.Builder<Long> builder = ImmutableList.builder();
for (long i = 0; i < STRIPE_MAX_ROWS; i++) {
builder.add(i);
}
int repeatInterval = 1500;
for (long i = 0; i < 100_000; i++) {
builder.add(INTEGER_VALUES_DICTIONARY_BASE + i % repeatInterval);
}
DirectConversionTester tester = new DirectConversionTester();
int preserveDirectEncodingStripeCount = 2;
OrcWriterOptions orcWriterOptions = OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder().withStripeMaxRowCount(STRIPE_MAX_ROWS).build())
.withIntegerDictionaryEncodingEnabled(true)
.withPreserveDirectEncodingStripeCount(preserveDirectEncodingStripeCount)
.build();
List<Long> values = builder.build();
List<StripeFooter> stripeFooters = testDictionary(BIGINT, DWRF, orcWriterOptions, tester, values);
assertEquals(getStripeSize(values.size()), stripeFooters.size());
for (int i = 0; i <= preserveDirectEncodingStripeCount; i++) {
verifyDwrfDirectEncoding(stripeFooters, i);
}
for (int i = preserveDirectEncodingStripeCount + 1; i < stripeFooters.size(); i++) {
verifyDictionaryEncoding(stripeFooters, DWRF, i, repeatInterval);
}
}
@Test
public void verifyIntegerInList()
throws IOException
{
List<List<Integer>> values = new ArrayList<>();
for (int i = 0; i < 50_000; i++) {
int childSize = i % 5;
if (childSize == 4) {
values.add(null);
}
else {
List<Integer> childList = new ArrayList<>();
for (int j = 0; j < childSize; j++) {
childList.add(i + j);
}
values.add(childList);
}
}
DirectConversionTester directConversionTester = new DirectConversionTester();
Type listType = new ArrayType(INTEGER);
testDictionary(listType, DWRF, true, true, directConversionTester, values);
}
@Test
public void verifyChildElementEmptyOrMissingInList()
throws IOException
{
List<List<Integer>> values = new ArrayList<>();
List<Integer> emptyChildList = new ArrayList<>();
for (int i = 0; i < 50_000; i++) {
int childSize = i % 2;
values.add(childSize == 0 ? null : emptyChildList);
}
DirectConversionTester directConversionTester = new DirectConversionTester();
Type listType = new ArrayType(INTEGER);
testDictionary(listType, DWRF, true, true, directConversionTester, values);
}
@Test
public void verifyStringInList()
throws IOException
{
List<List<String>> values = new ArrayList<>();
for (int i = 0; i < 50_000; i++) {
int childSize = i % 5;
if (childSize == 4) {
values.add(null);
}
else {
List<String> childList = new ArrayList<>();
for (int j = 0; j < childSize; j++) {
childList.add(Integer.toString(i + j));
}
values.add(childList);
}
}
DirectConversionTester directConversionTester = new DirectConversionTester();
Type listType = new ArrayType(VARCHAR);
testDictionary(listType, DWRF, true, true, directConversionTester, values);
}
@Test
public void verifyStringEmptyOrMissingInList()
throws IOException
{
List<List<String>> values = new ArrayList<>();
List<String> emptyChildList = new ArrayList<>();
for (int i = 0; i < 50_000; i++) {
int childSize = i % 2;
values.add(childSize == 0 ? null : emptyChildList);
}
DirectConversionTester directConversionTester = new DirectConversionTester();
Type listType = new ArrayType(VARCHAR);
testDictionary(listType, DWRF, true, true, directConversionTester, values);
}
private ColumnEncoding getColumnEncoding(List<StripeFooter> stripeFooters, int stripeId)
{
StripeFooter stripeFooter = stripeFooters.get(stripeId);
return stripeFooter.getColumnEncodings().get(COLUMN_ID);
}
private void verifyDwrfDirectEncoding(List<StripeFooter> stripeFooters, int stripeId)
{
assertEquals(getColumnEncoding(stripeFooters, stripeId).getColumnEncodingKind(), DWRF_DIRECT, "StripeId " + stripeId);
}
private void verifyDirectEncoding(List<StripeFooter> stripeFooters, OrcEncoding encoding, int stripeId)
{
ColumnEncoding columnEncoding = getColumnEncoding(stripeFooters, stripeId);
if (encoding.equals(DWRF)) {
assertEquals(columnEncoding.getColumnEncodingKind(), DIRECT, "Encoding " + encoding + " StripeId " + stripeId);
}
else {
assertEquals(encoding, ORC);
assertEquals(columnEncoding.getColumnEncodingKind(), DIRECT_V2, "Encoding " + encoding + " StripeId " + stripeId);
}
}
private void verifyDictionaryEncoding(List<StripeFooter> stripeFooters, OrcEncoding encoding, int stripeId, int dictionarySize)
{
ColumnEncoding columnEncoding = getColumnEncoding(stripeFooters, stripeId);
if (encoding.equals(DWRF)) {
assertEquals(columnEncoding.getColumnEncodingKind(), DICTIONARY, "Encoding " + encoding + " StripeId " + stripeId);
}
else {
assertEquals(encoding, ORC);
assertEquals(columnEncoding.getColumnEncodingKind(), DICTIONARY_V2, "Encoding " + encoding + " StripeId " + stripeId);
}
assertEquals(columnEncoding.getDictionarySize(), dictionarySize, "Encoding " + encoding + " StripeId " + stripeId);
}
private void verifyDictionaryEncoding(int stripeCount, OrcEncoding encoding, List<StripeFooter> stripeFooters, List<Integer> dictionarySizes)
{
assertEquals(stripeFooters.size(), stripeCount);
for (int i = 0; i < stripeFooters.size(); i++) {
verifyDictionaryEncoding(stripeFooters, encoding, i, dictionarySizes.get(i));
}
}
private void verifyDirectEncoding(int stripeCount, OrcEncoding encoding, List<StripeFooter> stripeFooters)
{
assertEquals(stripeFooters.size(), stripeCount);
for (int i = 0; i < stripeCount; i++) {
verifyDirectEncoding(stripeFooters, encoding, i);
}
}
private void verifyDwrfDirectEncoding(int stripeCount, List<StripeFooter> stripeFooters)
{
assertEquals(stripeFooters.size(), stripeCount);
for (StripeFooter footer : stripeFooters) {
ColumnEncoding encoding = footer.getColumnEncodings().get(COLUMN_ID);
assertEquals(encoding.getColumnEncodingKind(), DWRF_DIRECT);
}
}
private List<StripeFooter> testLongDictionary(DirectConversionTester directConversionTester, List<?> values)
throws IOException
{
return testDictionary(BIGINT, DWRF, true, true, directConversionTester, values);
}
private List<StripeFooter> testIntegerDictionary(DirectConversionTester directConversionTester, List<?> values)
throws IOException
{
return testDictionary(INTEGER, DWRF, true, true, directConversionTester, values);
}
private List<StripeFooter> testStringDictionary(DirectConversionTester directConversionTester, StringDictionaryInput dictionaryInput, List<String> values)
throws IOException
{
return testDictionary(VARCHAR, dictionaryInput.getEncoding(), false, dictionaryInput.isSortStringDictionaryKeys(), directConversionTester, values);
}
private List<StripeFooter> testDictionary(Type type, OrcEncoding encoding, boolean enableIntDictionary, boolean sortStringDictionaryKeys, DirectConversionTester directConversionTester, List<?> values)
throws IOException
{
OrcWriterOptions orcWriterOptions = OrcWriterOptions.builder()
.withFlushPolicy(DefaultOrcWriterFlushPolicy.builder().withStripeMaxRowCount(STRIPE_MAX_ROWS).build())
.withIntegerDictionaryEncodingEnabled(enableIntDictionary)
.withStringDictionarySortingEnabled(sortStringDictionaryKeys)
.build();
return testDictionary(type, encoding, orcWriterOptions, directConversionTester, values);
}
private static boolean isArrayType(Type type)
{
return type.getTypeSignature().getBase().equals(ARRAY);
}
private void appendListToBlock(Type type, List<?> values, BlockBuilder blockBuilder, int startIndex, int endIndex)
{
while (startIndex < endIndex) {
Object value = values.get(startIndex++);
if (value == null) {
blockBuilder.appendNull();
}
else {
if (isArrayType(type)) {
List<?> childList = (List<?>) value;
BlockBuilder childBlockBuilder = blockBuilder.beginBlockEntry();
appendListToBlock(type.getTypeParameters().get(0), childList, childBlockBuilder, 0, childList.size());
blockBuilder.closeEntry();
}
else if (type.equals(VARCHAR)) {
type.writeSlice(blockBuilder, utf8Slice((String) value));
}
else {
Number number = (Number) value;
type.writeLong(blockBuilder, number.longValue());
}
}
}
}
private List<StripeFooter> testDictionary(Type type, OrcEncoding encoding, OrcWriterOptions orcWriterOptions, DirectConversionTester directConversionTester, List<?> values)
throws IOException
{
List<Type> types = ImmutableList.of(type);
try (TempFile tempFile = new TempFile()) {
OrcWriter writer = createOrcWriter(tempFile.getFile(), encoding, ZSTD, Optional.empty(), types, orcWriterOptions, NOOP_WRITER_STATS);
int index = 0;
int batchId = 0;
while (index < values.size()) {
int end = Math.min(index + BATCH_ROWS, values.size());
BlockBuilder blockBuilder = type.createBlockBuilder(null, end - index);
appendListToBlock(type, values, blockBuilder, index, end);
Block[] blocks = new Block[] {blockBuilder.build()};
writer.write(new Page(blocks));
directConversionTester.validate(batchId, writer);
batchId++;
index = end;
}
writer.close();
writer.validate(new FileOrcDataSource(
tempFile.getFile(),
new DataSize(1, MEGABYTE),
new DataSize(1, MEGABYTE),
new DataSize(1, MEGABYTE),
true));
index = 0;
try (OrcSelectiveRecordReader reader = createCustomOrcSelectiveRecordReader(
tempFile,
encoding,
OrcPredicate.TRUE,
type,
INITIAL_BATCH_SIZE,
true,
false)) {
while (index < values.size()) {
Page page = reader.getNextPage();
if (page == null) {
break;
}
Block block = page.getBlock(0).getLoadedBlock();
index = verifyBlock(type, values, index, block);
}
assertEquals(index, values.size());
}
return OrcTester.getStripes(tempFile.getFile(), encoding);
}
}
private int verifyBlock(Type type, List<?> values, int index, Block block)
{
for (int i = 0; i < block.getPositionCount(); i++) {
Object value = values.get(index++);
assertEquals(block.isNull(i), value == null);
if (value != null) {
if (type.equals(VARCHAR)) {
assertEquals(block.getSlice(i, 0, block.getSliceLength(i)), utf8Slice((String) value));
}
else if (isArrayType(type)) {
List<?> childList = (List<?>) value;
Block childBlock = block.getBlock(i);
int childIndex = verifyBlock(type.getTypeParameters().get(0), childList, 0, childBlock);
assertEquals(childIndex, childList.size());
}
else {
Number number = (Number) value;
assertEquals(type.getLong(block, i), number.longValue());
}
}
}
return index;
}
private static class StringDictionaryInput
{
private final OrcEncoding encoding;
private final boolean sortStringDictionaryKeys;
StringDictionaryInput(OrcEncoding encoding, boolean sortStringDictionaryKeys)
{
this.encoding = encoding;
this.sortStringDictionaryKeys = sortStringDictionaryKeys;
}
public OrcEncoding getEncoding()
{
return encoding;
}
public boolean isSortStringDictionaryKeys()
{
return sortStringDictionaryKeys;
}
static List<StringDictionaryInput> values()
{
return ImmutableList.of(
new StringDictionaryInput(ORC, true),
new StringDictionaryInput(DWRF, true),
new StringDictionaryInput(DWRF, false));
}
}
private static class DirectConversionTester
{
private final List<Integer> batchIds = new ArrayList<>();
private final List<Integer> maxDirectBytes = new ArrayList<>();
private final List<Boolean> expectedResults = new ArrayList<>();
private int index;
private int lastBatchId = -1;
void add(int batchId, int maxBytes, boolean expectedResult)
{
batchIds.add(batchId);
maxDirectBytes.add(maxBytes);
expectedResults.add(expectedResult);
}
void validate(int batchId, OrcWriter writer)
{
checkState(batchId > lastBatchId);
lastBatchId = batchId;
while (true) {
if (index >= batchIds.size() || batchIds.get(index) != batchId) {
return;
}
DictionaryColumnWriter columnWriter = (DictionaryColumnWriter) writer.getColumnWriters().get(0);
assertFalse(columnWriter.isDirectEncoded(), "BatchId " + batchId + "is Direct encoded");
int bufferedBytes = maxDirectBytes.get(index);
if (!expectedResults.get(index)) {
// Failed Conversion to direct, can be invoked on column writer, as the dictionary
// compression optimizer state does not change.
assertFalse(columnWriter.tryConvertToDirect(bufferedBytes).isPresent(), "BatchId " + batchId + " bytes " + bufferedBytes);
}
else {
// Successful conversion to direct, changes the state of the dictionary compression
// optimizer and it should go only via dictionary compression optimizer.
List<DictionaryColumnManager> directConversionCandidates = writer.getDictionaryCompressionOptimizer().getDirectConversionCandidates();
boolean contains = directConversionCandidates.stream().anyMatch(x -> x.getDictionaryColumn() == columnWriter);
assertTrue(contains);
writer.getDictionaryCompressionOptimizer().convertLowCompressionStreams(true, bufferedBytes);
assertTrue(columnWriter.isDirectEncoded(), "BatchId " + batchId + " bytes " + bufferedBytes);
contains = directConversionCandidates.stream().anyMatch(x -> x.getDictionaryColumn() == columnWriter);
assertFalse(contains);
}
index++;
}
}
}
}