CommonSubExpressionBenchmark.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.SequencePageBuilder;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.DictionaryBlock;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.DriverYieldSignal;
import com.facebook.presto.operator.index.PageRecordSet;
import com.facebook.presto.operator.project.CursorProcessor;
import com.facebook.presto.operator.project.PageProcessor;
import com.facebook.presto.spi.RecordSet;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.relational.RowExpressionOptimizer;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.testing.TestingSession;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.operator.scalar.FunctionAssertions.createExpression;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static java.util.Collections.emptyMap;
import static java.util.Locale.ENGLISH;
import static java.util.stream.Collectors.toList;

@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(2)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@BenchmarkMode(Mode.AverageTime)
public class CommonSubExpressionBenchmark
{
    private static final Map<String, Type> TYPE_MAP = ImmutableMap.of("bigint", BIGINT, "varchar", VARCHAR, "json", VARCHAR);
    private static final SqlParser SQL_PARSER = new SqlParser();
    private static final Metadata METADATA = createTestMetadataManager();
    private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
    private static final int POSITIONS = 1024;

    private PageProcessor pageProcessor;
    private CursorProcessor cursorProcessor;
    private Page inputPage;
    private Map<String, Type> symbolTypes;
    private Map<VariableReferenceExpression, Integer> sourceLayout;
    private List<Type> projectionTypes;

    @Param({"json", "bigint", "varchar"})
    String functionType;

    @Param({"true", "false"})
    boolean optimizeCommonSubExpression;

    @Param({"true", "false"})
    boolean dictionaryBlocks;

    @Setup
    public void setup()
    {
        Type type = TYPE_MAP.get(this.functionType);

        VariableReferenceExpression variable = new VariableReferenceExpression(Optional.empty(), type.getDisplayName().toLowerCase(ENGLISH) + "0", type);
        symbolTypes = ImmutableMap.of(variable.getName(), type);
        sourceLayout = ImmutableMap.of(variable, 0);
        inputPage = createPage(functionType, dictionaryBlocks);

        List<RowExpression> projections = getProjections(this.functionType);

        projectionTypes = projections.stream().map(RowExpression::getType).collect(toList());

        MetadataManager metadata = createTestMetadataManager();
        PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(metadata, 0);
        ExpressionCompiler expressionCompiler = new ExpressionCompiler(metadata, pageFunctionCompiler);
        pageProcessor = expressionCompiler.compilePageProcessor(TEST_SESSION.getSqlFunctionProperties(), Optional.of(getFilter(functionType)), projections, optimizeCommonSubExpression, Optional.empty()).get();
        cursorProcessor = expressionCompiler.compileCursorProcessor(TEST_SESSION.getSqlFunctionProperties(), Optional.of(getFilter(functionType)), projections, "key", optimizeCommonSubExpression).get();
    }

    @Benchmark
    public List<Optional<Page>> computePage()
    {
        return ImmutableList.copyOf(
                pageProcessor.process(
                        null,
                        new DriverYieldSignal(),
                        newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()),
                        inputPage));
    }

    @Benchmark
    public Optional<Page> ComputeRecordSet()
    {
        List<Type> types = ImmutableList.of(TYPE_MAP.get(this.functionType));
        PageBuilder pageBuilder = new PageBuilder(projectionTypes);
        RecordSet recordSet = new PageRecordSet(types, inputPage);

        cursorProcessor.process(
                null,
                new DriverYieldSignal(),
                recordSet.cursor(),
                pageBuilder);

        return Optional.of(pageBuilder.build());
    }

    private RowExpression getFilter(String functionType)
    {
        if (functionType.equals("varchar")) {
            return rowExpression("cast(varchar0 as bigint) % 2 = 0");
        }
        if (functionType.equals("bigint")) {
            return rowExpression("bigint0 % 2 = 0");
        }
        if (functionType.equals("json")) {
            return rowExpression("rand() < 0.5");
        }
        throw new IllegalArgumentException("filter not supported for type : " + functionType);
    }

    private List<RowExpression> getProjections(String functionType)
    {
        ImmutableList.Builder<RowExpression> builder = ImmutableList.builder();
        if (functionType.equals("bigint")) {
            return ImmutableList.of(rowExpression("bigint0 + bigint0"), rowExpression("bigint0 + bigint0 + 5"));
        }
        else if (functionType.equals("varchar")) {
            return ImmutableList.of(rowExpression("concat(varchar0, varchar0)"), rowExpression("concat(concat(varchar0, varchar0), 'foo')"));
        }
        else if (functionType.equals("json")) {
            return ImmutableList.of(rowExpression("json_extract(json_parse(varchar0), '$.a')"), rowExpression("json_extract(json_parse(varchar0), '$.b')"));
        }
        throw new IllegalArgumentException();
    }

    private RowExpression rowExpression(String value)
    {
        Expression expression = createExpression(value, METADATA, TypeProvider.copyOf(symbolTypes));

        Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, TypeProvider.copyOf(symbolTypes), expression, emptyMap(), WarningCollector.NOOP);
        RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, expressionTypes, sourceLayout, METADATA.getFunctionAndTypeManager(), TEST_SESSION);
        RowExpressionOptimizer optimizer = new RowExpressionOptimizer(METADATA);
        return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession());
    }

    private static Page createPage(String functionType, boolean dictionary)
    {
        List<Type> types = ImmutableList.of(TYPE_MAP.get(functionType));
        switch (functionType) {
            case "bigint":
            case "varchar":
                if (dictionary) {
                    return SequencePageBuilder.createSequencePageWithDictionaryBlocks(types, POSITIONS);
                }
                else {
                    return SequencePageBuilder.createSequencePage(types, POSITIONS);
                }
            case "json":
                if (dictionary) {
                    return createDictionaryStringJsonPage();
                }
                else {
                    return createStringJsonPage();
                }
            default:
                throw new IllegalArgumentException();
        }
    }

    private static Page createStringJsonPage()
    {
        BlockBuilder builder = VARCHAR.createBlockBuilder(null, POSITIONS);

        for (int i = 0; i < POSITIONS; i++) {
            VARCHAR.writeString(builder, "{\"a\": 1, \"b\": 2}");
        }
        return new Page(builder.build());
    }

    private static Page createDictionaryStringJsonPage()
    {
        int dictionarySize = POSITIONS / 5;
        BlockBuilder builder = VARCHAR.createBlockBuilder(null, dictionarySize);
        for (int i = 0; i < dictionarySize; i++) {
            VARCHAR.writeString(builder, "{\"a\": 1, \"b\": 2}");
        }
        int[] ids = new int[POSITIONS];
        for (int i = 0; i < POSITIONS; i++) {
            ids[i] = i % dictionarySize;
        }
        return new Page(new DictionaryBlock(builder.build(), ids));
    }

    public static void main(String[] args)
            throws RunnerException
    {
        Options options = new OptionsBuilder()
                .verbosity(VerboseMode.NORMAL)
                .include(".*" + CommonSubExpressionBenchmark.class.getSimpleName() + ".*")
                .build();

        new Runner(options).run();
    }
}