TDigestFunctions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.tdigest.Centroid;
import com.facebook.presto.tdigest.TDigest;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.EXPERIMENTAL;
import static com.facebook.presto.tdigest.TDigest.createTDigest;
import static com.facebook.presto.util.Failures.checkCondition;
import static java.lang.Math.toIntExact;

public final class TDigestFunctions
{
    public static final double DEFAULT_COMPRESSION = 100;

    @VisibleForTesting
    static final RowType TDIGEST_CENTROIDS_ROW_TYPE = RowType.from(
            ImmutableList.of(
                    RowType.field("centroid_means", new ArrayType(DOUBLE)),
                    RowType.field("centroid_weights", new ArrayType(INTEGER)),
                    RowType.field("compression", DOUBLE),
                    RowType.field("min", DOUBLE),
                    RowType.field("max", DOUBLE),
                    RowType.field("sum", DOUBLE),
                    RowType.field("count", BIGINT)));

    private TDigestFunctions() {}

    @ScalarFunction(value = "value_at_quantile", visibility = EXPERIMENTAL)
    @Description("Given an input q between [0, 1], find the value whose rank in the sorted sequence of the n values represented by the tdigest is qn.")
    @SqlType(StandardTypes.DOUBLE)
    public static double valueAtQuantileDouble(@SqlType("tdigest(double)") Slice input, @SqlType(StandardTypes.DOUBLE) double quantile)
    {
        return createTDigest(input).getQuantile(quantile);
    }

