BenchmarkZstdJniDecompression.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc;
import com.facebook.presto.orc.zstd.ZstdJniCompressor;
import com.google.common.collect.ImmutableList;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.util.List;
import java.util.OptionalInt;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.testng.Assert.assertEquals;
@SuppressWarnings("MethodMayBeStatic")
@State(Scope.Thread)
@OutputTimeUnit(MILLISECONDS)
@Fork(3)
@Warmup(iterations = 20, time = 500, timeUnit = MILLISECONDS)
@Measurement(iterations = 20, time = 500, timeUnit = MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
public class BenchmarkZstdJniDecompression
{
private static final ZstdJniCompressor compressor = new ZstdJniCompressor(OptionalInt.empty());
private static final List<Unit> list = generateWorkload();
private static final int sourceLength = 256 * 1024;
private static byte[] decompressedBytes = new byte[sourceLength];
@Benchmark
public void decompressJni()
throws OrcCorruptionException
{
decompressList(createOrcDecompressor(true));
}
@Benchmark
public void decompressJava()
throws OrcCorruptionException
{
decompressList(createOrcDecompressor(false));
}
private void decompressList(OrcDecompressor decompressor)
throws OrcCorruptionException
{
for (Unit unit : list) {
int outputSize = decompressor.decompress(unit.compressedBytes, 0, unit.compressedLength, new OrcDecompressor.OutputBuffer()
{
@Override
public byte[] initialize(int size)
{
return decompressedBytes;
}
@Override
public byte[] grow(int size)
{
throw new RuntimeException();
}
});
assertEquals(outputSize, unit.sourceLength);
}
}
private static List<Unit> generateWorkload()
{
ImmutableList.Builder<Unit> builder = new ImmutableList.Builder<>();
for (int i = 0; i < 10; i++) {
byte[] sourceBytes = getAlphaNumericString(sourceLength).getBytes();
byte[] compressedBytes = new byte[sourceLength * 32];
int size = compressor.compress(sourceBytes, 0, sourceBytes.length, compressedBytes, 0, compressedBytes.length);
builder.add(new Unit(sourceBytes, sourceLength, compressedBytes, size));
}
return builder.build();
}
private OrcDecompressor createOrcDecompressor(boolean zstdJniDecompressionEnabled)
{
return new OrcZstdDecompressor(new OrcDataSourceId("orc"), sourceLength, zstdJniDecompressionEnabled);
}
private static String getAlphaNumericString(int length)
{
String alphaNumericString = "USINDIA";
StringBuilder stringBuilder = new StringBuilder(length);
for (int index = 0; index < length; index++) {
int arrayIndex = (int) (alphaNumericString.length() * Math.random());
stringBuilder.append(alphaNumericString.charAt(arrayIndex));
}
return stringBuilder.toString();
}
static class Unit
{
final byte[] sourceBytes;
final int sourceLength;
final byte[] compressedBytes;
final int compressedLength;
public Unit(byte[] sourceBytes, int sourceLength, byte[] compressedBytes, int compressedLength)
{
this.sourceBytes = sourceBytes;
this.sourceLength = sourceLength;
this.compressedBytes = compressedBytes;
this.compressedLength = compressedLength;
}
}
}