ArraySortFunction.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.NotSupportedException;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.OperatorDependency;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.google.common.primitives.Ints;

import java.lang.invoke.MethodHandle;
import java.util.AbstractList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static java.lang.Float.floatToIntBits;
import static java.lang.Float.intBitsToFloat;

@ScalarFunction("array_sort")
@Description("Sorts the given array in ascending order according to the natural ordering of its elements.")
public final class ArraySortFunction
{
    private ArraySortFunction() {}

    @TypeParameter("E")
    @SqlType("array(E)")
    public static Block sort(
            @OperatorDependency(operator = LESS_THAN, argumentTypes = {"E", "E"}) MethodHandle lessThanFunction,
            @TypeParameter("E") Type type,
            @SqlType("array(E)") Block block)
    {
        int arrayLength = block.getPositionCount();

        if (arrayLength < 2) {
            return block;
        }

        ListOfPositions listOfPositions = new ListOfPositions(block.getPositionCount());
        if (block.mayHaveNull()) {
            listOfPositions.sort(new Comparator<Integer>()
            {
                @Override
                public int compare(Integer p1, Integer p2)
                {
                    if (block.isNull(p1)) {
                        return block.isNull(p2) ? 0 : 1;
                    }
                    else if (block.isNull(p2)) {
                        return -1;
                    }

                    try {
                        //TODO: This could be quite slow, it should use parametric equals
                        return type.compareTo(block, p1, block, p2);
                    }
                    catch (PrestoException | NotSupportedException e) {
                        if (e instanceof NotSupportedException || ((PrestoException) e).getErrorCode() == NOT_SUPPORTED.toErrorCode()) {
                            throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Array contains elements not supported for comparison", e);
                        }
                        throw e;
                    }
                }
            });
        }
        else {
            listOfPositions.sort(new Comparator<Integer>()
            {
                @Override
                public int compare(Integer p1, Integer p2)
                {
                    try {
                        //TODO: This could be quite slow, it should use parametric equals
                        return type.compareTo(block, p1, block, p2);
                    }
                    catch (PrestoException | NotSupportedException e) {
                        if (e instanceof NotSupportedException || ((PrestoException) e).getErrorCode() == NOT_SUPPORTED.toErrorCode()) {
                            throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Array contains elements not supported for comparison", e);
                        }
                        throw e;
                    }
                }
            });
        }

        List<Integer> sortedListOfPositions = listOfPositions.getSortedListOfPositions();
        if (sortedListOfPositions == listOfPositions) {
            // Original array is already sorted.
            return block;
        }

        BlockBuilder blockBuilder = type.createBlockBuilder(null, arrayLength);

        for (int i = 0; i < arrayLength; i++) {
            type.appendTo(block, sortedListOfPositions.get(i), blockBuilder);
        }

        return blockBuilder.build();
    }

    @SqlType("array(bigint)")
    public static Block bigintSort(@SqlType("array(bigint)") Block array)
    {
        final int arrayLength = array.getPositionCount();
        if (arrayLength < 2) {
            return array;
        }

        long[] values = new long[array.getPositionCount()];
        int nulls = 0;

        if (array.mayHaveNull()) {
            int j = 0;
            for (int i = 0; i < array.getPositionCount(); i++) {
                if (array.isNull(i)) {
                    nulls++;
                }
                else {
                    values[j++] = BIGINT.getLong(array, i);
                }
            }
        }
        else {
            for (int i = 0; i < array.getPositionCount(); i++) {
                values[i] = BIGINT.getLong(array, i);
            }
        }

        Arrays.sort(values, 0, values.length - nulls);

        BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, arrayLength);
        for (int i = 0; i < values.length - nulls; i++) {
            BIGINT.writeLong(blockBuilder, values[i]);
        }
        // Nulls last
        for (int i = 0; i < nulls; i++) {
            blockBuilder.appendNull();
        }

        return blockBuilder.build();
    }

    @SqlType("array(double)")
    public static Block doubleSort(@SqlType("array(double)") Block array)
    {
        final int arrayLength = array.getPositionCount();
        if (arrayLength < 2) {
            return array;
        }

        double[] values = new double[array.getPositionCount()];
        int nulls = 0;

        if (array.mayHaveNull()) {
            int j = 0;
            for (int i = 0; i < array.getPositionCount(); i++) {
                if (array.isNull(i)) {
                    nulls++;
                }
                else {
                    values[j++] = DOUBLE.getDouble(array, i);
                }
            }
        }
        else {
            for (int i = 0; i < array.getPositionCount(); i++) {
                values[i] = DOUBLE.getDouble(array, i);
            }
        }

        Arrays.sort(values, 0, values.length - nulls);

        BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, arrayLength);
        for (int i = 0; i < values.length - nulls; i++) {
            DOUBLE.writeDouble(blockBuilder, values[i]);
        }

        // Nulls last
        for (int i = 0; i < nulls; i++) {
            blockBuilder.appendNull();
        }

        return blockBuilder.build();
    }

    @SqlType("array(real)")
    public static Block floatSort(@SqlType("array(real)") Block array)
    {
        final int arrayLength = array.getPositionCount();
        if (arrayLength < 2) {
            return array;
        }

        float[] values = new float[array.getPositionCount()];
        int nulls = 0;

        if (array.mayHaveNull()) {
            int j = 0;
            for (int i = 0; i < array.getPositionCount(); i++) {
                if (array.isNull(i)) {
                    nulls++;
                }
                else {
                    values[j++] = intBitsToFloat(array.getInt(i));
                }
            }
        }
        else {
            for (int i = 0; i < array.getPositionCount(); i++) {
                values[i] = intBitsToFloat(array.getInt(i));
            }
        }

        Arrays.sort(values, 0, values.length - nulls);

        BlockBuilder blockBuilder = REAL.createBlockBuilder(null, arrayLength);
        for (int i = 0; i < values.length - nulls; i++) {
            REAL.writeLong(blockBuilder, floatToIntBits(values[i]));
        }

        // Nulls last
        for (int i = 0; i < nulls; i++) {
            blockBuilder.appendNull();
        }

        return blockBuilder.build();
    }

    private static class ListOfPositions
            extends AbstractList<Integer>
    {
        private final int size;
        private List<Integer> sortedListOfPositions;

        ListOfPositions(int size)
        {
            this.size = size;
        }

        @Override
        public final int size()
        {
            return size;
        }

        @Override
        public final Integer get(int i)
        {
            return i;
        }

        @Override
        public final Integer set(int index, Integer position)
        {
            if (index != position) {
                // The element at position is out of order.
                if (sortedListOfPositions == null) {
                    // So we need to store the entire position array in a new list.
                    sortedListOfPositions = Ints.asList(new int[size()]);
                    for (int i = 0; i < size(); i++) {
                        sortedListOfPositions.set(i, i);
                    }
                }

                // Set the new position to be used for this index.
                sortedListOfPositions.set(index, position);
            }

            return position;
        }

        List<Integer> getSortedListOfPositions()
        {
            return sortedListOfPositions == null ? this : sortedListOfPositions;
        }
    }
}