    @ScalarFunction(value = "values_at_quantiles", visibility = EXPERIMENTAL)
    @Description("For each input q between [0, 1], find the value whose rank in the sorted sequence of the n values represented by the tdigest is qn.")
    @SqlType("array(double)")
    public static Block valuesAtQuantilesDouble(@SqlType("tdigest(double)") Slice input, @SqlType("array(double)") Block percentilesArrayBlock)
    {
        TDigest tDigest = createTDigest(input);
        BlockBuilder output = DOUBLE.createBlockBuilder(null, percentilesArrayBlock.getPositionCount());
        for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) {
            checkCondition(!percentilesArrayBlock.isNull(i), INVALID_FUNCTION_ARGUMENT, "All quantiles should be non-null.");
            DOUBLE.writeDouble(output, tDigest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i)));
        }
        return output.build();
    }

    @ScalarFunction(value = "quantile_at_value", visibility = EXPERIMENTAL)
    @Description("Given an input x between min/max values of t-digest, find which quantile is represented by that value")
    @SqlType(StandardTypes.DOUBLE)
    public static double quantileAtValueDouble(@SqlType("tdigest(double)") Slice input, @SqlType(StandardTypes.DOUBLE) double value)
    {
        return createTDigest(input).getCdf(value);
    }

    @ScalarFunction(value = "quantiles_at_values", visibility = EXPERIMENTAL)
    @Description("For each input x between min/max values of t-digest, find which quantile is represented by that value")
    @SqlType("array(double)")
    public static Block quantilesAtValuesDouble(@SqlType("tdigest(double)") Slice input, @SqlType("array(double)") Block valuesArrayBlock)
    {
        TDigest tDigest = createTDigest(input);
        BlockBuilder output = DOUBLE.createBlockBuilder(null, valuesArrayBlock.getPositionCount());
        for (int i = 0; i < valuesArrayBlock.getPositionCount(); i++) {
            checkCondition(!valuesArrayBlock.isNull(i), INVALID_FUNCTION_ARGUMENT, "All values should be non-null.");
            DOUBLE.writeDouble(output, tDigest.getCdf(DOUBLE.getDouble(valuesArrayBlock, i)));
        }
        return output.build();
    }

    @ScalarFunction(value = "scale_tdigest", visibility = EXPERIMENTAL)
    @Description("Scale a t-digest according to a new weight")
    @SqlType("tdigest(double)")
    public static Slice scaleTDigestDouble(@SqlType("tdigest(double)") Slice input, @SqlType(StandardTypes.DOUBLE) double scale)
    {
        checkCondition(scale > 0, INVALID_FUNCTION_ARGUMENT, "Scale factor should be positive.");
        TDigest digest = createTDigest(input);
        digest.scale(scale);
        return digest.serialize();
    }

    @ScalarFunction(value = "destructure_tdigest", visibility = EXPERIMENTAL)
    @Description("Return the raw TDigest, including arrays of centroid means and weights, as well as min, max, sum, count, and compression factor.")
    @SqlType("row(centroid_means array(double), centroid_weights array(integer), compression double, min double, max double, sum double, count bigint)")
    public static Block destructureTDigest(@SqlType("tdigest(double)") Slice input)
    {
        TDigest tDigest = createTDigest(input);

        BlockBuilder blockBuilder = TDIGEST_CENTROIDS_ROW_TYPE.createBlockBuilder(null, 1);
        BlockBuilder rowBuilder = blockBuilder.beginBlockEntry();

        // Centroid means / weights
        BlockBuilder meansBuilder = DOUBLE.createBlockBuilder(null, tDigest.centroidCount());
        BlockBuilder weightsBuilder = INTEGER.createBlockBuilder(null, tDigest.centroidCount());
        for (Centroid centroid : tDigest.centroids()) {
            int weight = (int) centroid.getWeight();
            DOUBLE.writeDouble(meansBuilder, centroid.getMean());
            INTEGER.writeLong(weightsBuilder, weight);
        }
        rowBuilder.appendStructure(meansBuilder);
        rowBuilder.appendStructure(weightsBuilder);

        // Compression, min, max, sum, count
        DOUBLE.writeDouble(rowBuilder, tDigest.getCompressionFactor());
        DOUBLE.writeDouble(rowBuilder, tDigest.getMin());
        DOUBLE.writeDouble(rowBuilder, tDigest.getMax());
        DOUBLE.writeDouble(rowBuilder, tDigest.getSum());
        BIGINT.writeLong(rowBuilder, (long) tDigest.getSize());

        blockBuilder.closeEntry();
        return TDIGEST_CENTROIDS_ROW_TYPE.getObject(blockBuilder, blockBuilder.getPositionCount() - 1);
    }

    @ScalarFunction(value = "trimmed_mean", visibility = EXPERIMENTAL)
    @Description("Returns an estimate of the mean, excluding portions of the distribution outside the provided quantile bounds.")
    @SqlType("double")
    public static double trimmedMeanTDigestDouble(@SqlType("tdigest(double)") Slice input, @SqlType(StandardTypes.DOUBLE) double lowerQuantileBound, @SqlType(StandardTypes.DOUBLE) double upperQuantileBound)
    {
        checkCondition(lowerQuantileBound >= 0 && lowerQuantileBound <= 1, INVALID_FUNCTION_ARGUMENT, "Lower quantile bound should be in [0,1].");
        checkCondition(upperQuantileBound >= 0 && upperQuantileBound <= 1, INVALID_FUNCTION_ARGUMENT, "Upper quantile bound should be in [0,1].");
        TDigest digest = createTDigest(input);
        return digest.trimmedMean(lowerQuantileBound, upperQuantileBound);
    }

    @ScalarFunction(value = "construct_tdigest", visibility = EXPERIMENTAL)
    @Description("Create a TDigest by passing in its internal state.")
    @SqlType("tdigest(double)")
    public static Slice constructTDigest(
            @SqlType("array(double)") Block centroidMeansBlock,
            @SqlType("array(double)") Block centroidWeightsBlock,
            @SqlType(StandardTypes.DOUBLE) double compression,
            @SqlType(StandardTypes.DOUBLE) double min,
            @SqlType(StandardTypes.DOUBLE) double max,
            @SqlType(StandardTypes.DOUBLE) double sum,
            @SqlType(StandardTypes.BIGINT) long count)
    {
        double[] centroidMeans = new double[centroidMeansBlock.getPositionCount()];
        for (int i = 0; i < centroidMeansBlock.getPositionCount(); i++) {
            centroidMeans[i] = DOUBLE.getDouble(centroidMeansBlock, i);
        }
        double[] centroidWeights = new double[centroidWeightsBlock.getPositionCount()];
        for (int i = 0; i < centroidWeightsBlock.getPositionCount(); i++) {
            centroidWeights[i] = DOUBLE.getDouble(centroidWeightsBlock, i);
        }

        TDigest tDigest = createTDigest(
                centroidMeans,
                centroidWeights,
                compression,
                min,
                max,
                sum,
                toIntExact(count));

        return tDigest.serialize();
    }

    @ScalarFunction(value = "merge_tdigest", visibility = EXPERIMENTAL)
    @Description("Merge an array of TDigests into a single TDigest")
    @SqlType("tdigest(double)")
    @SqlNullable
    public static Slice merge_tdigest(@SqlType("array(tdigest(double))") Block input)
    {
        if (input.getPositionCount() == 0) {
            return null;
        }
        TDigest output = null;
        for (int i = 0; i < input.getPositionCount(); i++) {
            if (input.isNull(i)) {
                continue;
            }
            TDigest tdigest = createTDigest(input.getSlice(i, 0, input.getSliceLength(i)));
            if (output == null) {
                output = tdigest;
            }
            else {
                output.merge(tdigest);
            }
        }
        return output == null ? null : output.serialize();
    }
}