BenchmarkDecimalColumnBatchReader.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.parquet.reader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.parquet.BenchmarkParquetReader;
import com.facebook.presto.parquet.Field;
import com.facebook.presto.parquet.FileParquetDataSource;
import com.facebook.presto.parquet.cache.MetadataReader;
import io.airlift.units.DataSize;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.column.ParquetProperties.WriterVersion;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.example.data.simple.SimpleGroupFactory;
import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.hadoop.example.GroupWriteSupport;
import org.apache.parquet.hadoop.metadata.CompressionCodecName;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.apache.parquet.io.ColumnIOConverter;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type.Repetition;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.results.format.ResultFormatType;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static com.facebook.presto.parquet.BenchmarkParquetReader.ROWS;
import static com.facebook.presto.parquet.ParquetTypeUtils.getColumnIO;
import static com.facebook.presto.parquet.reader.TestData.longToBytes;
import static com.facebook.presto.parquet.reader.TestData.maxPrecision;
import static com.facebook.presto.parquet.reader.TestData.unscaledRandomShortDecimalSupplier;
import static com.google.common.io.Files.createTempDir;
import static com.google.common.io.MoreFiles.deleteRecursively;
import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.lang.String.format;
import static java.time.format.DateTimeFormatter.ISO_DATE_TIME;
import static java.util.UUID.randomUUID;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0;
import static org.apache.parquet.hadoop.ParquetWriter.DEFAULT_BLOCK_SIZE;
import static org.apache.parquet.hadoop.ParquetWriter.DEFAULT_PAGE_SIZE;
import static org.apache.parquet.schema.MessageTypeParser.parseMessageType;
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(3)
@Warmup(iterations = 30, time = 500, timeUnit = MILLISECONDS)
@Measurement(iterations = 20, time = 500, timeUnit = MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
@OperationsPerInvocation(BenchmarkParquetReader.ROWS)
public class BenchmarkDecimalColumnBatchReader
{
public static final int DICT_PAGE_SIZE = 512;
public static final String FIELD_NAME = "decimal_test_column";
@Param({
"true", "false",
})
public boolean enableOptimizedReader;
@Param({
"true", "false",
})
public static boolean nullable = true;
@Param({
"PARQUET_1_0", "PARQUET_2_0",
})
// PARQUET_1_0 => PLAIN
// PARQUET_2_0 => DELTA_BYTE_ARRAY, DELTA_LENGTH_BYTE_ARRAY, DELTA_BYTE_ARRAY
public static WriterVersion writerVersion = PARQUET_1_0;
public static void main(String[] args)
throws Throwable
{
Options options = new OptionsBuilder()
.verbosity(VerboseMode.NORMAL)
.include(".*" + BenchmarkDecimalColumnBatchReader.class.getSimpleName() + ".*")
.resultFormat(ResultFormatType.JSON)
.result(format("%s/%s-result-%s.json", System.getProperty("java.io.tmpdir"), BenchmarkDecimalColumnBatchReader.class.getSimpleName(), ISO_DATE_TIME.format(LocalDateTime.now())))
.shouldFailOnError(true)
.build();
new Runner(options).run();
}
@Benchmark
public Object readShortDecimalByteArrayLength(ShortDecimalByteArrayLengthBenchmarkData data)
throws Throwable
{
return read(data, enableOptimizedReader);
}
@Benchmark
public Object readShortDecimal(ShortDecimalBenchmarkData data)
throws Throwable
{
return read(data, enableOptimizedReader);
}
@Benchmark
public Object readLongDecimal(LongDecimalBenchmarkData data)
throws Throwable
{
return read(data, enableOptimizedReader);
}
public static Object read(BenchmarkData data, boolean enableOptimizedReader)
throws Exception
{
try (ParquetReader recordReader = data.createRecordReader(enableOptimizedReader)) {
List<Block> blocks = new ArrayList<>();
while (recordReader.nextBatch() > 0) {
Block block = recordReader.readBlock(data.field);
blocks.add(block);
}
return blocks;
}
}
@State(Scope.Thread)
public static class ShortDecimalByteArrayLengthBenchmarkData
extends BenchmarkData
{
@Param({
"1", "2", "3", "4", "5", "6", "7", "8",
})
public int byteArrayLength = 1;
@Override
protected Type getType()
{
return DecimalType.createDecimalType(getPrecision(), getScale());
}
@Override
protected String getPrimitiveTypeName()
{
return "FIXED_LEN_BYTE_ARRAY(" + byteArrayLength + ")";
}
@Override
protected int getPrecision()
{
return maxPrecision(byteArrayLength);
}
@Override
protected int getScale()
{
return 1;
}
@Override
protected MessageType getSchema()
{
DecimalType decimalType = (DecimalType) getType();
String type = format("DECIMAL(%d,%d)", decimalType.getPrecision(), decimalType.getScale());
return parseMessageType(
"message test { "
+ Repetition.REQUIRED + " " + getPrimitiveTypeName() + " " + FIELD_NAME + " (" + type + "); "
+ "} ");
}
@Override
protected List<Object> generateValues()
{
List<Object> values = new ArrayList<>();
int precision = ((DecimalType) getType()).getPrecision();
long[] dataGen = unscaledRandomShortDecimalSupplier(byteArrayLength * Byte.SIZE, precision).apply(ROWS);
for (int i = 0; i < ROWS; ++i) {
values.add(Binary.fromConstantByteArray(longToBytes(dataGen[i], byteArrayLength)));
}
return values;
}
}
@State(Scope.Thread)
public static class ShortDecimalBenchmarkData
extends BenchmarkData
{
@Param({
"INT32", "INT64", "BINARY", "FIXED_LEN_BYTE_ARRAY(8)",
})
public static String decimalPrimitiveTypeName = "FIXED_LEN_BYTE_ARRAY(8)";
@Override
protected Type getType()
{
return DecimalType.createDecimalType(getPrecision(), getScale());
}
@Override
protected String getPrimitiveTypeName()
{
return decimalPrimitiveTypeName;
}
@Override
protected int getPrecision()
{
switch (getPrimitiveTypeName()) {
case "INT32":
return 9;
default:
return 18;
}
}
@Override
protected int getScale()
{
switch (getPrimitiveTypeName()) {
case "INT32":
case "INT64":
return 0;
default:
return 12;
}
}
@Override
protected MessageType getSchema()
{
boolean nullability = getNullability();
Repetition repetition = nullability ? Repetition.OPTIONAL : Repetition.REQUIRED;
DecimalType decimalType = (DecimalType) getType();
String type = format("DECIMAL(%d,%d)", decimalType.getPrecision(), decimalType.getScale());
return parseMessageType(
"message test { "
+ repetition + " " + getPrimitiveTypeName() + " " + FIELD_NAME + " (" + type + "); "
+ "} ");
}
@Override
protected List<Object> generateValues()
{
List<Object> values = new ArrayList<>();
for (int i = 0; i < ROWS; ++i) {
if (getNullability()) {
if (random.nextBoolean()) {
switch (getPrimitiveTypeName()) {
case "INT32":
values.add(random.nextInt());
break;
case "INT64":
values.add(random.nextLong());
break;
default:
values.add(Binary.fromConstantByteArray(longToBytes(random.nextLong(), 8)));
break;
}
}
else {
values.add(null);
}
}
else {
switch (getPrimitiveTypeName()) {
case "INT32":
values.add(random.nextInt());
break;
case "INT64":
values.add(random.nextLong());
break;
default:
values.add(Binary.fromConstantByteArray(longToBytes(random.nextLong(), 8)));
break;
}
}
}
return values;
}
protected boolean getNullability()
{
return nullable;
}
}
@State(Scope.Thread)
public static class LongDecimalBenchmarkData
extends BenchmarkData
{
@Param({
"BINARY", "FIXED_LEN_BYTE_ARRAY(16)",
})
public static String decimalPrimitiveTypeName = "FIXED_LEN_BYTE_ARRAY(16)";
@Override
protected Type getType()
{
return DecimalType.createDecimalType(getPrecision(), getScale());
}
@Override
protected String getPrimitiveTypeName()
{
return decimalPrimitiveTypeName;
}
@Override
protected int getPrecision()
{
return 38;
}
@Override
protected int getScale()
{
return 2;
}
@Override
protected MessageType getSchema()
{
boolean nullability = getNullability();
Repetition repetition = nullability ? Repetition.OPTIONAL : Repetition.REQUIRED;
DecimalType decimalType = (DecimalType) getType();
String type = format("DECIMAL(%d,%d)", decimalType.getPrecision(), decimalType.getScale());
return parseMessageType(
"message test { "
+ repetition + " " + getPrimitiveTypeName() + " " + FIELD_NAME + " (" + type + "); "
+ "} ");
}
@Override
protected List<Binary> generateValues()
{
List<Binary> values = new ArrayList<>();
for (int i = 0; i < ROWS; ++i) {
if (getNullability()) {
if (random.nextBoolean()) {
values.add(Binary.fromConstantByteArray(longToBytes(random.nextLong(), 16)));
}
else {
values.add(null);
}
}
else {
values.add(Binary.fromConstantByteArray(longToBytes(random.nextLong(), 16)));
}
}
return values;
}
protected boolean getNullability()
{
return nullable;
}
}
public abstract static class BenchmarkData
{
protected File temporaryDirectory;
protected File file;
protected Random random;
private Field field;
@Setup
public void setup()
throws Exception
{
random = new Random(0);
temporaryDirectory = createTempDir();
file = new File(temporaryDirectory, randomUUID().toString());
generateData(new Path(file.getAbsolutePath()), getSchema(), generateValues(), getPrimitiveTypeName());
}
@TearDown
public void tearDown()
throws IOException
{
deleteRecursively(temporaryDirectory.toPath(), ALLOW_INSECURE);
}
ParquetReader createRecordReader(boolean enableOptimizedReader)
throws IOException
{
FileParquetDataSource dataSource = new FileParquetDataSource(file);
ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, file.length(), Optional.empty(), false).getParquetMetadata();
MessageType schema = parquetMetadata.getFileMetaData().getSchema();
MessageColumnIO messageColumnIO = getColumnIO(schema, schema);
this.field = ColumnIOConverter.constructField(getType(), messageColumnIO.getChild(0)).get();
return new ParquetReader(
messageColumnIO,
parquetMetadata.getBlocks(),
Optional.empty(),
dataSource,
newSimpleAggregatedMemoryContext(),
new DataSize(16, MEGABYTE),
enableOptimizedReader,
false,
null,
null,
false,
Optional.empty());
}
protected abstract List<?> generateValues();
protected abstract MessageType getSchema();
protected abstract String getPrimitiveTypeName();
protected abstract Type getType();
protected abstract int getPrecision();
protected abstract int getScale();
}
public static void generateData(Path outFile, MessageType schema, List<?> dataList, String primitiveTypeName)
throws IOException
{
System.out.println("Generating data @ " + outFile);
Configuration configuration = new Configuration();
GroupWriteSupport.setSchema(schema, configuration);
SimpleGroupFactory f = new SimpleGroupFactory(schema);
ParquetWriter<Group> writer = new ParquetWriter<Group>(
outFile,
new GroupWriteSupport(),
CompressionCodecName.UNCOMPRESSED,
DEFAULT_BLOCK_SIZE,
DEFAULT_PAGE_SIZE,
DICT_PAGE_SIZE,
true,
false,
writerVersion,
configuration);
for (Object data : dataList) {
if (data == null) {
writer.write(f.newGroup());
}
else {
switch (primitiveTypeName) {
case "INT32":
writer.write(f.newGroup().append(FIELD_NAME, (int) data));
break;
case "INT64":
writer.write(f.newGroup().append(FIELD_NAME, (long) data));
break;
default:
writer.write(f.newGroup().append(FIELD_NAME, (Binary) data));
}
}
}
writer.close();
}
static {
try {
BenchmarkDecimalColumnBatchReader benchmark = new BenchmarkDecimalColumnBatchReader();
ShortDecimalByteArrayLengthBenchmarkData shortDecimalByteArrayLengthBenchmarkData = new ShortDecimalByteArrayLengthBenchmarkData();
shortDecimalByteArrayLengthBenchmarkData.setup();
benchmark.readShortDecimalByteArrayLength(shortDecimalByteArrayLengthBenchmarkData);
ShortDecimalBenchmarkData dataShortDecimal = new ShortDecimalBenchmarkData();
dataShortDecimal.setup();
benchmark.readShortDecimal(dataShortDecimal);
LongDecimalBenchmarkData dataLongDecimal = new LongDecimalBenchmarkData();
dataLongDecimal.setup();
benchmark.readLongDecimal(dataLongDecimal);
}
catch (Throwable throwable) {
throw new RuntimeException(throwable);
}
}
}