BenchmarkHashAndSegmentedAggregationOperators.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator;

import com.facebook.presto.RowPagesBuilder;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spiller.SpillerFactory;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.testing.TestingTaskContext;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import org.testng.annotations.Test;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;

import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock;
import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.operator.BenchmarkHashAndSegmentedAggregationOperators.Context.ROWS_PER_PAGE;
import static com.facebook.presto.operator.BenchmarkHashAndSegmentedAggregationOperators.Context.TOTAL_PAGES;
import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static io.airlift.units.DataSize.succinctBytes;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.openjdk.jmh.annotations.Mode.AverageTime;
import static org.openjdk.jmh.annotations.Scope.Thread;
import static org.testng.Assert.assertEquals;

@State(Thread)
@OutputTimeUnit(MILLISECONDS)
@BenchmarkMode(AverageTime)
@Fork(3)
@Warmup(iterations = 5)
@Measurement(iterations = 10, time = 2, timeUnit = SECONDS)
public class BenchmarkHashAndSegmentedAggregationOperators
{
    private static final MetadataManager metadata = MetadataManager.createTestMetadataManager();
    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = metadata.getFunctionAndTypeManager();

    private static final JavaAggregationFunctionImplementation LONG_SUM = FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(
            FUNCTION_AND_TYPE_MANAGER.lookupFunction("sum", fromTypes(BIGINT)));
    private static final JavaAggregationFunctionImplementation COUNT = FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(
            FUNCTION_AND_TYPE_MANAGER.lookupFunction("count", ImmutableList.of()));

    @State(Thread)
    public static class Context
    {
        public static final int TOTAL_PAGES = 100;
        public static final int ROWS_PER_PAGE = 1000;

        @Param({"1", "10", "800", "100000"})
        public int rowsPerSegment;

        @Param({"segmented", "hash"})
        public String operatorType;

        private ExecutorService executor;
        private ScheduledExecutorService scheduledExecutor;
        private OperatorFactory operatorFactory;
        private List<Page> pages;
        private int outputRows;

        @Setup
        public void setup()
        {
            executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s"));
            scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s"));
            outputRows = 0;

            boolean segmentedAggregation = operatorType.equalsIgnoreCase("segmented");
            RowPagesBuilder pagesBuilder = RowPagesBuilder.rowPagesBuilder(true, ImmutableList.of(0, 1), VARCHAR, BIGINT, BIGINT);
            for (int i = 0; i < TOTAL_PAGES; i++) {
                BlockBuilder sortedBlockBuilder = VARCHAR.createBlockBuilder(null, ROWS_PER_PAGE);
                for (int j = 0; j < ROWS_PER_PAGE; j++) {
                    int currentSegment = (i * ROWS_PER_PAGE + j) / rowsPerSegment;
                    VARCHAR.writeString(sortedBlockBuilder, String.valueOf(currentSegment));
                }
                outputRows += (ROWS_PER_PAGE - 1) / rowsPerSegment + 1;
                pagesBuilder.addBlocksPage(sortedBlockBuilder, createLongRepeatBlock(i, ROWS_PER_PAGE), createLongSequenceBlock(0, ROWS_PER_PAGE));
            }

            pages = pagesBuilder.build();
            operatorFactory = createHashAggregationOperatorFactory(pagesBuilder.getHashChannel(), segmentedAggregation);
        }

        private OperatorFactory createHashAggregationOperatorFactory(Optional<Integer> hashChannel, boolean segmentedAggregation)
        {
            JoinCompiler joinCompiler = new JoinCompiler(metadata);
            SpillerFactory spillerFactory = (types, localSpillContext, aggregatedMemoryContext) -> null;

            return new HashAggregationOperatorFactory(
                    0,
                    new PlanNodeId("test"),
                    ImmutableList.of(VARCHAR, BIGINT),
                    ImmutableList.of(0, 1),
                    segmentedAggregation ? ImmutableList.of(0) : ImmutableList.of(),
                    ImmutableList.of(),
                    AggregationNode.Step.SINGLE,
                    false,
                    ImmutableList.of(generateAccumulatorFactory(COUNT, ImmutableList.of(2), Optional.empty()),
                            generateAccumulatorFactory(LONG_SUM, ImmutableList.of(2), Optional.empty())),
                    hashChannel,
                    Optional.empty(),
                    100_000,
                    Optional.of(new DataSize(16, MEGABYTE)),
                    false,
                    Optional.empty(),
                    succinctBytes(8),
                    succinctBytes(Integer.MAX_VALUE),
                    spillerFactory,
                    joinCompiler,
                    false);
        }

        public TaskContext createTaskContext()
        {
            return TestingTaskContext.createTaskContext(executor, scheduledExecutor, TEST_SESSION, new DataSize(2, GIGABYTE));
        }

        public OperatorFactory getOperatorFactory()
        {
            return operatorFactory;
        }

        public List<Page> getPages()
        {
            return pages;
        }
    }

    @Benchmark
    public List<Page> benchmark(Context context)
    {
        DriverContext driverContext = context.createTaskContext().addPipelineContext(0, true, true, false).addDriverContext();
        Operator operator = context.getOperatorFactory().createOperator(driverContext);

        Iterator<Page> input = context.getPages().iterator();
        ImmutableList.Builder<Page> outputPages = ImmutableList.builder();

        boolean finishing = false;
        for (int loops = 0; !operator.isFinished() && loops < 1_000_000; loops++) {
            if (operator.needsInput()) {
                if (input.hasNext()) {
                    Page inputPage = input.next();
                    operator.addInput(inputPage);
                }
                else if (!finishing) {
                    operator.finish();
                    finishing = true;
                }
            }

            Page outputPage = operator.getOutput();
            if (outputPage != null) {
                outputPages.add(outputPage);
            }
        }

        return outputPages.build();
    }

    @Test
    public void verifyHash()
    {
        verify(1, "hash");
        verify(10, "hash");
        verify(800, "hash");
        verify(100000, "hash");
    }

    @Test
    public void verifySegmented()
    {
        verify(1, "segmented");
        verify(10, "segmented");
        verify(800, "segmented");
        verify(100000, "segmented");
    }

    private void verify(int rowsPerSegment, String operatorType)
    {
        Context context = new Context();
        context.operatorType = operatorType;
        context.rowsPerSegment = rowsPerSegment;
        context.setup();

        assertEquals(TOTAL_PAGES, context.getPages().size());
        for (int i = 0; i < TOTAL_PAGES; i++) {
            assertEquals(ROWS_PER_PAGE, context.getPages().get(i).getPositionCount());
        }

        List<Page> outputPages = benchmark(context);
        assertEquals(context.outputRows, outputPages.stream().mapToInt(Page::getPositionCount).sum());
    }

    public static void main(String[] args)
            throws RunnerException
    {
        Options options = new OptionsBuilder()
                .verbosity(VerboseMode.NORMAL)
                .include(".*" + BenchmarkHashAndSegmentedAggregationOperators.class.getSimpleName() + ".*")
                .build();

        new Runner(options).run();
    }
}