BenchmarkPartitionedOutputOperator.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.repartition;

import com.facebook.presto.CompressionCodec;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.StateMachine;
import com.facebook.presto.execution.buffer.BufferState;
import com.facebook.presto.execution.buffer.OutputBuffers;
import com.facebook.presto.execution.buffer.PagesSerdeFactory;
import com.facebook.presto.execution.buffer.PartitionedOutputBuffer;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.memory.context.SimpleLocalMemoryContext;
import com.facebook.presto.operator.BucketPartitionFunction;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.PageAssertions;
import com.facebook.presto.operator.PartitionFunction;
import com.facebook.presto.operator.PrecomputedHashGenerator;
import com.facebook.presto.operator.exchange.LocalPartitionGenerator;
import com.facebook.presto.operator.repartition.OptimizedPartitionedOutputOperator.OptimizedPartitionedOutputFactory;
import com.facebook.presto.operator.repartition.PartitionedOutputOperator.PartitionedOutputFactory;
import com.facebook.presto.spi.page.SerializedPage;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.sql.planner.OutputPartitioning;
import com.facebook.presto.testing.TestingTaskContext;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.presto.block.BlockAssertions.createMapType;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DecimalType.createDecimalType;
import static com.facebook.presto.common.type.Decimals.MAX_SHORT_PRECISION;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RowType.withDefaultFieldNames;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.execution.buffer.BufferState.OPEN;
import static com.facebook.presto.execution.buffer.BufferState.TERMINAL_BUFFER_STATES;
import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED;
import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers;
import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static com.facebook.presto.operator.PageAssertions.createDictionaryPageWithRandomData;
import static com.facebook.presto.operator.PageAssertions.createRlePageWithRandomData;
import static com.facebook.presto.operator.PageAssertions.updateBlockTypesWithHashBlockAndNullBlock;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SystemPartitionFunction.HASH;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME;
import static io.airlift.units.DataSize.Unit.BYTE;
import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.util.Collections.nCopies;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

