GenericAccumulatorFactory.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.array.IntBigArray;
import com.facebook.presto.common.array.ObjectBigArray;
import com.facebook.presto.common.block.ArrayBlock;
import com.facebook.presto.common.block.ArrayBlockBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.ColumnarArray;
import com.facebook.presto.common.block.ColumnarRow;
import com.facebook.presto.common.block.LongArrayBlock;
import com.facebook.presto.common.block.RowBlock;
import com.facebook.presto.common.block.RowBlockBuilder;
import com.facebook.presto.common.block.RunLengthEncodedBlock;
import com.facebook.presto.common.block.SingleRowBlock;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.MarkDistinctHash;
import com.facebook.presto.operator.PagesIndex;
import com.facebook.presto.operator.UpdateMemory;
import com.facebook.presto.operator.Work;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.spi.function.WindowIndex;
import com.facebook.presto.spi.function.aggregation.Accumulator;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import com.facebook.presto.spi.function.aggregation.GroupByIdBlock;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.facebook.presto.spi.function.aggregation.LambdaProvider;
import com.facebook.presto.spi.storage.SerializedStorageHandle;
import com.facebook.presto.spiller.StandaloneSpiller;
import com.facebook.presto.spiller.StandaloneSpillerFactory;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Booleans;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import org.openjdk.jol.info.ClassLayout;

import javax.annotation.Nullable;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.facebook.presto.common.Page.wrapBlocksWithoutCopy;
import static com.facebook.presto.common.block.ColumnarArray.toColumnarArray;
import static com.facebook.presto.common.block.ColumnarRow.toColumnarRow;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.sessionpropertyproviders.JavaWorkerSessionPropertyProvider.getDistinctAggregationLargeBlockSizeThreshold;
import static com.facebook.presto.sessionpropertyproviders.JavaWorkerSessionPropertyProvider.isDedupBasedDistinctAggregationSpillEnabled;
import static com.facebook.presto.sessionpropertyproviders.JavaWorkerSessionPropertyProvider.isDistinctAggregationLargeBlockSpillEnabled;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterators.singletonIterator;
import static io.airlift.slice.Slices.EMPTY_SLICE;
import static io.airlift.slice.Slices.wrappedBuffer;
import static java.lang.Long.max;
import static java.util.Objects.isNull;
import static java.util.Objects.requireNonNull;

