UnnestOperator.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.unnest;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.LongArrayBlock;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.Operator;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.OperatorFactory;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import org.openjdk.jol.info.ClassLayout;

import java.util.List;
import java.util.Optional;

import static com.facebook.presto.common.array.Arrays.ExpansionFactor.SMALL;
import static com.facebook.presto.common.array.Arrays.ExpansionOption.INITIALIZE;
import static com.facebook.presto.common.array.Arrays.ensureCapacity;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.SizeOf.sizeOf;
import static java.lang.Integer.max;
import static java.util.Objects.requireNonNull;

public class UnnestOperator
        implements Operator
{
    public static class UnnestOperatorFactory
            implements OperatorFactory
    {
        private final int operatorId;
        private final PlanNodeId planNodeId;
        private final List<Integer> replicateChannels;
        private final List<Type> replicateTypes;
        private final List<Integer> unnestChannels;
        private final List<Type> unnestTypes;
        private final boolean withOrdinality;
        private boolean closed;

        public UnnestOperatorFactory(
                int operatorId,
                PlanNodeId planNodeId,
                List<Integer> replicateChannels,
                List<Type> replicateTypes,
                List<Integer> unnestChannels,
                List<Type> unnestTypes,
                boolean withOrdinality)
        {
            this.operatorId = operatorId;
            this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
            this.replicateChannels = ImmutableList.copyOf(requireNonNull(replicateChannels, "replicateChannels is null"));
            this.replicateTypes = ImmutableList.copyOf(requireNonNull(replicateTypes, "replicateTypes is null"));
            checkArgument(replicateChannels.size() == replicateTypes.size(), "replicateChannels and replicateTypes do not match");
            this.unnestChannels = ImmutableList.copyOf(requireNonNull(unnestChannels, "unnestChannels is null"));
            this.unnestTypes = ImmutableList.copyOf(requireNonNull(unnestTypes, "unnestTypes is null"));
            checkArgument(unnestChannels.size() == unnestTypes.size(), "unnestChannels and unnestTypes do not match");
            this.withOrdinality = withOrdinality;
        }

        @Override
        public Operator createOperator(DriverContext driverContext)
        {
            checkState(!closed, "Factory is already closed");
            return createOperator(driverContext, SystemSessionProperties.isLegacyUnnest(driverContext.getSession()));
        }

        @VisibleForTesting
        public Operator createOperator(DriverContext driverContext, boolean legacyUnnest)
        {
            OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, UnnestOperator.class.getSimpleName());
            return new UnnestOperator(operatorContext, replicateChannels, replicateTypes, unnestChannels, unnestTypes, withOrdinality, legacyUnnest);
        }

        @Override
        public void noMoreOperators()
        {
            closed = true;
        }

        @Override
        public OperatorFactory duplicate()
        {
            return new UnnestOperator.UnnestOperatorFactory(operatorId, planNodeId, replicateChannels, replicateTypes, unnestChannels, unnestTypes, withOrdinality);
        }
    }

    private static final int INSTANCE_SIZE = ClassLayout.parseClass(UnnestOperator.class).instanceSize();
    private static final int MAX_ROWS_PER_BLOCK = 1000;

    private final OperatorContext operatorContext;
    private final LocalMemoryContext systemMemoryContext;
    private final List<Integer> replicateChannels;
    private final List<Type> replicateTypes;
    private final List<ReplicatedBlockBuilder> replicatedBlockBuilders;
    private final List<Integer> unnestChannels;
    private final List<Type> unnestTypes;
    private final List<Unnester> unnesters;
    private final boolean withOrdinality;
    private final int outputChannelCount;

    private boolean finishing;
    private Page currentPage;
    private int currentPosition;
    private int[] maxLengths = new int[0];
    private int currentBatchTotalLength;

    public UnnestOperator(OperatorContext operatorContext, List<Integer> replicateChannels, List<Type> replicateTypes, List<Integer> unnestChannels, List<Type> unnestTypes, boolean withOrdinality, boolean isLegacyUnnest)
    {
        this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
        this.systemMemoryContext = operatorContext.localSystemMemoryContext();

        this.replicateChannels = ImmutableList.copyOf(requireNonNull(replicateChannels, "replicateChannels is null"));
        this.replicateTypes = ImmutableList.copyOf(requireNonNull(replicateTypes, "replicateTypes is null"));
        checkArgument(replicateChannels.size() == replicateTypes.size(), "replicate channels or types has wrong size");
        this.replicatedBlockBuilders = replicateTypes.stream()
                .map(type -> new ReplicatedBlockBuilder())
                .collect(toImmutableList());

        this.unnestChannels = ImmutableList.copyOf(requireNonNull(unnestChannels, "unnestChannels is null"));
        this.unnestTypes = ImmutableList.copyOf(requireNonNull(unnestTypes, "unnestTypes is null"));
        checkArgument(unnestChannels.size() == unnestTypes.size(), "unnest channels or types has wrong size");
        this.unnesters = unnestTypes.stream()
                .map(nestedType -> createUnnester(nestedType, isLegacyUnnest))
                .collect(toImmutableList());

        this.withOrdinality = withOrdinality;
        this.outputChannelCount = unnesters.stream().mapToInt(Unnester::getChannelCount).sum() + replicateTypes.size() + (withOrdinality ? 1 : 0);
    }

    @Override
    public OperatorContext getOperatorContext()
    {
        return operatorContext;
    }

    @Override
    public void finish()
    {
        finishing = true;
    }

    @Override
    public boolean isFinished()
    {
        return finishing && currentPage == null;
    }

    @Override
    public boolean needsInput()
    {
        return !finishing && currentPage == null;
    }

    @Override
    public void addInput(Page page)
    {
        checkState(!finishing, "Operator is already finishing");
        requireNonNull(page, "page is null");
        checkState(currentPage == null, "currentPage is not null");

        currentPage = page;
        currentPosition = 0;

        resetBlockBuilders();

        systemMemoryContext.setBytes(getRetainedSizeInBytes());
    }

    @Override
    public Page getOutput()
    {
        if (currentPage == null) {
            return null;
        }

        int positionCount = currentPage.getPositionCount();
        int batchSize = calculateNextBatchSize();
        Block[] outputBlocks = buildOutputBlocks(batchSize);

        if (currentPosition == positionCount) {
            currentPage = null;
            currentPosition = 0;
        }

        return new Page(outputBlocks);
    }

    private static Unnester createUnnester(Type nestedType, boolean isLegacyUnnest)
    {
        if (nestedType instanceof ArrayType) {
            Type elementType = ((ArrayType) nestedType).getElementType();

            if (!isLegacyUnnest && elementType instanceof RowType) {
                return new ArrayOfRowsUnnester(elementType.getTypeParameters().size());
            }
            else {
                return new ArrayUnnester();
            }
        }
        else if (nestedType instanceof MapType) {
            return new MapUnnester();
        }
        else {
            throw new IllegalArgumentException("Cannot unnest type: " + nestedType);
        }
    }

    private void resetBlockBuilders()
    {
        for (int i = 0; i < replicateTypes.size(); i++) {
            Block newInputBlock = currentPage.getBlock(replicateChannels.get(i));
            replicatedBlockBuilders.get(i).resetInputBlock(newInputBlock);
        }

        int positionCount = currentPage.getPositionCount();
        maxLengths = ensureCapacity(maxLengths, positionCount, SMALL, INITIALIZE);

        for (int i = 0; i < unnestTypes.size(); i++) {
            int inputChannel = unnestChannels.get(i);
            Block unnestChannelInputBlock = currentPage.getBlock(inputChannel);
            Unnester unnester = unnesters.get(i);

            unnester.resetInput(unnestChannelInputBlock);

            int[] lengths = unnester.getLengths();
            for (int j = 0; j < positionCount; j++) {
                maxLengths[j] = max(maxLengths[j], lengths[j]);
            }
        }
    }

    private int calculateNextBatchSize()
    {
        int positionCount = currentPage.getPositionCount();

        int totalLengths = 0;
        int position = currentPosition;
        while (position < positionCount) {
            int length = maxLengths[position];
            if (totalLengths + length >= MAX_ROWS_PER_BLOCK) {
                break;
            }
            totalLengths += length;
            position++;
        }

        // grab at least a single position
        if (position == currentPosition) {
            currentBatchTotalLength = maxLengths[currentPosition];
            return 1;
        }

        currentBatchTotalLength = totalLengths;
        return position - currentPosition;
    }

    private Block[] buildOutputBlocks(int batchSize)
    {
        Block[] outputBlocks = new Block[outputChannelCount];
        int channel = 0;

        for (int replicateIndex = 0; replicateIndex < replicateTypes.size(); replicateIndex++) {
            outputBlocks[channel++] = replicatedBlockBuilders.get(replicateIndex).buildOutputBlock(maxLengths, currentPosition, batchSize, currentBatchTotalLength);
        }

        for (int unnestIndex = 0; unnestIndex < unnesters.size(); unnestIndex++) {
            Unnester unnester = unnesters.get(unnestIndex);
            Block[] block = unnester.buildOutputBlocks(maxLengths, currentPosition, batchSize, currentBatchTotalLength);
            for (int j = 0; j < unnester.getChannelCount(); j++) {
                outputBlocks[channel++] = block[j];
            }
        }

        if (withOrdinality) {
            outputBlocks[channel] = buildOrdinalityOutputBlock(maxLengths, currentPosition, batchSize, currentBatchTotalLength);
        }

        currentPosition += batchSize;

        return outputBlocks;
    }

    private static Block buildOrdinalityOutputBlock(int[] maxEntries, int offset, int length, int totalEntriesForBatch)
    {
        long[] values = new long[totalEntriesForBatch];
        int index = 0;
        for (int i = 0; i < length; i++) {
            int curEntries = maxEntries[offset + i];
            for (int j = 1; j <= curEntries; j++) {
                values[index++] = j;
            }
        }

        return new LongArrayBlock(totalEntriesForBatch, Optional.empty(), values);
    }

    private long getRetainedSizeInBytes()
    {
        long size = INSTANCE_SIZE + sizeOf(maxLengths) + currentPage.getRetainedSizeInBytes();
        for (Unnester unnester : unnesters) {
            size += unnester.getRetainedSizeInBytes();
        }
        return size;
    }
}