TestTDigestFunctions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.SqlVarbinary;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeParameter;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.tdigest.TDigest;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.distribution.GeometricDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.IntStream;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.TDigestParametricType.TDIGEST;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.operator.scalar.TDigestFunctions.TDIGEST_CENTROIDS_ROW_TYPE;
import static com.facebook.presto.tdigest.TDigest.createTDigest;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.Slices.wrappedBuffer;
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static java.util.Collections.sort;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertEquals;

public class TestTDigestFunctions
        extends AbstractTestFunctions
{
    private static final int NUMBER_OF_ENTRIES = 1_000_000;
    private static final int STANDARD_COMPRESSION_FACTOR = 100;
    private static final double STANDARD_ERROR = 0.01;
    private static final double TRIMMED_MEAN_ERROR_IN_DEVIATIONS = 0.05;
    private static final double[] quantiles = {0.0001, 0.0200, 0.0300, 0.04000, 0.0500, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000,
            0.9000, 0.9500, 0.9600, 0.9700, 0.9800, 0.9999};
    private static final Type TDIGEST_DOUBLE = TDIGEST.createType(ImmutableList.of(TypeParameter.of(DOUBLE)));

    private static final Joiner ARRAY_JOINER = Joiner.on(",");
    private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager();

    @Test
    public void testNullTDigestGetValueAtQuantile()
    {
        functionAssertions.assertFunction("value_at_quantile(CAST(NULL AS tdigest(double)), 0.3)", DOUBLE, null);
    }

    @Test
    public void testNullTDigestGetQuantileAtValue()
    {
        functionAssertions.assertFunction("quantile_at_value(CAST(NULL AS tdigest(double)), 0.3)", DOUBLE, null);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testGetValueAtQuantileOverOne()
    {
        functionAssertions.assertFunction(format("value_at_quantile(CAST(X'%s' AS tdigest(double)), 1.5)",
                new SqlVarbinary(createTDigest(STANDARD_COMPRESSION_FACTOR).serialize().getBytes()).toString().replaceAll("\\s+", " ")),
                DOUBLE,
                null);
    }

    @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "All quantiles should be non-null.")
    public void testValuesAtQuantilesWithNullsThrowsError()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        for (int i = 0; i < 100; i++) {
            tDigest.add(i);
        }

        functionAssertions.assertFunction(
                format("values_at_quantiles(%s, ARRAY[0.25, NULL, 0.75])",
                        toSqlString(tDigest)),
                new ArrayType(DOUBLE),
                null);
    }

    @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "All values should be non-null.")
    public void testQuantilesAtValuesWithNullsThrowsError()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        for (int i = 0; i < 100; i++) {
            tDigest.add(i);
        }

        functionAssertions.assertFunction(
                format("quantiles_at_values(%s, ARRAY[25.0, NULL, 75.0])",
                        toSqlString(tDigest)),
                new ArrayType(DOUBLE),
                null);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testGetValueAtQuantileBelowZero()
    {
        functionAssertions.assertFunction(format("value_at_quantile(CAST(X'%s' AS tdigest(double)), -0.2)",
                new SqlVarbinary(createTDigest(STANDARD_COMPRESSION_FACTOR).serialize().getBytes()).toString().replaceAll("\\s+", " ")),
                DOUBLE,
                null);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidSerializationFormat()
    {
        functionAssertions.assertFunction(format("value_at_quantile(CAST(X'%s' AS tdigest(double)), 0.5)",
                new SqlVarbinary(createTDigest(STANDARD_COMPRESSION_FACTOR).serialize().getBytes()).toString().substring(0, 80).replaceAll("\\s+", " ")),
                DOUBLE,
                null);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testEmptySerialization()
    {
        functionAssertions.assertFunction(format("value_at_quantile(CAST(X'%s' AS tdigest(double)), 0.5)",
                new SqlVarbinary(new byte[0])),
                DOUBLE,
                null);
    }

    @Test
    public void testMergeTwoNormalDistributionsGetQuantile()
    {
        TDigest tDigest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        TDigest tDigest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(0, 50);

        for (int i = 0; i < NUMBER_OF_ENTRIES / 2; i++) {
            double value1 = normal.sample();
            double value2 = normal.sample();
            tDigest1.add(value1);
            tDigest2.add(value2);
            list.add(value1);
            list.add(value2);
        }

        tDigest1.merge(tDigest2);
        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertValueWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest1);
        }
    }

    @Test
    public void testGetQuantileAtValueOutsideRange()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            double value = Math.random() * NUMBER_OF_ENTRIES;
            tDigest.add(value);
        }

        functionAssertions.assertFunction(
                format("quantile_at_value(%s, %s) = 1",
                        toSqlString(tDigest),
                        1_000_000_000d),
                BOOLEAN,
                true);

        functionAssertions.assertFunction(
                format("quantile_at_value(%s, %s) = 0",
                        toSqlString(tDigest),
                        -500d),
                BOOLEAN,
                true);
    }

    @Test
    public void testNormalDistributionHighVarianceValuesArray()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        NormalDistribution normal = new NormalDistribution(0, 1);
        List<Double> list = new ArrayList<>();

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            double value = normal.sample();
            tDigest.add(value);
            list.add(value);
        }

        sort(list);

        double[] values = new double[quantiles.length];

        for (int i = 0; i < quantiles.length; i++) {
            values[i] = list.get((int) (quantiles[i] * NUMBER_OF_ENTRIES));
        }

        assertBlockValues(values, STANDARD_ERROR, tDigest);
    }

    @Test
    public void testAddElementsInOrderQuantileArray()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            tDigest.add(i);
            list.add((double) i);
        }

        sort(list);

        assertBlockQuantiles(quantiles, STANDARD_ERROR, list, tDigest);
    }

    @Test
    public void testNormalDistributionHighVarianceQuantileArray()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(0, 1);

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            double value = normal.sample();
            tDigest.add(value);
            list.add(value);
        }

        sort(list);

        assertBlockQuantiles(quantiles, STANDARD_ERROR, list, tDigest);
    }

    @Test
    public void testAddElementsInOrder()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Integer> list = new ArrayList<>();

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            tDigest.add(i);
            list.add(i);
        }

        for (int i = 0; i < quantiles.length; i++) {
            assertDiscreteQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
        }
    }

    @Test
    public void testMergeTwoDistributionsWithoutOverlap()
    {
        TDigest tDigest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        TDigest tDigest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Integer> list = new ArrayList<>();

        for (int i = 0; i < NUMBER_OF_ENTRIES / 2; i++) {
            tDigest1.add(i);
            tDigest2.add(i + NUMBER_OF_ENTRIES / 2);
            list.add(i);
            list.add(i + NUMBER_OF_ENTRIES / 2);
        }

        tDigest1.merge(tDigest2);
        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertDiscreteQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest1);
        }
    }

    @Test
    public void testMergeTwoDistributionsWithOverlap()
    {
        TDigest tDigest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        TDigest tDigest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Integer> list = new ArrayList<>();

        for (int i = 0; i < NUMBER_OF_ENTRIES / 2; i++) {
            tDigest1.add(i);
            tDigest2.add(i);
            list.add(i);
            list.add(i);
        }

        tDigest2.merge(tDigest1);
        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertDiscreteQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest2);
        }
    }

    @Test
    public void testAddElementsRandomized()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            double value = Math.random() * NUMBER_OF_ENTRIES;
            tDigest.add(value);
            list.add(value);
        }

        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertContinuousQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
        }
    }

    @Test
    public void testNormalDistributionLowVariance()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(1000, 1);

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            double value = normal.sample();
            tDigest.add(value);
            list.add(value);
        }

        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertContinuousQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
        }
    }

    @Test
    public void testTrimmedMean()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR * 2);
        RealDistribution distribution = new UniformRealDistribution(0.0d, NUMBER_OF_ENTRIES);
        List<Double> list = new ArrayList<>();

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            double value = distribution.sample();
            tDigest.add(value);
            list.add(value);
        }

        sort(list);

        List<Double> lowQuantiles = new ArrayList<>();
        List<Double> highQuantiles = new ArrayList<>();
        for (int i = 0; i < quantiles.length; i++) {
            for (int j = i + 1; j < quantiles.length; j++) {
                lowQuantiles.add(quantiles[i]);
                highQuantiles.add(quantiles[j]);
            }
        }
        assertTrimmedMeanValues(lowQuantiles, highQuantiles, Math.sqrt(distribution.getNumericalVariance()), TRIMMED_MEAN_ERROR_IN_DEVIATIONS, list, tDigest);
    }

    @Test
    public void testNormalDistributionHighVariance()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(0, 1);

        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            double value = normal.sample();
            tDigest.add(value);
            list.add(value);
        }

        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertContinuousQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
        }
    }

    @Test
    public void testMergeTwoNormalDistributions()
    {
        TDigest tDigest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        TDigest tDigest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(0, 50);

        for (int i = 0; i < NUMBER_OF_ENTRIES / 2; i++) {
            double value1 = normal.sample();
            double value2 = normal.sample();
            tDigest1.add(value1);
            tDigest2.add(value2);
            list.add(value1);
            list.add(value2);
        }

        tDigest1.merge(tDigest2);
        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertContinuousQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest1);
        }
    }

    @Test
    public void testMergeManySmallNormalDistributions()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(500, 20);
        int digests = 100_000;

        for (int k = 0; k < digests; k++) {
            TDigest current = createTDigest(STANDARD_COMPRESSION_FACTOR);
            for (int i = 0; i < 10; i++) {
                double value = normal.sample();
                current.add(value);
                list.add(value);
            }
            tDigest.merge(current);
        }

        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertContinuousQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
        }
    }

    @Test
    public void testMergeManyLargeNormalDistributions()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> list = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(500, 20);
        int digests = 1000;

        for (int k = 0; k < digests; k++) {
            TDigest current = createTDigest(STANDARD_COMPRESSION_FACTOR);
            for (int i = 0; i < NUMBER_OF_ENTRIES / digests; i++) {
                double value = normal.sample();
                current.add(value);
                list.add(value);
            }
            tDigest.merge(current);
        }

        sort(list);

        for (int i = 0; i < quantiles.length; i++) {
            assertContinuousQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
        }
    }

    @Test
    public void testConstructTDigest()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        ImmutableList<Double> values = ImmutableList.of(0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d);
        values.stream().forEach(tDigest::add);

        List<Double> weights = Collections.nCopies(values.size(), 1.0);
        double compression = Double.valueOf(STANDARD_COMPRESSION_FACTOR);
        double min = values.stream().reduce(Double.POSITIVE_INFINITY, Double::min);
        double max = values.stream().reduce(Double.NEGATIVE_INFINITY, Double::max);
        double sum = values.stream().reduce(0.0d, Double::sum);
        int count = values.size();

        String sql = format("construct_tdigest(ARRAY%s, ARRAY%s, %s, %s, %s, %s, %s)",
                values,
                weights,
                compression,
                min,
                max,
                sum,
                count);

        functionAssertions.selectSingleValue(
                sql,
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
    }

    @Test
    public void testDestructureTDigest()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        ImmutableList<Double> values = ImmutableList.of(0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d);
        values.stream().forEach(tDigest::add);

        List<Integer> weights = Collections.nCopies(values.size(), 1);
        double compression = Double.valueOf(STANDARD_COMPRESSION_FACTOR);
        double min = values.stream().reduce(Double.POSITIVE_INFINITY, Double::min);
        double max = values.stream().reduce(Double.NEGATIVE_INFINITY, Double::max);
        double sum = values.stream().reduce(0.0d, Double::sum);
        long count = values.size();

        String sql = format("destructure_tdigest(%s)", toSqlString(tDigest));

        functionAssertions.assertFunction(
                sql,
                TDIGEST_CENTROIDS_ROW_TYPE,
                ImmutableList.of(values, weights, compression, min, max, sum, count));

        functionAssertions.assertFunction(format("%s.compression", sql), DOUBLE, compression);
        functionAssertions.assertFunction(format("%s.min", sql), DOUBLE, min);
        functionAssertions.assertFunction(format("%s.max", sql), DOUBLE, max);
        functionAssertions.assertFunction(format("%s.sum", sql), DOUBLE, sum);
        functionAssertions.assertFunction(format("%s.count", sql), BIGINT, count);
        functionAssertions.assertFunction(
                format("%s.centroid_means", sql),
                new ArrayType(DOUBLE),
                values);
        functionAssertions.assertFunction(
                format("%s.centroid_weights", sql),
                new ArrayType(INTEGER),
                weights);
    }

    @Test
    public void testConstructTDigestLarge()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> values = new ArrayList<>();
        for (int i = 0; i < 100; i++) {
            values.add((double) i);
        }

        values.stream().forEach(tDigest::add);

        List<Double> weights = Collections.nCopies(values.size(), 1.0);
        double compression = Double.valueOf(STANDARD_COMPRESSION_FACTOR);
        double min = values.stream().reduce(Double.POSITIVE_INFINITY, Double::min);
        double max = values.stream().reduce(Double.NEGATIVE_INFINITY, Double::max);
        double sum = values.stream().reduce(0.0d, Double::sum);
        long count = values.size();

        String sql = format("construct_tdigest(ARRAY%s, ARRAY%s, %s, %s, %s, %s, %s)",
                values,
                weights,
                compression,
                min,
                max,
                sum,
                count);

        functionAssertions.selectSingleValue(
                sql,
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
    }

    @Test
    public void testDestructureTDigestLarge()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> values = new ArrayList<>();
        for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
            values.add((double) i);
        }

        values.stream().forEach(tDigest::add);

        double compression = Double.valueOf(STANDARD_COMPRESSION_FACTOR);
        double min = values.stream().reduce(Double.POSITIVE_INFINITY, Double::min);
        double max = values.stream().reduce(Double.NEGATIVE_INFINITY, Double::max);
        double sum = values.stream().reduce(0.0d, Double::sum);
        long count = values.size();

        String sql = format("destructure_tdigest(%s)", toSqlString(tDigest));

        functionAssertions.assertFunction(format("%s.compression", sql), DOUBLE, compression);
        functionAssertions.assertFunction(format("%s.min", sql), DOUBLE, min);
        functionAssertions.assertFunction(format("%s.max", sql), DOUBLE, max);
        functionAssertions.assertFunction(format("%s.sum", sql), DOUBLE, sum);
        functionAssertions.assertFunction(format("%s.count", sql), BIGINT, count);
    }

    @Test
    public void testConstructTDigestNormalDistribution()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        List<Double> values = new ArrayList<>();
        NormalDistribution normal = new NormalDistribution(500, 20);
        int samples = 100;

        for (int k = 0; k < samples; k++) {
            double value = normal.sample();
            tDigest.add(value);
            values.add(value);
        }

        List<Double> weights = Collections.nCopies(values.size(), 1.0);
        double compression = Double.valueOf(STANDARD_COMPRESSION_FACTOR);
        double min = values.stream().reduce(Double.POSITIVE_INFINITY, Double::min);
        double max = values.stream().reduce(Double.NEGATIVE_INFINITY, Double::max);
        double sum = values.stream().reduce(0.0d, Double::sum);
        long count = values.size();

        String sql = format("construct_tdigest(ARRAY%s, ARRAY%s, %s, %s, %s, %s, %s)",
                values,
                weights,
                compression,
                min,
                max,
                sum,
                count);

        functionAssertions.selectSingleValue(
                sql,
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
    }

    @Test
    public void testConstructTDigestInverse()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        ImmutableList<Double> values = ImmutableList.of(0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d);
        values.stream().forEach(tDigest::add);

        List<Integer> weights = Collections.nCopies(values.size(), 1);
        double compression = Double.valueOf(STANDARD_COMPRESSION_FACTOR);
        double min = values.stream().reduce(Double.POSITIVE_INFINITY, Double::min);
        double max = values.stream().reduce(Double.NEGATIVE_INFINITY, Double::max);
        double sum = values.stream().reduce(0.0d, Double::sum);
        long count = values.size();

        SqlVarbinary sqlVarbinary = new SqlVarbinary(tDigest.serialize().getBytes());

        String destructureTdigestSql = format("destructure_tdigest(CAST(X'%s' AS tdigest(%s)))",
                new SqlVarbinary(tDigest.serialize().getBytes()).toString().replaceAll("\\s+", " "),
                DOUBLE);

        // Asserting that calling destructure_tdigest on the generated tdigest
        // produces values that equal those declared above
        functionAssertions.assertFunction(
                destructureTdigestSql,
                TDIGEST_CENTROIDS_ROW_TYPE,
                ImmutableList.of(values, weights, compression, min, max, sum, count));

        functionAssertions.assertFunction(format("%s.compression", destructureTdigestSql), DOUBLE, compression);
        functionAssertions.assertFunction(format("%s.min", destructureTdigestSql), DOUBLE, min);
        functionAssertions.assertFunction(format("%s.max", destructureTdigestSql), DOUBLE, max);
        functionAssertions.assertFunction(format("%s.sum", destructureTdigestSql), DOUBLE, sum);
        functionAssertions.assertFunction(format("%s.count", destructureTdigestSql), BIGINT, count);
        functionAssertions.assertFunction(
                format("%s.centroid_means", destructureTdigestSql),
                new ArrayType(DOUBLE),
                values);
        functionAssertions.assertFunction(
                format("%s.centroid_weights", destructureTdigestSql),
                new ArrayType(INTEGER),
                weights);

        String constructTdigestSql = format("construct_tdigest(ARRAY%s, ARRAY%s, %s, %s, %s, %s, %s)",
                values,
                weights,
                compression,
                min,
                max,
                sum,
                count);

        // Asserting that calling construct_tdigest with the raw values
        // produces a varbinary that equals the generated tdigest declared above
        SqlVarbinary constructedSqlVarbinary = functionAssertions.selectSingleValue(
                constructTdigestSql,
                TDIGEST_DOUBLE,
                SqlVarbinary.class);

        // If this is true then by definition calling construct_tdigest(destructure_tdigest(...)...)
        // will work
        assertEquals(constructedSqlVarbinary, sqlVarbinary);
    }

    @Test
    public void testMergeTDigestNullInput()
    {
        functionAssertions.assertFunction("merge_tdigest(null)", TDIGEST_DOUBLE, null);
    }

    @Test
    public void testMergeTDigestEmptyArray()
    {
        functionAssertions.assertFunction("merge_tdigest(array[])", TDIGEST_DOUBLE, null);
    }

    @Test
    public void testMergeTDigestEmptyArrayOfNull()
    {
        functionAssertions.assertFunction("merge_tdigest(array[null])", TDIGEST_DOUBLE, null);
    }

    @Test
    public void testMergeTDigestEmptyArrayOfNulls()
    {
        functionAssertions.assertFunction("merge_tdigest(array[null, null, null])", TDIGEST_DOUBLE, null);
    }

    @Test
    public void testMergeTDigests()
    {
        TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(digest1, 0.1);
        TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(digest2, 0.2);
        SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
                format("merge_tdigest(cast(array[%s, %s] as array(tdigest(double))))",
                        toSqlString(digest1),
                        toSqlString(digest2)),
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
        digest1.merge(digest2);
        assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
    }

    @Test
    public void testMergeTDigestOneNull()
    {
        TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(digest1, 0.1);
        SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
                format("merge_tdigest(cast(array[%s, null] as array(tdigest(double))))",
                        toSqlString(digest1)),
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
        assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
    }

    @Test
    public void testMergeTDigestOneNullFirst()
    {
        TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(digest1, 0.1);
        TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(digest2, 0.2);
        SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
                format("merge_tdigest(cast(array[null, %s, %s] as array(tdigest(double))))",
                        toSqlString(digest1),
                        toSqlString(digest2)),
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
        digest1.merge(digest2);
        assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
    }

    @Test
    public void testMergeTDigestOneNullMiddle()
    {
        TDigest digest1 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(digest1, 0.1);
        TDigest digest2 = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(digest2, 0.2);
        SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
                format("merge_tdigest(cast(array[%s, null, %s] as array(tdigest(double))))",
                        toSqlString(digest1),
                        toSqlString(digest2)),
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
        digest1.merge(digest2);
        assertEquals(sqlVarbinary, new SqlVarbinary(digest1.serialize().getBytes()));
    }

    // disabled because test takes almost 10s
    @Test(enabled = false)
    public void testBinomialDistribution()
    {
        int trials = 10;
        for (int k = 1; k < trials; k++) {
            TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
            BinomialDistribution binomial = new BinomialDistribution(trials, k * 0.1);
            List<Integer> list = new ArrayList<>();

            for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
                int sample = binomial.sample();
                tDigest.add(sample);
                list.add(sample);
            }

            Collections.sort(list);

            for (int i = 0; i < quantiles.length; i++) {
                assertDiscreteQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
            }
        }
    }

    @Test(enabled = false)
    public void testGeometricDistribution()
    {
        int trials = 10;
        for (int k = 1; k < trials; k++) {
            TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
            GeometricDistribution geometric = new GeometricDistribution(k * 0.1);
            List<Integer> list = new ArrayList<>();

            for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
                int sample = geometric.sample();
                tDigest.add(sample);
                list.add(sample);
            }

            Collections.sort(list);

            for (int i = 0; i < quantiles.length; i++) {
                assertDiscreteQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
            }
        }
    }

    @Test(enabled = false)
    public void testPoissonDistribution()
    {
        int trials = 10;
        for (int k = 1; k < trials; k++) {
            TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
            PoissonDistribution poisson = new PoissonDistribution(k * 0.1);
            List<Integer> list = new ArrayList<>();

            for (int i = 0; i < NUMBER_OF_ENTRIES; i++) {
                int sample = poisson.sample();
                tDigest.add(sample);
                list.add(sample);
            }

            Collections.sort(list);

            for (int i = 0; i < quantiles.length; i++) {
                assertDiscreteQuantileWithinBound(quantiles[i], STANDARD_ERROR, list, tDigest);
            }
        }
    }

    @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Scale factor should be positive\\.")
    public void testScaleNegative()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(tDigest, 0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d);

        functionAssertions.selectSingleValue(
                format(
                        "scale_tdigest(%s, -1)",
                        toSqlString(tDigest)),
                TDIGEST_DOUBLE,
                SqlVarbinary.class);
    }

    @Test
    public void testScale()
    {
        TDigest tDigest = createTDigest(STANDARD_COMPRESSION_FACTOR);
        addAll(tDigest, 0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d);

        // Before scaling.
        List<Double> unscaledFrequencies = getFrequencies(tDigest, asList(2.0d, 4.0d, 6.0d, 8.0d));

        // Scale up.
        SqlVarbinary sqlVarbinary = functionAssertions.selectSingleValue(
                format("scale_tdigest(%s, 2)", toSqlString(tDigest)),
                TDIGEST_DOUBLE,
                SqlVarbinary.class);

        TDigest scaledTdigest = createTDigest(wrappedBuffer(sqlVarbinary.getBytes()));
        List<Double> scaledDigestFrequencies = getFrequencies(scaledTdigest, asList(2.0d, 4.0d, 6.0d, 8.0d));
        List<Double> scaledUpFrequencies = new ArrayList<>();
        unscaledFrequencies.forEach(frequency -> scaledUpFrequencies.add(frequency * 2));
        assertEquals(scaledDigestFrequencies, scaledUpFrequencies);

        // Scale down.
        sqlVarbinary = functionAssertions.selectSingleValue(
                format("scale_tdigest(%s, 0.5)", toSqlString(tDigest)),
                TDIGEST_DOUBLE,
                SqlVarbinary.class);

        scaledTdigest = createTDigest(wrappedBuffer(sqlVarbinary.getBytes()));
        scaledDigestFrequencies = getFrequencies(scaledTdigest, asList(2.0d, 4.0d, 6.0d, 8.0d));
        List<Double> scaledDownFrequencies = new ArrayList<>();
        unscaledFrequencies.forEach(frequency -> scaledDownFrequencies.add(frequency * 0.5));
        assertEquals(scaledDigestFrequencies, scaledDownFrequencies);
    }

    private static void addAll(TDigest digest, double... values)
    {
        requireNonNull(values, "values is null");
        for (double value : values) {
            digest.add(value);
        }
    }

    private void assertValueWithinBound(double quantile, double error, List<Double> list, TDigest tDigest)
    {
        functionAssertions.assertFunction(
                format("quantile_at_value(%s, %s) <= %s",
                        toSqlString(tDigest),
                        list.get((int) (NUMBER_OF_ENTRIES * quantile)),
                        getUpperBoundQuantile(quantile, error)),
                BOOLEAN,
                true);
        functionAssertions.assertFunction(
                format("quantile_at_value(%s, %s) >= %s",
                        toSqlString(tDigest),
                        list.get((int) (NUMBER_OF_ENTRIES * quantile)),
                        getLowerBoundQuantile(quantile, error)),
                BOOLEAN,
                true);
    }

    private void assertDiscreteQuantileWithinBound(double quantile, double error, List<Integer> list, TDigest tDigest)
    {
        functionAssertions.assertFunction(
                format("round(value_at_quantile(%s, %s)) <= %s",
                        toSqlString(tDigest),
                        quantile,
                        getUpperBoundValue(quantile, error, list)),
                BOOLEAN,
                true);
        functionAssertions.assertFunction(
                format("round(value_at_quantile(%s, %s)) >= %s",
                        toSqlString(tDigest),
                        quantile,
                        getLowerBoundValue(quantile, error, list)),
                BOOLEAN,
                true);
    }

    private void assertContinuousQuantileWithinBound(double quantile, double error, List<Double> list, TDigest tDigest)
    {
        functionAssertions.assertFunction(
                format("value_at_quantile(%s, %s) <= %s",
                        toSqlString(tDigest),
                        quantile,
                        getUpperBoundValue(quantile, error, list)),
                BOOLEAN,
                true);
        functionAssertions.assertFunction(
                format("value_at_quantile(%s, %s) >= %s",
                        toSqlString(tDigest),
                        quantile,
                        getLowerBoundValue(quantile, error, list)),
                BOOLEAN,
                true);
    }

    private List<Double> getFrequencies(TDigest tdigest, List<Double> buckets)
    {
        List<Double> histogram = new ArrayList<>();
        for (Double bin : buckets) {
            histogram.add(tdigest.getCdf(bin) * tdigest.getSize());
        }
        return histogram;
    }

    private double getLowerBoundValue(double quantile, double error, List<? extends Number> values)
    {
        return values.get((int) Math.max(NUMBER_OF_ENTRIES * (quantile - error), 0)).doubleValue();
    }

    private double getUpperBoundValue(double quantile, double error, List<? extends Number> values)
    {
        return values.get((int) Math.min(NUMBER_OF_ENTRIES * (quantile + error), values.size() - 1)).doubleValue();
    }

    private double getLowerBoundQuantile(double quantile, double error)
    {
        return Math.max(0, quantile - error);
    }

    private double getUpperBoundQuantile(double quantile, double error)
    {
        return Math.min(1, quantile + error);
    }

    private double getTrimmedMean(double l, double h, List<? extends Number> values)
    {
        return values
                .subList((int) (NUMBER_OF_ENTRIES * l), (int) (NUMBER_OF_ENTRIES * h))
                .stream()
                .mapToDouble(Number::doubleValue)
                .reduce(0.0d, Double::sum) / ((int) (NUMBER_OF_ENTRIES * h) - (int) (NUMBER_OF_ENTRIES * l) + 1);
    }

    private void assertBlockQuantiles(double[] percentiles, double error, List<? extends Number> rows, TDigest tDigest)
    {
        List<Double> boxedPercentiles = Arrays.stream(percentiles).sorted().boxed().collect(toImmutableList());
        List<Number> lowerBounds = boxedPercentiles.stream().map(percentile -> getLowerBoundValue(percentile, error, rows)).collect(toImmutableList());
        List<Number> upperBounds = boxedPercentiles.stream().map(percentile -> getUpperBoundValue(percentile, error, rows)).collect(toImmutableList());

        // Ensure that the lower bound of each item in the distribution is not greater than the chosen quantiles
        functionAssertions.assertFunction(
                format(
                        "zip_with(values_at_quantiles(%s, ARRAY[%s]), ARRAY[%s], (value, lowerbound) -> value >= lowerbound)",
                        toSqlString(tDigest),
                        ARRAY_JOINER.join(boxedPercentiles),
                        ARRAY_JOINER.join(lowerBounds)),
                METADATA.getType(parseTypeSignature("array(boolean)")),
                Collections.nCopies(percentiles.length, true));

        // Ensure that the upper bound of each item in the distribution is not less than the chosen quantiles
        functionAssertions.assertFunction(
                format(
                        "zip_with(values_at_quantiles(%s, ARRAY[%s]), ARRAY[%s], (value, upperbound) -> value <= upperbound)",
                        toSqlString(tDigest),
                        ARRAY_JOINER.join(boxedPercentiles),
                        ARRAY_JOINER.join(upperBounds)),
                METADATA.getType(parseTypeSignature("array(boolean)")),
                Collections.nCopies(percentiles.length, true));
    }

    private void assertTrimmedMeanValues(List<Double> lowerQuantiles, List<Double> upperQuantiles, double sd, double error, List<? extends Number> rows, TDigest tDigest)
    {
        List<Double> expectedTrimmedMeans = IntStream.range(0, lowerQuantiles.size())
                .mapToDouble(i -> getTrimmedMean(lowerQuantiles.get(i), upperQuantiles.get(i), rows))
                .boxed()
                .collect(toImmutableList());
        functionAssertions.assertFunction(
                format(
                        "zip_with(ARRAY[%s], zip_with(ARRAY[%s], ARRAY[%s], (l, u) -> (l, u)), (v, bounds) -> abs(trimmed_mean(%s, bounds[1], bounds[2]) - v)/%s <= %s)",
                        ARRAY_JOINER.join(expectedTrimmedMeans),
                        ARRAY_JOINER.join(lowerQuantiles),
                        ARRAY_JOINER.join(upperQuantiles),
                        toSqlString(tDigest),
                        sd,
                        error),
                METADATA.getType(parseTypeSignature("array(boolean)")),
                Collections.nCopies(lowerQuantiles.size(), true));
    }

    private void assertBlockValues(double[] values, double error, TDigest tDigest)
    {
        List<Double> boxedValues = Arrays.stream(values).sorted().boxed().collect(toImmutableList());
        List<Double> boxedPercentiles = Arrays.stream(quantiles).sorted().boxed().collect(toImmutableList());
        List<Number> lowerBounds = boxedPercentiles.stream().map(percentile -> getLowerBoundQuantile(percentile, error)).collect(toImmutableList());
        List<Number> upperBounds = boxedPercentiles.stream().map(percentile -> getUpperBoundQuantile(percentile, error)).collect(toImmutableList());

        // Ensure that the lower bound of each item in the distribution is not greater than the chosen quantiles
        functionAssertions.assertFunction(
                format(
                        "zip_with(quantiles_at_values(%s, ARRAY[%s]), ARRAY[%s], (value, lowerbound) -> value >= lowerbound)",
                        toSqlString(tDigest),
                        ARRAY_JOINER.join(boxedValues),
                        ARRAY_JOINER.join(lowerBounds)),
                METADATA.getType(parseTypeSignature("array(boolean)")),
                Collections.nCopies(values.length, true));

        // Ensure that the upper bound of each item in the distribution is not less than the chosen quantiles
        functionAssertions.assertFunction(
                format(
                        "zip_with(quantiles_at_values(CAST(X'%s' AS tdigest(%s)), ARRAY[%s]), ARRAY[%s], (value, upperbound) -> value <= upperbound)",
                        new SqlVarbinary(tDigest.serialize().getBytes()).toString().replaceAll("\\s+", " "),
                        "double",
                        ARRAY_JOINER.join(boxedValues),
                        ARRAY_JOINER.join(upperBounds)),
                METADATA.getType(parseTypeSignature("array(boolean)")),
                Collections.nCopies(values.length, true));
    }

    private String toSqlString(TDigest tDigest)
    {
        return format("CAST(X'%s' AS tdigest(%s))",
                new SqlVarbinary(tDigest.serialize().getBytes()).toString().replaceAll("\\s+", " "),
                DOUBLE);
    }
}