BenchmarkJsonExtract.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.DriverYieldSignal;
import com.facebook.presto.operator.project.PageProcessor;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.gen.PageFunctionCompiler;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.relational.RowExpressionOptimizer;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.testing.TestingConnectorSession;
import com.facebook.presto.testing.TestingSession;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.SliceOutput;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import org.openjdk.jmh.runner.options.WarmupMode;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import static com.facebook.presto.common.type.JsonType.JSON;
import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.operator.scalar.FunctionAssertions.createExpression;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static java.util.Collections.emptyMap;
import static java.util.Locale.ENGLISH;
@SuppressWarnings("MethodMayBeStatic")
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(10)
@BenchmarkMode(Mode.AverageTime)
public class BenchmarkJsonExtract
{
private static final SqlParser SQL_PARSER = new SqlParser();
private static final Metadata METADATA = createTestMetadataManager();
private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
public static final ConnectorSession SESSION = new TestingConnectorSession(ImmutableList.of());
private static final int POSITION_COUNT = 100_000;
private static final int ARRAY_SIZE = 20;
private static final String CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
private PageProcessor pageProcessor;
private Page inputPage;
private Map<String, Type> symbolTypes;
private Map<VariableReferenceExpression, Integer> sourceLayout;
@Param({"true", "false"})
boolean isCanonicalizedJsonExtract;
@Setup
public void setup()
{
VariableReferenceExpression variable = new VariableReferenceExpression(Optional.empty(), VARCHAR.getDisplayName().toLowerCase(ENGLISH) + "0", VARCHAR);
symbolTypes = ImmutableMap.of(variable.getName(), VARCHAR);
sourceLayout = ImmutableMap.of(variable, 0);
inputPage = new Page(createChannel());
List<RowExpression> projections = ImmutableList.of(rowExpression("json_extract(varchar0, '$.key1')"), rowExpression("json_extract(varchar0, '$.key2')"));
MetadataManager metadata = createTestMetadataManager();
PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(metadata, 0);
ExpressionCompiler expressionCompiler = new ExpressionCompiler(metadata, pageFunctionCompiler);
pageProcessor = expressionCompiler.compilePageProcessor(TEST_SESSION.getSqlFunctionProperties(), Optional.empty(), projections).get();
}
@Benchmark
public List<Optional<Page>> computePage()
{
SqlFunctionProperties sqlFunctionProperties = SqlFunctionProperties.builder()
.setTimeZoneKey(UTC_KEY)
.setLegacyTimestamp(true)
.setSessionStartTime(0)
.setSessionLocale(ENGLISH).setSessionUser("user")
.setCanonicalizedJsonExtract(isCanonicalizedJsonExtract)
.build();
return ImmutableList.copyOf(
pageProcessor.process(
sqlFunctionProperties,
new DriverYieldSignal(),
newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()),
inputPage));
}
private RowExpression rowExpression(String value)
{
Expression expression = createExpression(TEST_SESSION, value, METADATA, TypeProvider.copyOf(symbolTypes));
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyMap(), WarningCollector.NOOP);
RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, METADATA.getFunctionAndTypeManager(), TEST_SESSION);
RowExpressionOptimizer optimizer = new RowExpressionOptimizer(METADATA);
return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession());
}
private static Block createChannel()
{
BlockBuilder blockBuilder = JSON.createBlockBuilder(null, BenchmarkJsonExtract.POSITION_COUNT);
for (int position = 0; position < BenchmarkJsonExtract.POSITION_COUNT; position++) {
try (SliceOutput jsonSlice = new DynamicSliceOutput(20 * BenchmarkJsonExtract.ARRAY_SIZE)) {
jsonSlice.appendByte('{');
int k1Index = ThreadLocalRandom.current().nextInt(ARRAY_SIZE);
int k2Index = ThreadLocalRandom.current().nextInt(ARRAY_SIZE);
while (k2Index == k1Index) {
k2Index = ThreadLocalRandom.current().nextInt(ARRAY_SIZE);
}
for (int i = 0; i < ARRAY_SIZE; i++) {
String key;
if (i == k1Index) {
key = "key1";
}
else if (i == k2Index) {
key = "key2";
}
else {
key = generateRandomKey(ThreadLocalRandom.current().nextInt(5) + 1);
}
jsonSlice.appendBytes("\"".getBytes());
jsonSlice.appendBytes(key.getBytes());
jsonSlice.appendBytes("\"".getBytes());
jsonSlice.appendByte(':');
String value;
if (key.equals("key1") || key.equals("key2") || (ThreadLocalRandom.current().nextInt(10) & 1) == 0) {
value = generateNestedJsonValue();
}
else {
value = generateRandomJsonValue();
}
jsonSlice.appendBytes(value.getBytes());
if (i < ARRAY_SIZE - 1) {
jsonSlice.appendByte(','); // Add a comma between JSON objects
}
}
jsonSlice.appendByte('}');
JSON.writeSlice(blockBuilder, jsonSlice.slice());
}
catch (Exception ignore) {
// Ignore...
}
}
return blockBuilder.build();
}
private static String generateRandomJsonValue()
{
int length = ThreadLocalRandom.current().nextInt(10) + 1;
StringBuilder builder = new StringBuilder(length + 2);
builder.append('"');
for (int i = 0; i < length; i++) {
char c = CHARACTERS.charAt(ThreadLocalRandom.current().nextInt(CHARACTERS.length()));
if (c == '"') {
builder.append('\\'); // escape double quote
}
builder.append(c);
}
builder.append('"');
return builder.toString();
}
private static String generateNestedJsonValue()
{
int size = ThreadLocalRandom.current().nextInt(5) + 1;
StringBuilder builder = new StringBuilder(size * 10);
builder.append('{');
for (int i = 0; i < size; i++) {
String key = generateRandomKey(ThreadLocalRandom.current().nextInt(5) + 2);
builder.append("\"").append(key).append("\":");
builder.append(generateRandomJsonValue());
if (i < size - 1) {
builder.append(",");
}
}
builder.append('}');
return builder.toString();
}
private static String generateRandomKey(int len)
{
StringBuilder builder = new StringBuilder(len);
for (int i = 0; i < len; i++) {
builder.append(CHARACTERS.charAt(ThreadLocalRandom.current().nextInt(CHARACTERS.length())));
}
return builder.toString();
}
@Test
public void verify()
{
BenchmarkJsonToArrayCast.BenchmarkData data = new BenchmarkJsonToArrayCast.BenchmarkData();
data.setup();
new BenchmarkJsonToArrayCast().benchmark(data);
}
public static void main(String[] args)
throws Throwable
{
// assure the benchmarks are valid before running
BenchmarkJsonToArrayCast.BenchmarkData data = new BenchmarkJsonToArrayCast.BenchmarkData();
data.setup();
new BenchmarkJsonToArrayCast().benchmark(data);
Options options = new OptionsBuilder()
.verbosity(VerboseMode.NORMAL)
.include(".*" + BenchmarkJsonExtract.class.getSimpleName() + ".*")
.warmupMode(WarmupMode.BULK_INDI)
.build();
new Runner(options).run();
}
}