@State(Scope.Thread)
@OutputTimeUnit(MILLISECONDS)
@Fork(3)
@Warmup(iterations = 10, time = 500, timeUnit = MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
public class BenchmarkPartitionedOutputOperator
{
    @Benchmark
    public void addPage(BenchmarkData data)
    {
        PartitionedOutputOperator operator = data.createPartitionedOutputOperator();
        for (int i = 0; i < data.pageCount; i++) {
            operator.addInput(data.dataPage);
        }
        operator.finish();
    }

    @Benchmark
    public void optimizedAddPage(BenchmarkData data)
    {
        OptimizedPartitionedOutputOperator operator = data.createOptimizedPartitionedOutputOperator();
        for (int i = 0; i < data.pageCount; i++) {
            operator.addInput(data.dataPage);
        }
        operator.finish();
    }

    @Test
    public void verifyAddPage()
    {
        BenchmarkData data = new BenchmarkData();
        data.setup();
        new BenchmarkPartitionedOutputOperator().addPage(data);
    }

    @Test
    public void verifyOptimizedAddPage()
    {
        BenchmarkData data = new BenchmarkData();
        data.setup();
        new BenchmarkPartitionedOutputOperator().optimizedAddPage(data);
    }

    @State(Scope.Thread)
    public static class BenchmarkData
    {
        private static final int PARTITION_COUNT = 256;
        private static final int POSITION_COUNT = 8192;
        private static final DataSize MAX_MEMORY = new DataSize(4, GIGABYTE);
        private static final DataSize MAX_PARTITION_BUFFER_SIZE = new DataSize(256, MEGABYTE);
        private static final ExecutorService EXECUTOR = newCachedThreadPool(daemonThreadsNamed("test-EXECUTOR-%s"));
        private static final ScheduledExecutorService SCHEDULER = newScheduledThreadPool(1, daemonThreadsNamed("test-%s"));

        @SuppressWarnings("unused")
        @Param({"NONE", "LZ4"})
        private String codec = "NONE";

        @Param({"1", "2"})
        private int channelCount = 1;

        @Param({
                "BIGINT",
                "DICTIONARY(BIGINT)",
                "RLE(BIGINT)",
                "LONG_DECIMAL",
                "INTEGER",
                "SMALLINT",
                "BOOLEAN",
                "VARCHAR",
                "ARRAY(BIGINT)",
                "ARRAY(VARCHAR)",
                "ARRAY(ARRAY(BIGINT))",
                "MAP(BIGINT,BIGINT)",
                "MAP(BIGINT,MAP(BIGINT,BIGINT))",
                "ROW(BIGINT,BIGINT)",
                "ROW(ARRAY(BIGINT),ARRAY(BIGINT))"
        })
        private String type = "BIGINT";

        @SuppressWarnings("unused")
        @Param({"true", "false"})
        private boolean hasNull;

        private List<Type> types;
        private int pageCount;
        private Page dataPage;

        @Setup
        public void setup()
        {
            createPages(type);
        }

        private void createPages(String inputType)
        {
            float primitiveNullRate = 0.0f;
            float nestedNullRate = 0.0f;

            if (hasNull) {
                primitiveNullRate = 0.2f;
                nestedNullRate = 0.2f;
            }
            switch (inputType) {
                case "BIGINT":
                    types = nCopies(channelCount, BIGINT);
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 5000;
                    break;
                case "DICTIONARY(BIGINT)":
                    types = nCopies(channelCount, BIGINT);
                    dataPage = createDictionaryPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 3000;
                    break;
                case "RLE(BIGINT)":
                    types = nCopies(channelCount, BIGINT);
                    dataPage = createRlePageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 3000;
                    break;
                case "LONG_DECIMAL":
                    types = nCopies(channelCount, createDecimalType(MAX_SHORT_PRECISION + 1));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 5000;
                    break;
                case "INTEGER":
                    types = nCopies(channelCount, INTEGER);
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 5000;
                    break;
                case "SMALLINT":
                    types = nCopies(channelCount, SMALLINT);
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 5000;
                    break;
                case "BOOLEAN":
                    types = nCopies(channelCount, BOOLEAN);
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 5000;
                    break;
                case "VARCHAR":
                    types = nCopies(channelCount, VARCHAR);
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 5000;
                    break;
                case "ARRAY(BIGINT)":
                    types = nCopies(channelCount, new ArrayType(BIGINT));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 1000;
                    break;
                case "ARRAY(VARCHAR)":
                    types = nCopies(channelCount, new ArrayType(VARCHAR));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 1000;
                    break;
                case "ARRAY(ARRAY(BIGINT))":
                    types = nCopies(channelCount, new ArrayType(new ArrayType(BIGINT)));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 1000;
                    break;
                case "MAP(BIGINT,BIGINT)":
                    types = nCopies(channelCount, createMapType(BIGINT, BIGINT));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 1000;
                    break;
                case "MAP(BIGINT,MAP(BIGINT,BIGINT))":
                    types = nCopies(channelCount, createMapType(BIGINT, createMapType(BIGINT, BIGINT)));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 1000;
                    break;
                case "ROW(BIGINT,BIGINT)":
                    types = nCopies(channelCount, withDefaultFieldNames(ImmutableList.of(BIGINT, BIGINT)));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 1000;
                    break;
                case "ROW(ARRAY(BIGINT),ARRAY(BIGINT))":
                    types = nCopies(channelCount, withDefaultFieldNames(ImmutableList.of(new ArrayType(BIGINT), new ArrayType(BIGINT))));
                    dataPage = PageAssertions.createPageWithRandomData(types, POSITION_COUNT, primitiveNullRate, nestedNullRate);
                    pageCount = 1000;
                    break;

                default:
                    throw new UnsupportedOperationException("Unsupported dataType");
            }
            // We built the dataPage with added pre-computed hash block at channel 0, so types needs to be updated
            types = updateBlockTypesWithHashBlockAndNullBlock(types, true, false);
        }

        private CompressionCodec getCompressionCodec(String compressionCodec)
        {
            switch (compressionCodec) {
                case "NONE":
                    return CompressionCodec.NONE;
                case "SNAPPY":
                    return CompressionCodec.SNAPPY;
                case "LZ4":
                    return CompressionCodec.LZ4;
                case "ZLIB":
                    return CompressionCodec.ZLIB;
                case "GZIP":
                    return CompressionCodec.GZIP;
                case "ZSTD":
                    return CompressionCodec.ZSTD;
                default:
                    throw new UnsupportedOperationException("Unsupported compression codec: " + compressionCodec);
            }
        }

        private PartitionedOutputBuffer createPartitionedOutputBuffer()
        {
            OutputBuffers buffers = createInitialEmptyOutputBuffers(PARTITIONED);
            for (int partition = 0; partition < PARTITION_COUNT; partition++) {
                buffers = buffers.withBuffer(new OutputBuffers.OutputBufferId(partition), partition);
            }
            PartitionedOutputBuffer buffer = createPartitionedBuffer(
                    buffers.withNoMoreBufferIds(),
                    new DataSize(Long.MAX_VALUE, BYTE)); // don't let output buffer block
            buffer.registerLifespanCompletionCallback(ignore -> {});

            return buffer;
        }

        private OptimizedPartitionedOutputOperator createOptimizedPartitionedOutputOperator()
        {
            PartitionFunction partitionFunction = new BucketPartitionFunction(
                    HASH.createBucketFunction(ImmutableList.of(BIGINT), true, PARTITION_COUNT),
                    IntStream.range(0, PARTITION_COUNT).toArray());
            OutputPartitioning outputPartitioning = createOutputPartitioning(partitionFunction);

            PagesSerdeFactory serdeFactory = new PagesSerdeFactory(new BlockEncodingManager(), getCompressionCodec(codec));
            PartitionedOutputBuffer buffer = createPartitionedOutputBuffer();

            OptimizedPartitionedOutputFactory operatorFactory = new OptimizedPartitionedOutputFactory(buffer, MAX_PARTITION_BUFFER_SIZE);

            return (OptimizedPartitionedOutputOperator) operatorFactory
                    .createOutputOperator(0, new PlanNodeId("plan-node-0"), types, Function.identity(), Optional.of(outputPartitioning), serdeFactory)
                    .createOperator(createDriverContext());
        }

        private PartitionedOutputOperator createPartitionedOutputOperator()
        {
            PartitionFunction partitionFunction = new LocalPartitionGenerator(new PrecomputedHashGenerator(0), PARTITION_COUNT);
            OutputPartitioning outputPartitioning = createOutputPartitioning(partitionFunction);

            PagesSerdeFactory serdeFactory = new PagesSerdeFactory(new BlockEncodingManager(), getCompressionCodec(codec));
            PartitionedOutputBuffer buffer = createPartitionedOutputBuffer();

            PartitionedOutputFactory operatorFactory = new PartitionedOutputFactory(buffer, MAX_PARTITION_BUFFER_SIZE);

            return (PartitionedOutputOperator) operatorFactory
                    .createOutputOperator(0, new PlanNodeId("plan-node-0"), types, Function.identity(), Optional.of(outputPartitioning), serdeFactory)
                    .createOperator(createDriverContext());
        }

        private OutputPartitioning createOutputPartitioning(PartitionFunction partitionFunction)
        {
            return new OutputPartitioning(
                    partitionFunction,
                    ImmutableList.of(0),
                    ImmutableList.of(Optional.empty(), Optional.empty()),
                    false,
                    OptionalInt.empty());
        }

        private DriverContext createDriverContext()
        {
            Session testSession = testSessionBuilder()
                    .setCatalog("tpch")
                    .setSchema(TINY_SCHEMA_NAME)
                    .build();

            return TestingTaskContext.builder(EXECUTOR, SCHEDULER, testSession)
                    .setMemoryPoolSize(MAX_MEMORY)
                    .setQueryMaxTotalMemory(MAX_MEMORY)
                    .build()
                    .addPipelineContext(0, true, true, false)
                    .addDriverContext();
        }

        private TestingPartitionedOutputBuffer createPartitionedBuffer(OutputBuffers buffers, DataSize dataSize)
        {
            return new TestingPartitionedOutputBuffer(
                    "task-instance-id",
                    new StateMachine<>("bufferState", SCHEDULER, OPEN, TERMINAL_BUFFER_STATES),
                    buffers,
                    dataSize,
                    () -> new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"),
                    SCHEDULER);
        }

        private static class TestingPartitionedOutputBuffer
                extends PartitionedOutputBuffer
        {
            public TestingPartitionedOutputBuffer(
                    String taskInstanceId,
                    StateMachine<BufferState> state,
                    OutputBuffers outputBuffers,
                    DataSize maxBufferSize,
                    Supplier<LocalMemoryContext> systemMemoryContextSupplier,
                    Executor notificationExecutor)
            {
                super(taskInstanceId, state, outputBuffers, maxBufferSize.toBytes(), systemMemoryContextSupplier, notificationExecutor);
            }

            // Use a dummy enqueue method to avoid OutOfMemory error
            @Override
            public void enqueue(Lifespan lifespan, int partitionNumber, List<SerializedPage> pages)
            {
            }
        }
    }

    public static void main(String[] args)
            throws RunnerException
    {
        Options options = new OptionsBuilder()
                .verbosity(VerboseMode.NORMAL)
                .jvmArgs("-Xmx10g")
                .include(".*" + BenchmarkPartitionedOutputOperator.class.getSimpleName() + ".*")
                .build();
        new Runner(options).run();
    }
}