PartitionedOutputOperator.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.repartition;

import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.RunLengthEncodedBlock;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.execution.buffer.OutputBuffer;
import com.facebook.presto.execution.buffer.PagesSerdeFactory;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.Operator;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.OperatorFactory;
import com.facebook.presto.operator.OutputFactory;
import com.facebook.presto.operator.PartitionFunction;
import com.facebook.presto.spi.page.PagesSerde;
import com.facebook.presto.spi.page.SerializedPage;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.sql.planner.OutputPartitioning;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;

import javax.annotation.Nullable;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.facebook.presto.common.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES;
import static com.facebook.presto.execution.buffer.PageSplitterUtil.splitPage;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.util.Objects.requireNonNull;

public class PartitionedOutputOperator
        implements Operator
{
    public static class PartitionedOutputFactory
            implements OutputFactory
    {
        private final OutputBuffer outputBuffer;
        private final DataSize maxMemory;

        public PartitionedOutputFactory(OutputBuffer outputBuffer, DataSize maxMemory)
        {
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
            this.maxMemory = requireNonNull(maxMemory, "maxMemory is null");
        }

        @Override
        public OperatorFactory createOutputOperator(
                int operatorId,
                PlanNodeId planNodeId,
                List<Type> types,
                Function<Page, Page> pagePreprocessor,
                Optional<OutputPartitioning> outputPartitioning,
                PagesSerdeFactory serdeFactory)
        {
            checkArgument(outputPartitioning.isPresent(), "outputPartitioning is not present");
            return new PartitionedOutputOperatorFactory(
                    operatorId,
                    planNodeId,
                    types,
                    pagePreprocessor,
                    outputPartitioning.get().getPartitionFunction(),
                    outputPartitioning.get().getPartitionChannels(),
                    outputPartitioning.get().getPartitionConstants(),
                    outputPartitioning.get().isReplicateNullsAndAny(),
                    outputPartitioning.get().getNullChannel(),
                    outputBuffer,
                    serdeFactory,
                    maxMemory);
        }
    }

    public static class PartitionedOutputOperatorFactory
            implements OperatorFactory
    {
        private final int operatorId;
        private final PlanNodeId planNodeId;
        private final List<Type> sourceTypes;
        private final Function<Page, Page> pagePreprocessor;
        private final PartitionFunction partitionFunction;
        private final List<Integer> partitionChannels;
        private final List<Optional<ConstantExpression>> partitionConstants;
        private final boolean replicatesAnyRow;
        private final OptionalInt nullChannel;
        private final OutputBuffer outputBuffer;
        private final PagesSerdeFactory serdeFactory;
        private final DataSize maxMemory;

        public PartitionedOutputOperatorFactory(
                int operatorId,
                PlanNodeId planNodeId,
                List<Type> sourceTypes,
                Function<Page, Page> pagePreprocessor,
                PartitionFunction partitionFunction,
                List<Integer> partitionChannels,
                List<Optional<ConstantExpression>> partitionConstants,
                boolean replicatesAnyRow,
                OptionalInt nullChannel,
                OutputBuffer outputBuffer,
                PagesSerdeFactory serdeFactory,
                DataSize maxMemory)
        {
            this.operatorId = operatorId;
            this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
            this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null");
            this.pagePreprocessor = requireNonNull(pagePreprocessor, "pagePreprocessor is null");
            this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null");
            this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null");
            this.partitionConstants = requireNonNull(partitionConstants, "partitionConstants is null");
            this.replicatesAnyRow = replicatesAnyRow;
            this.nullChannel = requireNonNull(nullChannel, "nullChannel is null");
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
            this.serdeFactory = requireNonNull(serdeFactory, "serdeFactory is null");
            this.maxMemory = requireNonNull(maxMemory, "maxMemory is null");
        }

        @Override
        public Operator createOperator(DriverContext driverContext)
        {
            OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, PartitionedOutputOperator.class.getSimpleName());
            return new PartitionedOutputOperator(
                    operatorContext,
                    sourceTypes,
                    pagePreprocessor,
                    partitionFunction,
                    partitionChannels,
                    partitionConstants,
                    replicatesAnyRow,
                    nullChannel,
                    outputBuffer,
                    serdeFactory,
                    maxMemory);
        }

        @Override
        public void noMoreOperators()
        {
        }

        @Override
        public OperatorFactory duplicate()
        {
            return new PartitionedOutputOperatorFactory(
                    operatorId,
                    planNodeId,
                    sourceTypes,
                    pagePreprocessor,
                    partitionFunction,
                    partitionChannels,
                    partitionConstants,
                    replicatesAnyRow,
                    nullChannel,
                    outputBuffer,
                    serdeFactory,
                    maxMemory);
        }
    }

    private final OperatorContext operatorContext;
    private final Function<Page, Page> pagePreprocessor;
    private final PagePartitioner partitionFunction;
    private ListenableFuture<?> isBlocked = NOT_BLOCKED;
    private boolean finished;

    public PartitionedOutputOperator(
            OperatorContext operatorContext,
            List<Type> sourceTypes,
            Function<Page, Page> pagePreprocessor,
            PartitionFunction partitionFunction,
            List<Integer> partitionChannels,
            List<Optional<ConstantExpression>> partitionConstants,
            boolean replicatesAnyRow,
            OptionalInt nullChannel,
            OutputBuffer outputBuffer,
            PagesSerdeFactory serdeFactory,
            DataSize maxMemory)
    {
        this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
        this.pagePreprocessor = requireNonNull(pagePreprocessor, "pagePreprocessor is null");
        this.partitionFunction = new PagePartitioner(
                partitionFunction,
                partitionChannels,
                partitionConstants,
                replicatesAnyRow,
                nullChannel,
                outputBuffer,
                serdeFactory,
                sourceTypes,
                maxMemory,
                operatorContext);

        operatorContext.setInfoSupplier(this.partitionFunction.getPartitionedOutputInfoSupplier());
    }

    @Override
    public OperatorContext getOperatorContext()
    {
        return operatorContext;
    }

    @Override
    public void finish()
    {
        finished = true;
        partitionFunction.flush(true);
    }

    @Override
    public boolean isFinished()
    {
        return finished && isBlocked().isDone();
    }

    @Override
    public ListenableFuture<?> isBlocked()
    {
        // Avoid re-synchronizing on the output buffer when operator is already blocked
        if (isBlocked.isDone()) {
            isBlocked = partitionFunction.isFull();
            if (isBlocked.isDone()) {
                isBlocked = NOT_BLOCKED;
            }
        }
        return isBlocked;
    }

    @Override
    public boolean needsInput()
    {
        return !finished && isBlocked().isDone();
    }

    @Override
    public void addInput(Page page)
    {
        requireNonNull(page, "page is null");

        if (page.getPositionCount() == 0) {
            return;
        }

        page = pagePreprocessor.apply(page);
        partitionFunction.partitionPage(page);
    }

    @Override
    public Page getOutput()
    {
        return null;
    }

    @Override
    public void close()
    {
        partitionFunction.zeroMemoryContext();
    }

    private static class PagePartitioner
    {
        private final OutputBuffer outputBuffer;
        private final Type[] sourceTypes;
        private final PartitionFunction partitionFunction;
        private final int[] partitionChannels;
        @Nullable
        private final Block[] partitionConstantBlocks; // when null, no constants are present. Only non-null elements are constants
        private final PagesSerde serde;
        private final PageBuilder[] pageBuilders;
        private final boolean replicatesAnyRow;
        private final int nullChannel; // when >= 0, send the position to every partition if this channel is null
        private final AtomicLong rowsAdded = new AtomicLong();
        private final AtomicLong pagesAdded = new AtomicLong();
        private boolean hasAnyRowBeenReplicated;
        private final OperatorContext operatorContext;
        private final LocalMemoryContext systemMemoryContext;

        public PagePartitioner(
                PartitionFunction partitionFunction,
                List<Integer> partitionChannels,
                List<Optional<ConstantExpression>> partitionConstants,
                boolean replicatesAnyRow,
                OptionalInt nullChannel,
                OutputBuffer outputBuffer,
                PagesSerdeFactory serdeFactory,
                List<Type> sourceTypes,
                DataSize maxMemory,
                OperatorContext operatorContext)
        {
            this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null");
            this.partitionChannels = Ints.toArray(requireNonNull(partitionChannels, "partitionChannels is null"));
            Block[] partitionConstantBlocks = requireNonNull(partitionConstants, "partitionConstants is null").stream()
                    .map(constant -> constant.map(ConstantExpression::getValueBlock).orElse(null))
                    .toArray(Block[]::new);
            if (Arrays.stream(partitionConstantBlocks).anyMatch(Objects::nonNull)) {
                this.partitionConstantBlocks = partitionConstantBlocks;
            }
            else {
                this.partitionConstantBlocks = null;
            }
            this.replicatesAnyRow = replicatesAnyRow;
            this.nullChannel = requireNonNull(nullChannel, "nullChannel is null").orElse(-1);
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
            this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]);
            this.serde = requireNonNull(serdeFactory, "serdeFactory is null").createPagesSerde();
            this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
            this.systemMemoryContext = operatorContext.localSystemMemoryContext();
            this.systemMemoryContext.setBytes(getRetainedSizeInBytes());

            //  Ensure partition channels align with constant arguments provided
            for (int i = 0; i < this.partitionChannels.length; i++) {
                if (this.partitionChannels[i] < 0) {
                    checkArgument(this.partitionConstantBlocks != null && this.partitionConstantBlocks[i] != null,
                            "Expected constant for partitioning channel %s, but none was found", i);
                }
            }

            int partitionCount = partitionFunction.getPartitionCount();
            int pageSize = min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, ((int) maxMemory.toBytes()) / partitionCount);
            pageSize = max(1, pageSize);

            this.pageBuilders = new PageBuilder[partitionCount];
            for (int i = 0; i < partitionCount; i++) {
                pageBuilders[i] = PageBuilder.withMaxPageSize(pageSize, sourceTypes);
            }
        }

        public void zeroMemoryContext()
        {
            systemMemoryContext.setBytes(0);
        }

        public ListenableFuture<?> isFull()
        {
            return outputBuffer.isFull();
        }

        public long getSizeInBytes()
        {
            // We use a foreach loop instead of streams
            // as it has much better performance.
            long sizeInBytes = serde.getSizeInBytes();
            if (pageBuilders != null) {
                for (PageBuilder pageBuilder : pageBuilders) {
                    sizeInBytes += pageBuilder.getSizeInBytes();
                }
            }
            return sizeInBytes;
        }

        /**
         * This method can be expensive for complex types.
         */
        public long getRetainedSizeInBytes()
        {
            long sizeInBytes = serde.getRetainedSizeInBytes();
            if (pageBuilders != null) {
                for (PageBuilder pageBuilder : pageBuilders) {
                    sizeInBytes += pageBuilder.getRetainedSizeInBytes();
                }
            }
            return sizeInBytes;
        }

        public Supplier<PartitionedOutputInfo> getPartitionedOutputInfoSupplier()
        {
            // Must be a separate static method to avoid embedding references to "this" in the supplier
            return PartitionedOutputInfo.createPartitionedOutputInfoSupplier(rowsAdded, pagesAdded, outputBuffer);
        }

        public void partitionPage(Page page)
        {
            requireNonNull(page, "page is null");
            if (page.getPositionCount() == 0) {
                return;
            }

            int position;
            // Handle "any row" replication outside of the inner loop processing
            if (replicatesAnyRow && !hasAnyRowBeenReplicated) {
                for (PageBuilder pageBuilder : pageBuilders) {
                    appendRow(pageBuilder, page, 0);
                }
                hasAnyRowBeenReplicated = true;
                position = 1;
            }
            else {
                position = 0;
            }

            Page partitionFunctionArgs = getPartitionFunctionArguments(page);
            // Skip null block checks if mayHaveNull reports that no positions will be null
            if (nullChannel >= 0 && page.getBlock(nullChannel).mayHaveNull()) {
                Block nullsBlock = page.getBlock(nullChannel);
                for (; position < page.getPositionCount(); position++) {
                    if (nullsBlock.isNull(position)) {
                        for (PageBuilder pageBuilder : pageBuilders) {
                            appendRow(pageBuilder, page, position);
                        }
                    }
                    else {
                        int partition = partitionFunction.getPartition(partitionFunctionArgs, position);
                        appendRow(pageBuilders[partition], page, position);
                    }
                }
            }
            else {
                for (; position < page.getPositionCount(); position++) {
                    int partition = partitionFunction.getPartition(partitionFunctionArgs, position);
                    appendRow(pageBuilders[partition], page, position);
                }
            }

            // We track the memory before it's flushed to avoid under counting when the page size is large.
            systemMemoryContext.setBytes(getRetainedSizeInBytes());

            flush(false);
        }

        private Page getPartitionFunctionArguments(Page page)
        {
            // Fast path for no constants
            if (partitionConstantBlocks == null) {
                return page.extractChannels(partitionChannels);
            }

            Block[] blocks = new Block[partitionChannels.length];
            for (int i = 0; i < blocks.length; i++) {
                int channel = partitionChannels[i];
                if (channel < 0) {
                    blocks[i] = new RunLengthEncodedBlock(partitionConstantBlocks[i], page.getPositionCount());
                }
                else {
                    blocks[i] = page.getBlock(channel);
                }
            }
            return new Page(page.getPositionCount(), blocks);
        }

        private void appendRow(PageBuilder pageBuilder, Page page, int position)
        {
            pageBuilder.declarePosition();

            for (int channel = 0; channel < sourceTypes.length; channel++) {
                Type type = sourceTypes[channel];
                type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel));
            }
        }

        public void flush(boolean force)
        {
            // add all full pages to output buffer
            for (int partition = 0; partition < pageBuilders.length; partition++) {
                PageBuilder partitionPageBuilder = pageBuilders[partition];
                if (!partitionPageBuilder.isEmpty() && (force || partitionPageBuilder.isFull())) {
                    Page pagePartition = partitionPageBuilder.build();
                    partitionPageBuilder.reset();

                    operatorContext.recordOutput(pagePartition.getSizeInBytes(), pagePartition.getPositionCount());

                    outputBuffer.enqueue(operatorContext.getDriverContext().getLifespan(), partition, splitAndSerializePage(pagePartition));
                    pagesAdded.incrementAndGet();
                    rowsAdded.addAndGet(pagePartition.getPositionCount());
                }
            }
        }

        private List<SerializedPage> splitAndSerializePage(Page pagePartition)
        {
            List<Page> pagesFromSplitting = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES);
            ImmutableList.Builder<SerializedPage> builder = ImmutableList.builderWithExpectedSize(pagesFromSplitting.size());
            for (Page p : pagesFromSplitting) {
                builder.add(serde.serialize(p));
            }
            return builder.build();
        }
    }
}