ArraySqlFunctions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.scalar.sql;

import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.SqlInvokedScalarFunction;
import com.facebook.presto.spi.function.SqlParameter;
import com.facebook.presto.spi.function.SqlParameters;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;

public class ArraySqlFunctions
{
    private ArraySqlFunctions() {}

    @SqlInvokedScalarFunction(value = "array_average", deterministic = true, calledOnNullInput = false)
    @Description("Returns the average of all array elements, or null if the array is empty. Ignores null elements.")
    @SqlParameter(name = "input", type = "array<double>")
    @SqlType("double")
    public static String arrayAverage()
    {
        return "RETURN reduce(" +
                "input, " +
                "(double '0.0', 0), " +
                "(s, x) -> IF(x IS NOT NULL, (s[1] + x, s[2] + 1), s), " +
                "s -> if(s[2] = 0, cast(null as double), s[1] / cast(s[2] as double)))";
    }

    @SqlInvokedScalarFunction(value = "array_split_into_chunks", deterministic = true, calledOnNullInput = false)
    @Description("Returns an array of arrays splitting input array into chunks of given length. " +
            "If array is not evenly divisible it will split into as many possible chunks and " +
            "return the left over elements for the last array. Returns null for null inputs, but not elements.")
    @TypeParameter("T")
    @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "sz", type = "int")})
    @SqlType("array(array(T))")
    public static String arraySplitIntoChunks()
    {
        return "RETURN IF(sz <= 0, " +
                "fail('Invalid slice size: ' || cast(sz as varchar) || '. Size must be greater than zero.'), " +
                "IF(cardinality(input) / sz > 10000, " +
                "fail('Cannot split array of size: ' || cast(cardinality(input) as varchar) || ' into more than 10000 parts.'), " +
                "transform(" +
                "sequence(1, cardinality(input), sz), " +
                "x -> slice(input, x, sz))))";
    }

    @SqlInvokedScalarFunction(value = "array_frequency", deterministic = true, calledOnNullInput = false)
    @Description("Returns the frequency of all array elements as a map.")
    @TypeParameter("T")
    @SqlParameter(name = "input", type = "array(T)")
    @SqlType("map(T, int)")
    public static String arrayFrequency()
    {
        return "RETURN reduce(" +
                "input," +
                "MAP()," +
                "(m, x) -> IF (x IS NOT NULL, MAP_CONCAT(m,MAP_FROM_ENTRIES(ARRAY[ROW(x, COALESCE(ELEMENT_AT(m,x) + 1, 1))])), m)," +
                "m -> m)";
    }

    @SqlInvokedScalarFunction(value = "array_duplicates", deterministic = true, calledOnNullInput = false)
    @Description("Returns set of elements that have duplicates")
    @SqlParameter(name = "input", type = "array(T)")
    @TypeParameter("T")
    @SqlType("array(T)")
    public static String arrayDuplicates()
    {
        return "RETURN CONCAT(" +
                "IF (cardinality(filter(input, x -> x is NULL)) > 1, array[element_at(input, find_first_index(input, x -> x IS NULL))], array[])," +
                "map_keys(map_filter(array_frequency(input), (k, v) -> v > 1)))";
    }

    @SqlInvokedScalarFunction(value = "array_has_duplicates", deterministic = true, calledOnNullInput = false)
    @Description("Returns whether array has any duplicate element")
    @TypeParameter("T")
    @SqlParameter(name = "input", type = "array(T)")
    @SqlType("boolean")
    public static String arrayHasDuplicatesVarchar()
    {
        return "RETURN cardinality(array_duplicates(input)) > 0";
    }

    @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true)
    @Description("Determines the least frequent element in the array. If there are multiple elements, the function returns the smallest element")
    @TypeParameter("T")
    @SqlParameter(name = "input", type = "array(T)")
    @SqlType("array<T>")
    public static String arrayLeastFrequent()
    {
        return "RETURN IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, 1), x -> x[2]))";
    }

    @SqlInvokedScalarFunction(value = "array_least_frequent", deterministic = true, calledOnNullInput = true)
    @Description("Determines the n least frequent element in the array in the ascending order of the elements.")
    @TypeParameter("T")
    @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "bigint")})
    @SqlType("array<T>")
    public static String arrayNLeastFrequent()
    {
        return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), IF(COALESCE(CARDINALITY(REMOVE_NULLS(input)), 0) = 0, NULL, TRANSFORM(SLICE(ARRAY_SORT(TRANSFORM(MAP_ENTRIES(ARRAY_FREQUENCY(REMOVE_NULLS(input))), x -> ROW(x[2], x[1]))), 1, n), x -> x[2])))";
    }

    @SqlInvokedScalarFunction(value = "array_max_by", deterministic = true, calledOnNullInput = true)
    @Description("Get the maximum value of array, by using a specific transformation function")
    @TypeParameter("T")
    @TypeParameter("U")
    @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "f", type = "function(T, U)")})
    @SqlType("T")
    public static String arrayMaxBy()
    {
        return "RETURN input[" +
                "array_max(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" +
                "]";
    }

    @SqlInvokedScalarFunction(value = "array_min_by", deterministic = true, calledOnNullInput = true)
    @Description("Get the minimum value of array, by using a specific transformation function")
    @TypeParameter("T")
    @TypeParameter("U")
    @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "f", type = "function(T, U)")})
    @SqlType("T")
    public static String arrayMinBy()
    {
        return "RETURN input[" +
                "array_min(zip_with(transform(input, f), sequence(1, cardinality(input)), (x, y)->IF(x IS NULL, NULL, (x, y))))[2]" +
                "]";
    }

    @SqlInvokedScalarFunction(value = "array_sort_desc", deterministic = true, calledOnNullInput = true)
    @Description("Sorts the given array in descending order according to the natural ordering of its elements.")
    @TypeParameter("T")
    @SqlParameter(name = "input", type = "array(T)")
    @SqlType("array<T>")
    public static String arraySortDesc()
    {
        return "RETURN reverse(array_sort(remove_nulls(input))) || filter(input, x -> x is null)";
    }

    @SqlInvokedScalarFunction(value = "remove_nulls", deterministic = true, calledOnNullInput = true)
    @Description("Removes null values from an array.")
    @TypeParameter("T")
    @SqlParameter(name = "input", type = "array(T)")
    @SqlType("array<T>")
    public static String removeNulls()
    {
        return "RETURN IF(none_match(input, x -> x is null), input, filter(input, x -> x is not null))";
    }

    @SqlInvokedScalarFunction(value = "array_top_n", deterministic = true, calledOnNullInput = true)
    @Description("Returns top N elements of a given array, using natural descending order.")
    @TypeParameter("T")
    @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "int")})
    @SqlType("array<T>")
    public static String arrayTopN()
    { return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(ARRAY_SORT_DESC(input), 1, n))"; }

    @SqlInvokedScalarFunction(value = "array_top_n", deterministic = true, calledOnNullInput = true)
    @Description("Returns the top N values of the given map sorted using the provided lambda comparator.")
    @TypeParameter("T")
    @SqlParameters({@SqlParameter(name = "input", type = "array(T)"), @SqlParameter(name = "n", type = "int"), @SqlParameter(name = "f", type = "function(T, T, int)")})
    @SqlType("array<T>")
    public static String arrayTopNComparator()
    {
        return "RETURN IF(n < 0, fail('Parameter n: ' || cast(n as varchar) || ' to ARRAY_TOP_N is negative'), SLICE(REVERSE(ARRAY_SORT(input, f)), 1, n))";
    }
}