TestMapFlatBatchStreamReader.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc;
import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.NamedTypeSignature;
import com.facebook.presto.common.type.RealType;
import com.facebook.presto.common.type.RowFieldName;
import com.facebook.presto.common.type.SmallintType;
import com.facebook.presto.common.type.SqlVarbinary;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.TinyintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.common.type.VarbinaryType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.orc.cache.StorageOrcFileTailSource;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.DataSize;
import org.testng.annotations.Test;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.orc.DwrfEncryptionProvider.NO_ENCRYPTION;
import static com.facebook.presto.orc.NoopOrcAggregatedMemoryContext.NOOP_ORC_AGGREGATED_MEMORY_CONTEXT;
import static com.facebook.presto.orc.OrcTester.HIVE_STORAGE_TIME_ZONE;
import static com.facebook.presto.orc.OrcTester.mapType;
import static com.facebook.presto.orc.TestMapFlatBatchStreamReader.ExpectedValuesBuilder.Frequency.ALL;
import static com.facebook.presto.orc.TestMapFlatBatchStreamReader.ExpectedValuesBuilder.Frequency.ALL_EXCEPT_FIRST;
import static com.facebook.presto.orc.TestMapFlatBatchStreamReader.ExpectedValuesBuilder.Frequency.NONE;
import static com.facebook.presto.orc.TestMapFlatBatchStreamReader.ExpectedValuesBuilder.Frequency.SOME;
import static com.facebook.presto.orc.TestingOrcPredicate.createOrcPredicate;
import static com.facebook.presto.testing.TestingConnectorSession.SESSION;
import static com.google.common.collect.Iterators.advance;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.Math.toIntExact;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
public class TestMapFlatBatchStreamReader
{
// TODO: Add tests for timestamp as value type
private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = createTestFunctionAndTypeManager();
private static final int NUM_ROWS = 31_234;
private static final Type LIST_TYPE = FUNCTION_AND_TYPE_MANAGER.getParameterizedType(
StandardTypes.ARRAY,
ImmutableList.of(TypeSignatureParameter.of(IntegerType.INTEGER.getTypeSignature())));
private static final Type MAP_TYPE = FUNCTION_AND_TYPE_MANAGER.getParameterizedType(
StandardTypes.MAP,
ImmutableList.of(TypeSignatureParameter.of(VarcharType.VARCHAR.getTypeSignature()), TypeSignatureParameter.of(RealType.REAL.getTypeSignature())));
private static final Type STRUCT_TYPE = FUNCTION_AND_TYPE_MANAGER.getParameterizedType(
StandardTypes.ROW,
ImmutableList.of(
TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName("value1", false)), IntegerType.INTEGER.getTypeSignature())),
TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName("value2", false)), IntegerType.INTEGER.getTypeSignature())),
TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName("value3", false)), IntegerType.INTEGER.getTypeSignature()))));
@Test
public void testByte()
throws Exception
{
runTest("test_flat_map/flat_map_byte.dwrf",
TinyintType.TINYINT,
ExpectedValuesBuilder.get(Integer::byteValue));
}
@Test
public void testByteWithNull()
throws Exception
{
runTest("test_flat_map/flat_map_byte_with_null.dwrf",
TinyintType.TINYINT,
ExpectedValuesBuilder.get(Integer::byteValue).setNullValuesFrequency(SOME));
}
@Test
public void testShort()
throws Exception
{
runTest("test_flat_map/flat_map_short.dwrf",
SmallintType.SMALLINT,
ExpectedValuesBuilder.get(Integer::shortValue));
}
@Test
public void testInteger()
throws Exception
{
runTest("test_flat_map/flat_map_int.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()));
}
@Test
public void testIntegerWithNull()
throws Exception
{
runTest("test_flat_map/flat_map_int_with_null.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setNullValuesFrequency(SOME));
}
@Test
public void testIntegerWithSharedDictionary()
throws Exception
{
runTest("test_flat_map/flat_map_dict_share_simple.dwrf",
INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setNumRows(2048));
}
@Test
public void testLong()
throws Exception
{
runTest("test_flat_map/flat_map_long.dwrf",
BigintType.BIGINT,
ExpectedValuesBuilder.get(Integer::longValue));
}
@Test
public void testString()
throws Exception
{
runTest("test_flat_map/flat_map_string.dwrf",
VarcharType.VARCHAR,
ExpectedValuesBuilder.get(i -> Integer.toString(i)));
}
@Test
public void testStringWithNull()
throws Exception
{
runTest("test_flat_map/flat_map_string_with_null.dwrf",
VarcharType.VARCHAR,
ExpectedValuesBuilder.get(i -> Integer.toString(i)).setNullValuesFrequency(SOME));
}
@Test
public void testBinary()
throws Exception
{
runTest("test_flat_map/flat_map_binary.dwrf",
VarbinaryType.VARBINARY,
ExpectedValuesBuilder.get(i -> new SqlVarbinary(Integer.toString(i).getBytes(StandardCharsets.UTF_8))));
}
@Test
public void testBoolean()
throws Exception
{
runTest("test_flat_map/flat_map_boolean.dwrf",
IntegerType.INTEGER,
BooleanType.BOOLEAN,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToBoolean));
}
@Test
public void testBooleanWithNull()
throws Exception
{
runTest("test_flat_map/flat_map_boolean_with_null.dwrf",
IntegerType.INTEGER,
BooleanType.BOOLEAN,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToBoolean).setNullValuesFrequency(SOME));
}
@Test
public void testFloat()
throws Exception
{
runTest("test_flat_map/flat_map_float.dwrf",
IntegerType.INTEGER,
RealType.REAL,
ExpectedValuesBuilder.get(Function.identity(), Float::valueOf));
}
@Test
public void testFloatWithNull()
throws Exception
{
runTest("test_flat_map/flat_map_float_with_null.dwrf",
IntegerType.INTEGER,
RealType.REAL,
ExpectedValuesBuilder.get(Function.identity(), Float::valueOf).setNullValuesFrequency(SOME));
}
@Test
public void testDouble()
throws Exception
{
runTest("test_flat_map/flat_map_double.dwrf",
IntegerType.INTEGER,
DoubleType.DOUBLE,
ExpectedValuesBuilder.get(Function.identity(), Double::valueOf));
}
@Test
public void testDoubleWithNull()
throws Exception
{
runTest("test_flat_map/flat_map_double_with_null.dwrf",
IntegerType.INTEGER,
DoubleType.DOUBLE,
ExpectedValuesBuilder.get(Function.identity(), Double::valueOf).setNullValuesFrequency(SOME));
}
@Test
public void testList()
throws Exception
{
runTest(
"test_flat_map/flat_map_list.dwrf",
IntegerType.INTEGER,
LIST_TYPE,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToList));
}
@Test
public void testListWithNull()
throws Exception
{
runTest(
"test_flat_map/flat_map_list_with_null.dwrf",
IntegerType.INTEGER,
LIST_TYPE,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToList).setNullValuesFrequency(SOME));
}
@Test
public void testMap()
throws Exception
{
runTest(
"test_flat_map/flat_map_map.dwrf",
IntegerType.INTEGER,
MAP_TYPE,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToMap));
}
@Test
public void testMapWithNull()
throws Exception
{
runTest(
"test_flat_map/flat_map_map_with_null.dwrf",
IntegerType.INTEGER,
MAP_TYPE,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToMap).setNullValuesFrequency(SOME));
}
@Test
public void testMapWithSharedDictionary()
throws Exception
{
runTest(
"test_flat_map/flat_map_dict_share_nested.dwrf",
BIGINT,
mapType(VARCHAR, INTEGER),
ExpectedValuesBuilder.get(x -> (long) x, TestMapFlatBatchStreamReader::intToIntMap).setNumRows(2048));
}
@Test
public void testStruct()
throws Exception
{
runTest(
"test_flat_map/flat_map_struct.dwrf",
IntegerType.INTEGER,
STRUCT_TYPE,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToList));
}
@Test
public void testStructWithNull()
throws Exception
{
runTest(
"test_flat_map/flat_map_struct_with_null.dwrf",
IntegerType.INTEGER,
STRUCT_TYPE,
ExpectedValuesBuilder.get(Function.identity(), TestMapFlatBatchStreamReader::intToList).setNullValuesFrequency(SOME));
}
@Test
public void testWithNulls()
throws Exception
{
// A test case where some of the flat maps are null
runTest(
"test_flat_map/flat_map_some_null_maps.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setNullRowsFrequency(SOME));
}
@Test
public void testWithAllNulls()
throws Exception
{
// A test case where every flat map is null
runTest(
"test_flat_map/flat_map_all_null_maps.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setNullRowsFrequency(ALL));
}
@Test
public void testWithAllNullsExceptFirst()
throws Exception
{
// A test case where every flat map is null except the first one
runTest(
"test_flat_map/flat_map_all_null_maps_except_first.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setNullRowsFrequency(ALL_EXCEPT_FIRST));
}
@Test
public void testWithEmptyMaps()
throws Exception
{
// A test case where some of the flat maps are empty
runTest(
"test_flat_map/flat_map_some_empty_maps.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setEmptyMapsFrequency(SOME));
}
@Test
public void testWithAllMaps()
throws Exception
{
// A test case where all of the flat maps are empty
runTest(
"test_flat_map/flat_map_all_empty_maps.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setEmptyMapsFrequency(ALL));
}
// All maps are empty and encoding is not present
@Test
public void testWithAllEmptyMapsWithNoEncoding()
throws Exception
{
runTest(
"test_flat_map/flat_map_all_empty_maps_no_encoding.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setEmptyMapsFrequency(ALL));
}
@Test
public void testMixedEncodings()
throws Exception
{
// A test case where the values associated with one key are direct encoded, and all other keys are
// dictionary encoded. The dictionary encoded values have occasional values that only appear once
// to ensure the IN_DICTIONARY stream is present, which means the checkpoints for dictionary encoded
// values will have a different number of positions compared to direct encoded values.
runTest("test_flat_map/flat_map_mixed_encodings.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setMixedEncodings());
}
@Test
public void testIntegerWithMissingSequences()
throws Exception
{
// A test case where the additional sequences IDs for a flat map aren't a consecutive range [1,N], the odd
// sequence IDs have been removed. This is to simulate the case where a file has been modified to delete
// certain keys from the map by dropping the ColumnEncodings and the associated data.
runTest("test_flat_map/flat_map_int_missing_sequences.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()).setMissingSequences());
}
@Test
public void testIntegerWithMissingSequence0()
throws Exception
{
// A test case where the (dummy) encoding for sequence 0 of the value node doesn't exist
runTest("test_flat_map/flat_map_int_missing_sequence_0.dwrf",
IntegerType.INTEGER,
ExpectedValuesBuilder.get(Function.identity()));
}
private <K, V> void runTest(String testOrcFileName, Type type, ExpectedValuesBuilder<K, V> expectedValuesBuilder)
throws Exception
{
runTest(testOrcFileName, type, type, expectedValuesBuilder);
}
private <K, V> void runTest(String testOrcFileName, Type keyType, Type valueType, ExpectedValuesBuilder<K, V> expectedValuesBuilder)
throws Exception
{
List<Map<K, V>> expectedValues = expectedValuesBuilder.build();
runTest(testOrcFileName, keyType, valueType, expectedValues, false, false);
runTest(testOrcFileName, keyType, valueType, expectedValues, true, false);
runTest(testOrcFileName, keyType, valueType, expectedValues, false, true);
}
private <K, V> void runTest(String testOrcFileName, Type keyType, Type valueType, List<Map<K, V>> expectedValues, boolean skipFirstBatch, boolean skipFirstStripe)
throws Exception
{
OrcDataSource orcDataSource = new FileOrcDataSource(
OrcReaderTestingUtils.getResourceFile(testOrcFileName),
new DataSize(1, MEGABYTE),
new DataSize(1, MEGABYTE),
new DataSize(1, MEGABYTE),
true);
OrcReader orcReader = new OrcReader(
orcDataSource,
OrcEncoding.DWRF,
new StorageOrcFileTailSource(),
new StorageStripeMetadataSource(),
NOOP_ORC_AGGREGATED_MEMORY_CONTEXT,
OrcReaderTestingUtils.createDefaultTestConfig(),
false,
NO_ENCRYPTION,
DwrfKeyProvider.EMPTY,
new RuntimeStats());
Type mapType = FUNCTION_AND_TYPE_MANAGER.getParameterizedType(
StandardTypes.MAP,
ImmutableList.of(
TypeSignatureParameter.of(keyType.getTypeSignature()),
TypeSignatureParameter.of(valueType.getTypeSignature())));
try (OrcBatchRecordReader recordReader = orcReader.createBatchRecordReader(
ImmutableMap.of(0, mapType),
createOrcPredicate(0, mapType, expectedValues, OrcTester.Format.DWRF, true),
HIVE_STORAGE_TIME_ZONE,
new TestingHiveOrcAggregatedMemoryContext(),
1024)) {
Iterator<?> expectedValuesIterator = expectedValues.iterator();
boolean isFirst = true;
int rowsProcessed = 0;
for (int batchSize = toIntExact(recordReader.nextBatch()); batchSize >= 0; batchSize = toIntExact(recordReader.nextBatch())) {
if (skipFirstStripe && rowsProcessed < 10_000) {
assertEquals(advance(expectedValuesIterator, batchSize), batchSize);
}
else if (skipFirstBatch && isFirst) {
assertEquals(advance(expectedValuesIterator, batchSize), batchSize);
isFirst = false;
}
else {
Block block = recordReader.readBlock(0);
for (int position = 0; position < block.getPositionCount(); position++) {
assertEquals(mapType.getObjectValue(SESSION.getSqlFunctionProperties(), block, position), expectedValuesIterator.next(), String.format("row mismatch at processed rows %d, position %d", rowsProcessed, position));
}
}
assertEquals(recordReader.getReaderPosition(), rowsProcessed);
assertEquals(recordReader.getFilePosition(), rowsProcessed);
rowsProcessed += batchSize;
}
assertFalse(expectedValuesIterator.hasNext());
assertEquals(recordReader.getReaderPosition(), rowsProcessed);
assertEquals(recordReader.getFilePosition(), rowsProcessed);
}
}
private static boolean intToBoolean(int i)
{
return i % 2 == 0;
}
private static List<Integer> intToList(int i)
{
return ImmutableList.of(i * 3, i * 3 + 1, i * 3 + 2);
}
private static Map<String, Float> intToMap(int i)
{
return ImmutableMap.of(Integer.toString(i * 3), (float) (i * 3), Integer.toString(i * 3 + 1), (float) (i * 3 + 1), Integer.toString(i * 3 + 2), (float) (i * 3 + 2));
}
private static Map<String, Integer> intToIntMap(int i)
{
return ImmutableMap.of(Integer.toString(i * 3), i * 3, Integer.toString(i * 3 + 1), i * 3 + 1, Integer.toString(i * 3 + 2), i * 3 + 2);
}
static class ExpectedValuesBuilder<K, V>
{
enum Frequency
{
NONE,
SOME,
ALL,
ALL_EXCEPT_FIRST,
}
private final Function<Integer, K> keyConverter;
private final Function<Integer, V> valueConverter;
private Frequency nullValuesFrequency = NONE;
private Frequency nullRowsFrequency = NONE;
private Frequency emptyMapsFrequency = NONE;
private boolean mixedEncodings;
private boolean missingSequences;
private int numRows = NUM_ROWS;
private ExpectedValuesBuilder(Function<Integer, K> keyConverter, Function<Integer, V> valueConverter)
{
this.keyConverter = keyConverter;
this.valueConverter = valueConverter;
}
public static <T> ExpectedValuesBuilder<T, T> get(Function<Integer, T> converter)
{
return new ExpectedValuesBuilder<>(converter, converter);
}
public static <K, V> ExpectedValuesBuilder<K, V> get(Function<Integer, K> keyConverter, Function<Integer, V> valueConverter)
{
return new ExpectedValuesBuilder<>(keyConverter, valueConverter);
}
public ExpectedValuesBuilder<K, V> setNullValuesFrequency(Frequency frequency)
{
this.nullValuesFrequency = frequency;
return this;
}
public ExpectedValuesBuilder<K, V> setNullRowsFrequency(Frequency frequency)
{
this.nullRowsFrequency = frequency;
return this;
}
public ExpectedValuesBuilder<K, V> setEmptyMapsFrequency(Frequency frequency)
{
this.emptyMapsFrequency = frequency;
return this;
}
public ExpectedValuesBuilder<K, V> setMixedEncodings()
{
this.mixedEncodings = true;
return this;
}
public ExpectedValuesBuilder<K, V> setMissingSequences()
{
this.missingSequences = true;
return this;
}
public ExpectedValuesBuilder<K, V> setNumRows(int numRows)
{
this.numRows = numRows;
return this;
}
public List<Map<K, V>> build()
{
List<Map<K, V>> result = new ArrayList<>(numRows);
for (int i = 0; i < numRows; ++i) {
if (passesFrequencyCheck(nullRowsFrequency, i)) {
result.add(null);
}
else if (passesFrequencyCheck(emptyMapsFrequency, i)) {
result.add(Collections.emptyMap());
}
else {
Map<K, V> row = new HashMap<>();
for (int j = 0; j < 3; j++) {
V value;
int key = (i * 3 + j) % 32;
if (missingSequences && key % 2 == 1) {
continue;
}
if (j == 0 && passesFrequencyCheck(nullValuesFrequency, i)) {
value = null;
}
else if (mixedEncodings && (key == 1 || j == 2)) {
// TODO: add comments to explain the condition
value = valueConverter.apply(i * 3 + j);
}
else {
value = valueConverter.apply((i * 3 + j) % 32);
}
row.put(keyConverter.apply(key), value);
}
result.add(row);
}
}
return result;
}
private boolean passesFrequencyCheck(Frequency frequency, int i)
{
switch (frequency) {
case NONE:
return false;
case ALL:
return true;
case SOME:
return i % 5 == 0;
case ALL_EXCEPT_FIRST:
return i != 0;
default:
throw new IllegalArgumentException("Got unexpected Frequency: " + frequency);
}
}
}
}