public class GenericAccumulatorFactory
        implements AccumulatorFactory
{
    private final List<AccumulatorStateDescriptor> stateDescriptors;
    private final Constructor<? extends Accumulator> accumulatorConstructor;
    private final Constructor<? extends GroupedAccumulator> groupedAccumulatorConstructor;
    private final List<LambdaProvider> lambdaProviders;
    private final Optional<Integer> maskChannel;
    private final List<Integer> inputChannels;
    private final List<Type> sourceTypes;
    private final List<Integer> orderByChannels;
    private final List<SortOrder> orderings;

    @Nullable
    private final JoinCompiler joinCompiler;

    @Nullable
    private final Session session;
    private final boolean distinct;
    private final boolean spillEnabled;
    private final PagesIndex.Factory pagesIndexFactory;
    private final StandaloneSpillerFactory standaloneSpillerFactory;

    public GenericAccumulatorFactory(
            List<AccumulatorStateDescriptor> stateDescriptors,
            Constructor<? extends Accumulator> accumulatorConstructor,
            Constructor<? extends GroupedAccumulator> groupedAccumulatorConstructor,
            List<LambdaProvider> lambdaProviders,
            List<Integer> inputChannels,
            Optional<Integer> maskChannel,
            List<Type> sourceTypes,
            List<Integer> orderByChannels,
            List<SortOrder> orderings,
            PagesIndex.Factory pagesIndexFactory,
            JoinCompiler joinCompiler,
            Session session,
            boolean distinct,
            boolean spillEnabled,
            StandaloneSpillerFactory standaloneSpillerFactory)
    {
        this.stateDescriptors = requireNonNull(stateDescriptors, "stateDescriptors is null");
        this.accumulatorConstructor = requireNonNull(accumulatorConstructor, "accumulatorConstructor is null");
        this.groupedAccumulatorConstructor = requireNonNull(groupedAccumulatorConstructor, "groupedAccumulatorConstructor is null");
        this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null"));
        this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
        this.inputChannels = ImmutableList.copyOf(requireNonNull(inputChannels, "inputChannels is null"));
        this.sourceTypes = ImmutableList.copyOf(requireNonNull(sourceTypes, "sourceTypes is null"));
        this.orderByChannels = ImmutableList.copyOf(requireNonNull(orderByChannels, "orderByChannels is null"));
        this.orderings = ImmutableList.copyOf(requireNonNull(orderings, "orderings is null"));
        checkArgument(orderByChannels.isEmpty() || !isNull(pagesIndexFactory), "No pagesIndexFactory to process ordering");
        this.pagesIndexFactory = pagesIndexFactory;

        checkArgument(!distinct || !isNull(session) && !isNull(joinCompiler) && !isNull(standaloneSpillerFactory), "joinCompiler, session and standaloneSpillerFactory needed when distinct is true");
        this.joinCompiler = joinCompiler;
        this.session = session;
        this.distinct = distinct;
        this.spillEnabled = spillEnabled;
        this.standaloneSpillerFactory = standaloneSpillerFactory;
    }

    @Override
    public List<Integer> getInputChannels()
    {
        return inputChannels;
    }

    @Override
    public Accumulator createAccumulator(UpdateMemory updateMemory)
    {
        Accumulator accumulator;

        if (hasDistinct()) {
            // channel 0 will contain the distinct mask
            accumulator = instantiateAccumulator(
                    inputChannels.stream()
                            .map(value -> value + 1)
                            .collect(Collectors.toList()),
                    Optional.of(0));

            List<Type> argumentTypes = inputChannels.stream()
                    .map(sourceTypes::get)
                    .collect(Collectors.toList());

            accumulator = new DistinctingAccumulator(accumulator, argumentTypes, inputChannels, maskChannel, session, joinCompiler, updateMemory);
        }
        else {
            accumulator = instantiateAccumulator(inputChannels, maskChannel);
        }

        if (orderByChannels.isEmpty()) {
            return accumulator;
        }

        return new OrderingAccumulator(accumulator, sourceTypes, orderByChannels, orderings, pagesIndexFactory);
    }

    @Override
    public Accumulator createIntermediateAccumulator()
    {
        try {
            return accumulatorConstructor.newInstance(stateDescriptors, ImmutableList.of(), Optional.empty(), lambdaProviders);
        }
        catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public GroupedAccumulator createGroupedAccumulator(UpdateMemory updateMemory)
    {
        GroupedAccumulator accumulator = createGenericGroupedAccumulator(updateMemory);
        if (!spillEnabled || (!hasDistinct() && !hasOrderBy())) {
            return accumulator;
        }

        checkState(accumulator instanceof FinalOnlyGroupedAccumulator);
        ImmutableSet.Builder<Integer> aggregateInputChannels = ImmutableSet.builder();
        aggregateInputChannels.addAll(inputChannels);
        maskChannel.ifPresent(aggregateInputChannels::add);
        aggregateInputChannels.addAll(orderByChannels);

        checkState(session != null, "Session is null");
        if (isDedupBasedDistinctAggregationSpillEnabled(session) && hasDistinct() && !hasOrderBy()) {
            return new DedupBasedSpillableDistinctGroupedAccumulator(
                    sourceTypes,
                    aggregateInputChannels.build().asList(),
                    (DistinctingGroupedAccumulator) accumulator,
                    maskChannel,
                    standaloneSpillerFactory,
                    session);
        }
        return new SpillableFinalOnlyGroupedAccumulator(
                sourceTypes,
                aggregateInputChannels.build().asList(),
                (FinalOnlyGroupedAccumulator) accumulator,
                standaloneSpillerFactory,
                session);
    }

    @Override
    public GroupedAccumulator createGroupedIntermediateAccumulator(UpdateMemory updateMemory)
    {
        if (!hasOrderBy() && !hasDistinct()) {
            try {
                return groupedAccumulatorConstructor.newInstance(stateDescriptors, ImmutableList.of(), Optional.empty(), lambdaProviders);
            }
            catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
        }
        return createGroupedAccumulator(updateMemory);
    }

    @Override
    public boolean hasOrderBy()
    {
        return !orderByChannels.isEmpty();
    }

    @Override
    public boolean hasDistinct()
    {
        return distinct;
    }

    private GroupedAccumulator createGenericGroupedAccumulator(UpdateMemory updateMemory)
    {
        GroupedAccumulator accumulator;

        if (hasDistinct()) {
            // channel 0 will contain the distinct mask
            accumulator = instantiateGroupedAccumulator(
                    inputChannels.stream()
                            .map(value -> value + 1)
                            .collect(Collectors.toList()),
                    Optional.of(0));

            List<Type> argumentTypes = new ArrayList<>();
            for (int input : inputChannels) {
                argumentTypes.add(sourceTypes.get(input));
            }

            accumulator = new DistinctingGroupedAccumulator(accumulator, argumentTypes, inputChannels, maskChannel, session, joinCompiler, updateMemory);
        }
        else {
            accumulator = instantiateGroupedAccumulator(inputChannels, maskChannel);
        }

        if (orderByChannels.isEmpty()) {
            return accumulator;
        }

        return new OrderingGroupedAccumulator(accumulator, sourceTypes, orderByChannels, orderings, pagesIndexFactory);
    }

    private Accumulator instantiateAccumulator(List<Integer> inputs, Optional<Integer> mask)
    {
        try {
            return accumulatorConstructor.newInstance(stateDescriptors, inputs, mask, lambdaProviders);
        }
        catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    private GroupedAccumulator instantiateGroupedAccumulator(List<Integer> inputs, Optional<Integer> mask)
    {
        try {
            return groupedAccumulatorConstructor.newInstance(stateDescriptors, inputs, mask, lambdaProviders);
        }
        catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    public static AccumulatorFactory generateAccumulatorFactory(
            JavaAggregationFunctionImplementation functionImplementation,
            List<Integer> argumentChannels,
            Optional<Integer> maskChannel,
            List<Type> sourceTypes,
            List<Integer> orderByChannels,
            List<SortOrder> orderings,
            PagesIndex.Factory pagesIndexFactory,
            boolean distinct,
            JoinCompiler joinCompiler,
            List<LambdaProvider> lambdaProviders,
            boolean spillEnabled,
            Session session,
            StandaloneSpillerFactory standaloneSpillerFactory)
    {
        try {
            Constructor<? extends Accumulator> accumulatorConstructor = functionImplementation.getAccumulatorClass().getConstructor(
                    List.class,     /* List<AccumulatorStateDescriptor> stateDescriptors */
                    List.class,     /* List<Integer> inputChannel */
                    Optional.class, /* Optional<Integer> maskChannel */
                    List.class      /* List<LambdaProvider> lambdaProviders */);

            Constructor<? extends GroupedAccumulator> groupedAccumulatorConstructor = functionImplementation.getGroupedAccumulatorClass().getConstructor(
                    List.class,     /* List<AccumulatorStateDescriptor> stateDescriptors */
                    List.class,     /* List<Integer> inputChannel */
                    Optional.class, /* Optional<Integer> maskChannel */
                    List.class      /* List<LambdaProvider> lambdaProviders */);

            return new GenericAccumulatorFactory(
                    functionImplementation.getAggregationMetadata().getAccumulatorStateDescriptors(),
                    accumulatorConstructor,
                    groupedAccumulatorConstructor,
                    lambdaProviders,
                    argumentChannels,
                    maskChannel,
                    sourceTypes,
                    orderByChannels,
                    orderings,
                    pagesIndexFactory,
                    joinCompiler,
                    session,
                    distinct,
                    spillEnabled,
                    standaloneSpillerFactory);
        }
        catch (NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }

    public static AccumulatorFactory generateAccumulatorFactory(
            JavaAggregationFunctionImplementation javaAggregationFunctionImplementation,
            List<Integer> inputChannels, Optional<Integer> maskChannel)
    {
        return generateAccumulatorFactory(
                javaAggregationFunctionImplementation,
                inputChannels,
                maskChannel,
                ImmutableList.of(),
                ImmutableList.of(),
                ImmutableList.of(),
                null,
                false,
                null,
                ImmutableList.of(),
                false,
                null,
                null);
    }

    private static class DistinctingAccumulator
            implements Accumulator
    {
        private final Accumulator accumulator;
        private final MarkDistinctHash hash;
        private final int maskChannel;

        private DistinctingAccumulator(
                Accumulator accumulator,
                List<Type> inputTypes,
                List<Integer> inputs,
                Optional<Integer> maskChannel,
                Session session,
                JoinCompiler joinCompiler,
                UpdateMemory updateMemory)
        {
            this.accumulator = requireNonNull(accumulator, "accumulator is null");
            this.maskChannel = requireNonNull(maskChannel, "maskChannel is null").orElse(-1);

            hash = new MarkDistinctHash(
                    session,
                    inputTypes,
                    Ints.toArray(inputs),
                    Optional.empty(),
                    joinCompiler,
                    () -> {
                        // enforce task memory limits for fast throw
                        updateMemory.update();
                        // never block, as addInput doesn't support yield semantics
                        return true;
                    });
        }

        @Override
        public long getEstimatedSize()
        {
            return hash.getEstimatedSize() + accumulator.getEstimatedSize();
        }

        @Override
        public Type getFinalType()
        {
            return accumulator.getFinalType();
        }

        @Override
        public Type getIntermediateType()
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void addInput(Page page)
        {
            // 1. filter out positions based on mask, if present
            Page filtered;
            if (maskChannel >= 0) {
                filtered = filter(page, page.getBlock(maskChannel));
            }
            else {
                filtered = page;
            }

            if (filtered.getPositionCount() == 0) {
                return;
            }

            // 2. compute a distinct mask
            Work<Block> work = hash.markDistinctRows(filtered);
            checkState(work.process());
            Block distinctMask = work.getResult();

            // 3. feed a Page with a new mask to the underlying aggregation
            accumulator.addInput(filtered.prependColumn(distinctMask));
        }

        @Override
        public void addInput(WindowIndex index, List<Integer> channels, int startPosition, int endPosition)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void addIntermediate(Block block)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateIntermediate(BlockBuilder blockBuilder)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateFinal(BlockBuilder blockBuilder)
        {
            accumulator.evaluateFinal(blockBuilder);
        }
    }

    private static Page filter(Page page, Block mask)
    {
        int positions = mask.getPositionCount();
        if (positions > 0 && mask instanceof RunLengthEncodedBlock) {
            // must have at least 1 position to be able to check the value at position 0
            boolean isNull = mask.mayHaveNull() && mask.isNull(0);
            if (!isNull && BOOLEAN.getBoolean(mask, 0)) {
                return page;
            }
            else {
                return page.getPositions(new int[0], 0, 0);
            }
        }
        boolean mayHaveNull = mask.mayHaveNull();
        int[] ids = new int[positions];
        int next = 0;
        for (int i = 0; i < ids.length; ++i) {
            boolean isNull = mayHaveNull && mask.isNull(i);
            if (!isNull && BOOLEAN.getBoolean(mask, i)) {
                ids[next++] = i;
            }
        }

        if (next == ids.length) {
            return page; // no rows were eliminated by the filter
        }
        return page.getPositions(ids, 0, next);
    }

    private static class DistinctingGroupedAccumulator
            extends FinalOnlyGroupedAccumulator
    {
        private final GroupedAccumulator accumulator;
        private final List<Type> inputTypes;
        private final List<Integer> inputChannels;
        private final int maskChannel;
        private final Session session;
        private final JoinCompiler joinCompiler;
        private final UpdateMemory updateMemory;

        private MarkDistinctHash hash;

        private DistinctingGroupedAccumulator(
                GroupedAccumulator accumulator,
                List<Type> inputTypes,
                List<Integer> inputChannels,
                Optional<Integer> maskChannel,
                Session session,
                JoinCompiler joinCompiler,
                UpdateMemory updateMemory)
        {
            this.accumulator = requireNonNull(accumulator, "accumulator is null");
            this.inputTypes = requireNonNull(inputTypes, "inputTypes is null");
            this.inputChannels = requireNonNull(inputChannels, "inputChannels is null");
            this.maskChannel = requireNonNull(maskChannel, "maskChannel is null").orElse(-1);
            this.session = requireNonNull(session, "session is null");
            this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null");
            this.updateMemory = requireNonNull(updateMemory, "updateMemory is null");
            this.hash = createMarkDistinctHash();
        }

        private MarkDistinctHash createMarkDistinctHash()
        {
            List<Type> types = ImmutableList.<Type>builder()
                    .add(BIGINT) // group id column
                    .addAll(inputTypes)
                    .build();

            int[] inputs = new int[inputChannels.size() + 1];
            inputs[0] = 0; // we'll use the first channel for group id column
            for (int i = 0; i < inputChannels.size(); i++) {
                inputs[i + 1] = inputChannels.get(i) + 1;
            }

            return new MarkDistinctHash(
                    session,
                    types,
                    inputs,
                    Optional.empty(),
                    joinCompiler,
                    () -> {
                        // enforce task memory limits for fast throw
                        updateMemory.update();
                        // never block, as addInput doesn't support yield semantics
                        return true;
                    });
        }

        @Override
        public long getEstimatedSize()
        {
            return hash.getEstimatedSize() + accumulator.getEstimatedSize();
        }

        @Override
        public Type getFinalType()
        {
            return accumulator.getFinalType();
        }

        @Override
        public void addInput(GroupByIdBlock groupIdsBlock, Page page)
        {
            Page withGroup = page.prependColumn(groupIdsBlock);

            // 1. filter out positions based on mask, if present
            Page filtered = applyMaskChannelFilter(withGroup, maskChannel);

            // 2. compute a mask for the distinct rows (including the group id)
            Block distinctMask = computeDistinctMask(filtered, hash);

            // 3. feed a Page with a new mask to the underlying aggregation
            GroupByIdBlock groupIds = new GroupByIdBlock(groupIdsBlock.getGroupCount(), filtered.getBlock(0));

            // drop the group id column and prepend the distinct mask column
            Block[] columns = new Block[filtered.getChannelCount()];
            columns[0] = distinctMask;
            for (int i = 1; i < filtered.getChannelCount(); i++) {
                columns[i] = filtered.getBlock(i);
            }

            accumulator.addInput(groupIds, new Page(filtered.getPositionCount(), columns));
        }

        @Override
        public void evaluateFinal(int groupId, BlockBuilder output)
        {
            accumulator.evaluateFinal(groupId, output);
        }

        @Override
        public void prepareFinal()
        {
        }

        public GroupByIdBlock preprocessInput(GroupByIdBlock groupIdsBlock, Page page)
        {
            // Prepend the groupId block to the input page
            Page withGroup = page.prependColumn(groupIdsBlock);

            // Filter out positions based on mask, if present
            Page filtered = applyMaskChannelFilter(withGroup, maskChannel);

            // Compute a mask for the distinct rows (including the group id)
            // Distinct rows will be stored inside `hash`
            Block distinctMask = computeDistinctMask(filtered, hash);

            // Filter out duplicate rows and return the page with distinct rows
            Page dedupPage = filter(filtered, distinctMask);

            // Get updated GroupByIdBlock from distinctPage
            return new GroupByIdBlock(groupIdsBlock.getGroupCount(), dedupPage.getBlock(0));
        }

        private Page applyMaskChannelFilter(Page page, int maskChannel)
        {
            if (maskChannel >= 0) {
                return filter(page, page.getBlock(maskChannel + 1)); // offset by one because of group id in column 0
            }
            return page;
        }

        private Block computeDistinctMask(Page page, MarkDistinctHash hash)
        {
            Work<Block> work = hash.markDistinctRows(page);
            checkState(work.process());
            return work.getResult();
        }

        private List<Page> getDistinctPages()
        {
            return hash.getDistinctPages();
        }

        private void reset()
        {
            hash = createMarkDistinctHash();
        }
    }

    private static class OrderingAccumulator
            implements Accumulator
    {
        private final Accumulator accumulator;
        private final List<Integer> orderByChannels;
        private final List<SortOrder> orderings;
        private final PagesIndex pagesIndex;

        private OrderingAccumulator(
                Accumulator accumulator,
                List<Type> aggregationSourceTypes,
                List<Integer> orderByChannels,
                List<SortOrder> orderings,
                PagesIndex.Factory pagesIndexFactory)
        {
            this.accumulator = requireNonNull(accumulator, "accumulator is null");
            this.orderByChannels = ImmutableList.copyOf(requireNonNull(orderByChannels, "orderByChannels is null"));
            this.orderings = ImmutableList.copyOf(requireNonNull(orderings, "orderings is null"));
            this.pagesIndex = pagesIndexFactory.newPagesIndex(aggregationSourceTypes, 10_000);
        }

        @Override
        public long getEstimatedSize()
        {
            return pagesIndex.getEstimatedSize().toBytes() + accumulator.getEstimatedSize();
        }

        @Override
        public Type getFinalType()
        {
            return accumulator.getFinalType();
        }

        @Override
        public Type getIntermediateType()
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void addInput(Page page)
        {
            pagesIndex.addPage(page);
        }

        @Override
        public void addInput(WindowIndex index, List<Integer> channels, int startPosition, int endPosition)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void addIntermediate(Block block)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateIntermediate(BlockBuilder blockBuilder)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateFinal(BlockBuilder blockBuilder)
        {
            pagesIndex.sort(orderByChannels, orderings);
            Iterator<Page> pagesIterator = pagesIndex.getSortedPages();
            pagesIterator.forEachRemaining(accumulator::addInput);
            accumulator.evaluateFinal(blockBuilder);
        }
    }

    private static class OrderingGroupedAccumulator
            extends FinalOnlyGroupedAccumulator
    {
        private final GroupedAccumulator accumulator;
        private final List<Integer> orderByChannels;
        private final List<SortOrder> orderings;
        private final PagesIndex pagesIndex;
        private long groupCount;

        private OrderingGroupedAccumulator(
                GroupedAccumulator accumulator,
                List<Type> aggregationSourceTypes,
                List<Integer> orderByChannels,
                List<SortOrder> orderings,
                PagesIndex.Factory pagesIndexFactory)
        {
            this.accumulator = requireNonNull(accumulator, "accumulator is null");
            requireNonNull(aggregationSourceTypes, "aggregationSourceTypes is null");
            this.orderByChannels = ImmutableList.copyOf(requireNonNull(orderByChannels, "orderByChannels is null"));
            this.orderings = ImmutableList.copyOf(requireNonNull(orderings, "orderings is null"));
            List<Type> pageIndexTypes = new ArrayList<>(aggregationSourceTypes);
            // Add group id column
            pageIndexTypes.add(BIGINT);
            this.pagesIndex = pagesIndexFactory.newPagesIndex(pageIndexTypes, 10_000);
            this.groupCount = 0;
        }

        @Override
        public long getEstimatedSize()
        {
            return pagesIndex.getEstimatedSize().toBytes() + accumulator.getEstimatedSize();
        }

        @Override
        public Type getFinalType()
        {
            return accumulator.getFinalType();
        }

        @Override
        public void addInput(GroupByIdBlock groupIdsBlock, Page page)
        {
            groupCount = max(groupCount, groupIdsBlock.getGroupCount());
            // Add group id block
            pagesIndex.addPage(page.appendColumn(groupIdsBlock));
        }

        @Override
        public void evaluateFinal(int groupId, BlockBuilder output)
        {
            accumulator.evaluateFinal(groupId, output);
        }

        @Override
        public void prepareFinal()
        {
            pagesIndex.sort(orderByChannels, orderings);
            Iterator<Page> pagesIterator = pagesIndex.getSortedPages();
            pagesIterator.forEachRemaining(page -> {
                // The last channel of the page is the group id
                GroupByIdBlock groupIds = new GroupByIdBlock(groupCount, page.getBlock(page.getChannelCount() - 1));
                // We pass group id together with the other input channels to accumulator. Accumulator knows which input channels
                // to use. Since we did not change the order of original input channels, passing the group id is safe.
                accumulator.addInput(groupIds, page);
            });
        }
    }

    /**
     * {@link SpillableFinalOnlyGroupedAccumulator} enables spilling for {@link FinalOnlyGroupedAccumulator}
     */
    private static class SpillableFinalOnlyGroupedAccumulator
            implements GroupedAccumulator
    {
        private static final int INSTANCE_SIZE = ClassLayout.parseClass(SpillableFinalOnlyGroupedAccumulator.class).instanceSize();

        private final FinalOnlyGroupedAccumulator delegate;
        private final List<Type> sourceTypes;
        private final List<Type> spillingTypes;
        private final List<Integer> aggregateInputChannels;

        private ObjectBigArray<GroupIdPage> rawInputs = new ObjectBigArray<>();
        private IntBigArray groupIdCount = new IntBigArray();
        private ObjectBigArray<RowBlockBuilder> blockBuilders;
        private long rawInputsSizeInBytes;
        private long blockBuildersSizeInBytes;
        private long rawInputsLength;
        private final StandaloneSpiller standaloneSpiller;
        private final boolean isDistinctAggregationLargeBlockSpillEnabled;
        private final DataSize distinctAggregationLargeBlockSizeThreshold;

        public SpillableFinalOnlyGroupedAccumulator(
                List<Type> sourceTypes,
                List<Integer> aggregateInputChannels,
                FinalOnlyGroupedAccumulator delegate,
                StandaloneSpillerFactory standaloneSpillerFactory,
                Session session)
        {
            this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null");
            this.aggregateInputChannels = requireNonNull(aggregateInputChannels, "aggregateInputChannels is null");
            this.delegate = requireNonNull(delegate, "delegate is null");
            requireNonNull(standaloneSpillerFactory, "standaloneSpillerFactory is null");
            requireNonNull(session, "session is null");
            this.standaloneSpiller = standaloneSpillerFactory.create(session);
            this.isDistinctAggregationLargeBlockSpillEnabled = isDistinctAggregationLargeBlockSpillEnabled(session);
            this.distinctAggregationLargeBlockSizeThreshold = getDistinctAggregationLargeBlockSizeThreshold(session);
            this.spillingTypes = aggregateInputChannels.stream()
                    .map(sourceTypes::get)
                    .collect(toImmutableList());
        }

        @Override
        public long getEstimatedSize()
        {
            return INSTANCE_SIZE +
                    delegate.getEstimatedSize() +
                    (rawInputs == null ? 0 : rawInputsSizeInBytes + rawInputs.sizeOf()) +
                    (groupIdCount == null ? 0 : groupIdCount.sizeOf()) +
                    (blockBuilders == null ? 0 : blockBuildersSizeInBytes + blockBuilders.sizeOf());
        }

        @Override
        public Type getFinalType()
        {
            return delegate.getFinalType();
        }

        @Override
        public Type getIntermediateType()
        {
            if (isDistinctAggregationLargeBlockSpillEnabled) {
                // VARCHAR type will store the file handle, if present
                // Array type will store the actual squashed rows, if present
                return RowType.anonymous(ImmutableList.of(getIntermediateFileHandleType(), getIntermediateRowsType()));
            }

            return getIntermediateRowsType();
        }

        private Type getIntermediateFileHandleType()
        {
            return VARCHAR;
        }

        private Type getIntermediateRowsType()
        {
            return new ArrayType(RowType.anonymous(spillingTypes));
        }

        @Override
        public void addInput(GroupByIdBlock groupIdsBlock, Page page)
        {
            checkState(rawInputs != null && blockBuilders == null);

            // Create a new Page that only have channels which will be consumed by the aggregate
            Block[] blocks = new Block[aggregateInputChannels.size()];
            for (int i = 0; i < aggregateInputChannels.size(); i++) {
                blocks[i] = page.getBlock(aggregateInputChannels.get(i));
            }
            Page accumulatorInputPage = wrapBlocksWithoutCopy(page.getPositionCount(), blocks);
            addRawInput(groupIdsBlock, accumulatorInputPage);
            updateGroupIdCount(groupIdsBlock);
        }

        @Override
        public void addIntermediate(GroupByIdBlock groupIdsBlock, Block block)
        {
            checkState(rawInputs != null && blockBuilders == null);
            List<Long> newGroupIdsList = new ArrayList<>();
            List<Boolean> nullsList = new ArrayList<>();
            int newPositionCount = 0;

            if (isDistinctAggregationLargeBlockSpillEnabled) {
                checkState(block instanceof RowBlock);
                RowBlock rowBlock = (RowBlock) block;
                PageBuilder pageBuilder = new PageBuilder(spillingTypes);

                for (int groupIdPosition = 0; groupIdPosition < groupIdsBlock.getPositionCount(); groupIdPosition++) {
                    SingleRowBlock singleRowBlock = (SingleRowBlock) rowBlock.getBlock(groupIdPosition);

                    // Get serialized file handle which is stored in the first block
                    Block fileHandleBlock = singleRowBlock.getSingleValueBlock(0);
                    Slice fileHandleSlice = fileHandleBlock.getSlice(0, 0, fileHandleBlock.getSliceLength(0));

                    ColumnarRow columnarRow;
                    // If file handle is valid, read the rows back for the spilled file
                    if (fileHandleSlice != EMPTY_SLICE) {
                        SerializedStorageHandle serializedStorageHandle = new SerializedStorageHandle(fileHandleSlice.byteArray());
                        ImmutableList<Page> pages = ImmutableList.copyOf(standaloneSpiller.getSpilledPages(serializedStorageHandle));
                        standaloneSpiller.remove(serializedStorageHandle);
                        newPositionCount += pages.stream().map(Page::getPositionCount).mapToInt(v -> v).sum();

                        for (Page page : pages) {
                            columnarRow = toColumnarRow(page.getBlock(0));
                            for (int unused = 0; unused < columnarRow.getPositionCount(); unused++) {
                                if (columnarRow.isNull(unused)) {
                                    break;
                                }
                                newGroupIdsList.add(groupIdsBlock.getGroupId(groupIdPosition));
                                nullsList.add(groupIdsBlock.isNull(groupIdPosition));
                            }
                            addToPageBuilder(pageBuilder, columnarRow);
                        }
                    }
                    else {
                        // File handle is empty. Read squashed rows directly which is stored in second block
                        Block arrayBlock = singleRowBlock.getSingleValueBlock(1);
                        ColumnarArray columnarArray = toColumnarArray(arrayBlock);
                        Block elementBlock = columnarArray.getElementsBlock();
                        columnarRow = toColumnarRow(elementBlock);

                        newPositionCount += columnarRow.getNonNullPositionCount();
                        for (int unused = 0; unused < elementBlock.getPositionCount(); unused++) {
                            if (elementBlock.isNull(unused)) {
                                break;
                            }
                            newGroupIdsList.add(groupIdsBlock.getGroupId(groupIdPosition));
                            nullsList.add(groupIdsBlock.isNull(groupIdPosition));
                        }

                        addToPageBuilder(pageBuilder, columnarRow);
                    }
                }

                GroupByIdBlock squashedGroupIds = new GroupByIdBlock(
                        groupIdsBlock.getGroupCount(),
                        new LongArrayBlock(newPositionCount, Optional.of(Booleans.toArray(nullsList)), Longs.toArray(newGroupIdsList)));
                addRawInput(squashedGroupIds, pageBuilder.build());
            }
            else {
                checkState(block instanceof ArrayBlock);
                // expand array block back into page
                ArrayBlock arrayBlock = (ArrayBlock) block;
                ColumnarArray columnarArray = toColumnarArray(block); // flattens the squashed arrays; so there is no need to flatten block again.
                ColumnarRow columnarRow = toColumnarRow(columnarArray.getElementsBlock()); // contains the flattened array
                newPositionCount = columnarRow.getNonNullPositionCount(); // number of positions in expanded array (since columnarRow is already flattened)
                for (int groupIdPosition = 0; groupIdPosition < groupIdsBlock.getPositionCount(); groupIdPosition++) {
                    for (int unused = 0; unused < arrayBlock.getBlock(groupIdPosition).getPositionCount(); unused++) {
                        // unused because we are expanding all the squashed values for the same group id
                        if (arrayBlock.getBlock(groupIdPosition).isNull(unused)) {
                            break;
                        }
                        newGroupIdsList.add(groupIdsBlock.getGroupId(groupIdPosition));
                        nullsList.add(groupIdsBlock.isNull(groupIdPosition));
                    }
                }

                Block[] blocks = new Block[spillingTypes.size()];
                for (int channel = 0; channel < spillingTypes.size(); channel++) {
                    blocks[channel] = columnarRow.getField(channel);
                }

                Page page = new Page(blocks);
                GroupByIdBlock squashedGroupIds = new GroupByIdBlock(
                        groupIdsBlock.getGroupCount(),
                        new LongArrayBlock(newPositionCount, Optional.of(Booleans.toArray(nullsList)), Longs.toArray(newGroupIdsList)));
                addRawInput(squashedGroupIds, page);
            }
        }

        private void addToPageBuilder(PageBuilder pageBuilder, ColumnarRow columnarRow)
        {
            pageBuilder.declarePositions(columnarRow.getPositionCount());
            for (int i = 0; i < columnarRow.getPositionCount(); i++) {
                for (int channel = 0; channel < spillingTypes.size(); channel++) {
                    spillingTypes.get(channel).appendTo(columnarRow.getField(channel), i, pageBuilder.getBlockBuilder(channel));
                }
            }
        }

        @Override
        public void evaluateIntermediate(int groupId, BlockBuilder output)
        {
            checkState(output instanceof ArrayBlockBuilder || output instanceof RowBlockBuilder);
            if (blockBuilders == null) {
                checkState(rawInputs != null);
                blockBuilders = new ObjectBigArray<>();
                for (int i = 0; i < rawInputsLength; i++) {
                    GroupIdPage groupIdPage = rawInputs.get(i);
                    Page page = groupIdPage.getPage();
                    GroupByIdBlock groupIdsBlock = groupIdPage.getGroupByIdBlock();
                    for (int position = 0; position < page.getPositionCount(); position++) {
                        long currentGroupId = groupIdsBlock.getGroupId(position);
                        blockBuilders.ensureCapacity(currentGroupId);

                        RowBlockBuilder rowBlockBuilder = blockBuilders.get(currentGroupId);
                        long currentRowBlockSizeInBytes = 0;
                        if (rowBlockBuilder == null) {
                            rowBlockBuilder = new RowBlockBuilder(spillingTypes, null, groupIdCount.get(currentGroupId));
                        }
                        else {
                            currentRowBlockSizeInBytes = rowBlockBuilder.getRetainedSizeInBytes();
                        }

                        BlockBuilder currentOutput = rowBlockBuilder.beginBlockEntry();
                        for (int channel = 0; channel < spillingTypes.size(); channel++) {
                            spillingTypes.get(channel).appendTo(page.getBlock(channel), position, currentOutput);
                        }
                        rowBlockBuilder.closeEntry();

                        blockBuildersSizeInBytes += (rowBlockBuilder.getRetainedSizeInBytes() - currentRowBlockSizeInBytes);
                        blockBuilders.set(currentGroupId, rowBlockBuilder);
                    }
                    rawInputs.set(i, null);
                }
                groupIdCount = null;
                rawInputs = null;
                rawInputsSizeInBytes = 0;
                rawInputsLength = 0;
            }

            BlockBuilder singleArrayBlockWriter = output.beginBlockEntry();
            checkState(rawInputs == null && blockBuilders != null);

            if (groupId >= blockBuilders.getCapacity() || blockBuilders.get(groupId) == null) {
                // No rows for this groupId exist
                writeIntermediateRow(singleArrayBlockWriter, null, null);
            }
            else {
                // We need to squash the entire page into one array block since we can't spill multiple values for a single group ID during evaluateIntermediate.
                RowBlock rowBlock = (RowBlock) blockBuilders.get(groupId).build();
                if (isDistinctAggregationLargeBlockSpillEnabled) {
                    if (rowBlock.getSizeInBytes() > distinctAggregationLargeBlockSizeThreshold.toBytes()) {
                        Page page = new Page(rowBlock);
                        SerializedStorageHandle storageHandle = standaloneSpiller.spill(singletonIterator(page));
                        writeIntermediateRow(singleArrayBlockWriter, wrappedBuffer(storageHandle.getSerializedStorageHandle()), null);
                    }
                    else {
                        BlockBuilder intermediateRowBlockBuilder = getIntermediateRowsType().createBlockBuilder(null, rowBlock.getPositionCount());
                        BlockBuilder intermediateRowSingleArrayBlockWriter = intermediateRowBlockBuilder.beginBlockEntry();
                        for (int i = 0; i < rowBlock.getPositionCount(); i++) {
                            intermediateRowSingleArrayBlockWriter.appendStructure(rowBlock.getBlock(i));
                        }
                        intermediateRowBlockBuilder.closeEntry();
                        writeIntermediateRow(singleArrayBlockWriter, null, intermediateRowBlockBuilder.build().getBlock(0));
                    }
                }
                else {
                    for (int i = 0; i < rowBlock.getPositionCount(); i++) {
                        singleArrayBlockWriter.appendStructure(rowBlock.getBlock(i));
                    }
                }

                // We only call evaluateIntermediate when it is time to spill. We never call evaluate intermediate twice for the same groupId.
                // This means we can null our reference to the groupId's corresponding blockBuilder to reduce memory usage
                blockBuilders.set(groupId, null);
            }
            output.closeEntry();
        }

        private void writeIntermediateRow(BlockBuilder singleArrayBlockWriter, Slice fileHandle, Block squashedBlock)
        {
            if (isDistinctAggregationLargeBlockSpillEnabled) {
                if (fileHandle == null) {
                    singleArrayBlockWriter.appendNull();
                }
                else {
                    VARCHAR.writeSlice(singleArrayBlockWriter, fileHandle);
                }
            }

            if (squashedBlock == null) {
                singleArrayBlockWriter.appendNull();
            }
            else {
                singleArrayBlockWriter.appendStructure(squashedBlock);
            }
        }

        @Override
        public void evaluateFinal(int groupId, BlockBuilder output)
        {
            checkState(rawInputs == null && blockBuilders == null);
            delegate.evaluateFinal(groupId, output);
        }

        @Override
        public void prepareFinal()
        {
            checkState(rawInputs != null && blockBuilders == null);
            for (int i = 0; i < rawInputsLength; i++) {
                GroupIdPage groupIdPage = rawInputs.get(i);
                // Before pushing the page to delegate, restore it back to it original structure
                // in terms of number of channels. Channels which are not consumed by the accumulator
                // will be replaced with null block
                Page page = groupIdPage.getPage();
                Block[] blocks = new Block[sourceTypes.size()];
                for (int channel = 0; channel < sourceTypes.size(); channel++) {
                    if (aggregateInputChannels.contains(channel)) {
                        blocks[channel] = page.getBlock(aggregateInputChannels.indexOf(channel));
                    }
                    else {
                        blocks[channel] = RunLengthEncodedBlock.create(sourceTypes.get(channel), null, page.getPositionCount());
                    }
                }
                delegate.addInput(groupIdPage.getGroupByIdBlock(), wrapBlocksWithoutCopy(page.getPositionCount(), blocks));
            }

            rawInputs = null;
            rawInputsSizeInBytes = 0;
            rawInputsLength = 0;
            delegate.prepareFinal();
        }

        protected long getRawInputsLength()
        {
            return rawInputsLength;
        }

        protected List<Type> getSpillingTypes()
        {
            return spillingTypes;
        }

        protected void addRawInput(GroupByIdBlock groupByIdBlock, Page page)
        {
            rawInputs.ensureCapacity(rawInputsLength);
            GroupIdPage groupIdPage = new GroupIdPage(groupByIdBlock, page);
            rawInputsSizeInBytes += groupIdPage.getRetainedSizeInBytes();
            rawInputs.set(rawInputsLength, groupIdPage);
            rawInputsLength++;
        }

        protected void updateGroupIdCount(GroupByIdBlock groupIdsBlock)
        {
            // Keep track of number of elements for each groupId. This will later help us know the size of each
            // RowBlock we spill to disk. E.g. Let's say groupIdsBlock = [0, 1, 0]. In a subsequent addInput call,
            // groupIdsBlock = [2, 1, 0]. The resultant groupIdCount would be [3, 2, 1]. This is because there are
            // 3 values for groupId 0, 2 values for groupId 1, and 1 value for groupId 2. The index into groupIdCount
            // represents the groupId while the value is the total number of values for that groupId.
            for (int i = 0; i < groupIdsBlock.getPositionCount(); i++) {
                long currentGroupId = groupIdsBlock.getGroupId(i);
                groupIdCount.ensureCapacity(currentGroupId);
                groupIdCount.increment(currentGroupId);
            }
        }
    }

    /**
     * SpillableAccumulator to perform deduplication of input data for distinct aggregates only
     */
    private static class DedupBasedSpillableDistinctGroupedAccumulator
            extends SpillableFinalOnlyGroupedAccumulator
    {
        private final DistinctingGroupedAccumulator delegate;
        private final int maskChannel;

        private long groupCount;

        public DedupBasedSpillableDistinctGroupedAccumulator(
                List<Type> sourceTypes,
                List<Integer> aggregateInputChannels,
                DistinctingGroupedAccumulator delegate,
                Optional<Integer> maskChannel,
                StandaloneSpillerFactory standaloneSpillerFactory,
                Session session)
        {
            super(sourceTypes, aggregateInputChannels, delegate, standaloneSpillerFactory, session);
            this.delegate = requireNonNull(delegate, "delegate is null");
            this.maskChannel = requireNonNull(maskChannel, "maskChannel is null").orElse(-1);
        }

        @Override
        public void addInput(GroupByIdBlock groupIdsBlock, Page page)
        {
            groupCount = max(groupCount, groupIdsBlock.getGroupCount());
            updateGroupIdCount(delegate.preprocessInput(groupIdsBlock, page));
        }

        @Override
        public void evaluateIntermediate(int groupId, BlockBuilder output)
        {
            addRawInputs(delegate.getDistinctPages());
            delegate.reset();
            super.evaluateIntermediate(groupId, output);
        }

        @Override
        public void prepareFinal()
        {
            addRawInputs(delegate.getDistinctPages());
            delegate.reset();
            if (getRawInputsLength() == 0) {
                // This means that all rows were filtered out during preprocessing
                // when filtering was applied based on maskChannel. Delegate's accumulator
                // expects to receive some input pages in order to initialize its internal
                // state properly. Due to this, we need push atleast 1 empty page to underlying
                // delegate's accumulator
                Page page = new PageBuilder(getSpillingTypes()).build();
                addRawInputs(ImmutableList.of(page));
            }
            super.prepareFinal();
        }

        private void addRawInputs(List<Page> inputPages)
        {
            for (Page inputPage : inputPages) {
                // Channel 0 is groupId column of type BIGINT
                Block groupIdBlock = inputPage.getBlock(0);

                // Drop GroupId column
                inputPage = inputPage.dropColumn(0);

                // If maskChannel is present, appends the corresponding block
                if (maskChannel >= 0) {
                    // Filtering based on masked channel is already applied during preprocessing
                    // So, we can just create a block for maskChannel where all rows will pass
                    // maskChannel filter in delegate's addInput() method. This will make the
                    // delegate maskChannel filter check a no-op
                    inputPage = inputPage.appendColumn(RunLengthEncodedBlock.create(BOOLEAN, true, inputPage.getPositionCount()));
                }
                GroupByIdBlock groupByIdBlock = new GroupByIdBlock(groupCount, groupIdBlock);
                addRawInput(groupByIdBlock, inputPage);
            }
        }
    }

    private static class GroupIdPage
    {
        private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupIdPage.class).instanceSize();

        private final GroupByIdBlock groupByIdBlock;
        private final Page page;

        public GroupIdPage(GroupByIdBlock groupByIdBlock, Page page)
        {
            this.page = requireNonNull(page, "page is null");
            this.groupByIdBlock = requireNonNull(groupByIdBlock, "groupByIdBlock is null");
        }

        public Page getPage()
        {
            return page;
        }

        public GroupByIdBlock getGroupByIdBlock()
        {
            return groupByIdBlock;
        }

        public long getRetainedSizeInBytes()
        {
            return INSTANCE_SIZE + groupByIdBlock.getRetainedSizeInBytes() + page.getRetainedSizeInBytes();
        }
    }
}