TestSliceDictionaryColumnWriter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc.writer;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.RunLengthEncodedBlock;
import com.facebook.presto.orc.ColumnWriterOptions;
import com.facebook.presto.orc.OrcCorruptionException;
import com.facebook.presto.orc.OrcDataSourceId;
import com.facebook.presto.orc.OrcDecompressor;
import com.facebook.presto.orc.OrcEncoding;
import com.facebook.presto.orc.TestingHiveOrcAggregatedMemoryContext;
import com.facebook.presto.orc.metadata.Stream.StreamKind;
import com.facebook.presto.orc.stream.ByteArrayInputStream;
import com.facebook.presto.orc.stream.LongInputStream;
import com.facebook.presto.orc.stream.LongInputStreamV1;
import com.facebook.presto.orc.stream.LongInputStreamV2;
import com.facebook.presto.orc.stream.OrcInputStream;
import com.facebook.presto.orc.stream.SharedBuffer;
import com.facebook.presto.orc.stream.StreamDataOutput;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import org.testng.annotations.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.orc.OrcDecompressor.createOrcDecompressor;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static com.facebook.presto.orc.OrcEncoding.ORC;
import static com.facebook.presto.orc.metadata.ColumnEncoding.DEFAULT_SEQUENCE_ID;
import static com.facebook.presto.orc.metadata.CompressionKind.SNAPPY;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.DICTIONARY_DATA;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.LENGTH;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.Math.toIntExact;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
public class TestSliceDictionaryColumnWriter
{
private static final int COLUMN_ID = 1;
private static final OrcDataSourceId ORC_DATA_SOURCE_ID = new OrcDataSourceId("test");
private StreamDataOutput getStreamKind(List<StreamDataOutput> streams, StreamKind streamKind)
{
return streams.stream()
.filter(e -> e.getStream().getStreamKind() == streamKind)
.collect(onlyElement());
}
private Optional<OrcDecompressor> getOrcDecompressor()
{
int compressionBlockSize = toIntExact(new DataSize(256, KILOBYTE).toBytes());
return createOrcDecompressor(ORC_DATA_SOURCE_ID, SNAPPY, compressionBlockSize);
}
private OrcInputStream convertSliceToInputStream(Slice slice)
{
TestingHiveOrcAggregatedMemoryContext aggregatedMemoryContext = new TestingHiveOrcAggregatedMemoryContext();
return new OrcInputStream(
ORC_DATA_SOURCE_ID,
new SharedBuffer(aggregatedMemoryContext.newOrcLocalMemoryContext("sharedDecompressionBuffer")),
slice.getInput(),
getOrcDecompressor(),
Optional.empty(),
aggregatedMemoryContext,
slice.getRetainedSize());
}
private Slice convertStreamToSlice(StreamDataOutput streamDataOutput)
throws OrcCorruptionException
{
DynamicSliceOutput sliceOutput = new DynamicSliceOutput(toIntExact(streamDataOutput.size()));
streamDataOutput.writeData(sliceOutput);
return sliceOutput.slice();
}
private OrcInputStream getOrcInputStream(List<StreamDataOutput> streams, StreamKind streamKind)
throws OrcCorruptionException
{
StreamDataOutput stream = getStreamKind(streams, streamKind);
Slice slice = convertStreamToSlice(stream);
return convertSliceToInputStream(slice);
}
private LongInputStream getDictionaryLengthStream(List<StreamDataOutput> streams, OrcEncoding orcEncoding)
{
if (orcEncoding == DWRF) {
return new LongInputStreamV1(getOrcInputStream(streams, LENGTH), false);
}
return new LongInputStreamV2(getOrcInputStream(streams, LENGTH), false, false);
}
private List<String> getDictionaryKeys(List<String> values, OrcEncoding orcEncoding, boolean sortDictionaryKeys)
throws IOException
{
DictionaryColumnWriter writer = getDictionaryColumnWriter(orcEncoding, sortDictionaryKeys);
for (int index = 0; index < values.size(); ) {
int endIndex = Math.min(index + 10_000, values.size());
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 10_000);
while (index < endIndex) {
VARCHAR.writeSlice(blockBuilder, utf8Slice(values.get(index++)));
}
writer.beginRowGroup();
writer.writeBlock(blockBuilder);
writer.finishRowGroup();
}
writer.close();
List<StreamDataOutput> streams = writer.getDataStreams();
int dictionarySize = writer.getColumnEncodings().get(COLUMN_ID).getDictionarySize();
ByteArrayInputStream dictionaryDataStream = new ByteArrayInputStream(getOrcInputStream(streams, DICTIONARY_DATA));
LongInputStream dictionaryLengthStream = getDictionaryLengthStream(streams, orcEncoding);
List<String> dictionaryKeys = new ArrayList<>(dictionarySize);
for (int i = 0; i < dictionarySize; i++) {
int length = toIntExact(dictionaryLengthStream.next());
String dictionaryKey = new String(dictionaryDataStream.next(length), UTF_8);
dictionaryKeys.add(dictionaryKey);
}
return dictionaryKeys;
}
private DictionaryColumnWriter getDictionaryColumnWriter(OrcEncoding orcEncoding, boolean sortDictionaryKeys)
{
ColumnWriterOptions columnWriterOptions = ColumnWriterOptions.builder()
.setCompressionKind(SNAPPY)
.setStringDictionarySortingEnabled(sortDictionaryKeys)
.build();
DictionaryColumnWriter writer = new SliceDictionaryColumnWriter(
COLUMN_ID,
DEFAULT_SEQUENCE_ID,
VARCHAR,
columnWriterOptions,
Optional.empty(),
orcEncoding,
orcEncoding.createMetadataWriter());
return writer;
}
@Test
public void testSortedDictionaryKeys()
throws IOException
{
for (OrcEncoding orcEncoding : OrcEncoding.values()) {
List<String> sortedKeys = getDictionaryKeys(ImmutableList.of("b", "a", "c"), orcEncoding, true);
assertEquals(sortedKeys, ImmutableList.of("a", "b", "c"));
sortedKeys = getDictionaryKeys(ImmutableList.of("b", "b", "a"), orcEncoding, true);
assertEquals(sortedKeys, ImmutableList.of("a", "b"));
}
}
@Test
public void testUnsortedDictionaryKeys()
throws IOException
{
List<String> sortedKeys = getDictionaryKeys(ImmutableList.of("b", "a", "c"), DWRF, false);
assertEquals(sortedKeys, ImmutableList.of("b", "a", "c"));
sortedKeys = getDictionaryKeys(ImmutableList.of("b", "b", "a"), DWRF, false);
assertEquals(sortedKeys, ImmutableList.of("b", "a"));
}
@Test(expectedExceptions = IllegalStateException.class)
public void testOrcStringSortingDisabledThrows()
{
getDictionaryColumnWriter(ORC, false);
}
@Test
public void testStringDirectConversion()
{
// a single row group exceeds 2G after direct conversion
byte[] value = new byte[megabytes(1)];
ThreadLocalRandom.current().nextBytes(value);
Block data = RunLengthEncodedBlock.create(VARCHAR, Slices.wrappedBuffer(value), 3000);
for (OrcEncoding orcEncoding : OrcEncoding.values()) {
DictionaryColumnWriter writer = getDictionaryColumnWriter(orcEncoding, true);
writer.beginRowGroup();
writer.writeBlock(data);
writer.finishRowGroup();
assertFalse(writer.tryConvertToDirect(megabytes(64)).isPresent());
}
}
private static int megabytes(int size)
{
return toIntExact(new DataSize(size, MEGABYTE).toBytes());
}
}