TestKHyperLogLogAggregationFunction.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.type.khyperloglog;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.SqlVarbinary;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;

import static com.facebook.presto.block.BlockAssertions.createDoublesBlock;
import static com.facebook.presto.block.BlockAssertions.createLongsBlock;
import static com.facebook.presto.block.BlockAssertions.createSlicesBlock;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;

public class TestKHyperLogLogAggregationFunction
{
    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
    private static final String NAME = KHyperLogLogWithLimitAggregationFunction.getFunctionName();

    @Test
    public void testSimpleKHyperLogLog()
    {
        int sampleSize = 100;
        List<Long> longs = generateLongs(sampleSize);
        List<Slice> strings = generateStringSlices(sampleSize);
        List<Double> doubles = generateDoubles(sampleSize);

        testAggregation(BIGINT, longs, BIGINT, longs);

        testAggregation(BIGINT, longs, VARCHAR, strings);

        testAggregation(VARCHAR, strings, BIGINT, longs);

        testAggregation(VARCHAR, strings, VARCHAR, strings);

        testAggregation(DOUBLE, doubles, BIGINT, longs);

        testAggregation(DOUBLE, doubles, VARCHAR, strings);
    }

    @Test
    public void testBigKHyperLogLog()
    {
        int sampleSize = 100000;
        List<Long> longs = generateLongs(sampleSize);
        List<Slice> strings = generateStringSlices(sampleSize);
        List<Double> doubles = generateDoubles(sampleSize);

        testAggregation(BIGINT, longs, BIGINT, longs);

        testAggregation(BIGINT, longs, VARCHAR, strings);

        testAggregation(VARCHAR, strings, BIGINT, longs);

        testAggregation(VARCHAR, strings, VARCHAR, strings);

        testAggregation(DOUBLE, doubles, BIGINT, longs);

        testAggregation(DOUBLE, doubles, VARCHAR, strings);
    }

    @Test
    public void testKHyperLogLogWithSomeNulls()
    {
        int sampleSize = 3;
        List<Long> longs = generateLongs(sampleSize);
        List<Slice> strings = generateStringSlices(sampleSize);
        List<Double> doubles = generateDoubles(sampleSize);

        includeNulls(longs);
        includeNulls(strings);
        includeNulls(doubles);

        testAggregation(BIGINT, longs, BIGINT, longs);

        testAggregation(BIGINT, longs, VARCHAR, strings);

        testAggregation(VARCHAR, strings, BIGINT, longs);

        testAggregation(VARCHAR, strings, VARCHAR, strings);

        testAggregation(DOUBLE, doubles, BIGINT, longs);

        testAggregation(DOUBLE, doubles, VARCHAR, strings);
    }

    @Test
    public void testKHyperLogLogWithNullColumn()
    {
        int sampleSize = 3;
        List<Long> longs = generateLongs(sampleSize);
        List<Slice> strings = generateStringSlices(sampleSize);
        List<Double> doubles = generateDoubles(sampleSize);
        List<Object> nulls = generateNulls(sampleSize);

        testAggregation(BIGINT, nulls, BIGINT, longs);
        testAggregation(BIGINT, longs, BIGINT, nulls);

        testAggregation(BIGINT, nulls, VARCHAR, strings);
        testAggregation(BIGINT, longs, VARCHAR, nulls);

        testAggregation(VARCHAR, nulls, BIGINT, longs);
        testAggregation(VARCHAR, strings, BIGINT, nulls);

        testAggregation(VARCHAR, nulls, VARCHAR, strings);
        testAggregation(VARCHAR, strings, VARCHAR, nulls);

        testAggregation(DOUBLE, nulls, BIGINT, longs);
        testAggregation(DOUBLE, doubles, BIGINT, nulls);

        testAggregation(DOUBLE, nulls, VARCHAR, strings);
        testAggregation(DOUBLE, doubles, VARCHAR, nulls);
    }

    private void testAggregation(Type valueType, List<?> values, Type uiiType, List<?> uiis)
    {
        JavaAggregationFunctionImplementation aggregationFunction = getAggregation(valueType, uiiType);
        KHyperLogLog khll = null;
        long value;
        long uii;

        for (int i = 0; i < values.size(); i++) {
            if (values.get(i) == null || uiis.get(i) == null) {
                continue;
            }

            if (khll == null) {
                khll = new KHyperLogLog();
            }

            value = toLong(values.get(i), valueType);
            uii = toLong(uiis.get(i), uiiType);
            if (valueType == VARCHAR) {
                khll.add((Slice) values.get(i), uii);
            }
            else {
                khll.add(value, uii);
            }
        }

        assertAggregation(
                aggregationFunction,
                (khll == null) ? null : new SqlVarbinary(khll.serialize().getBytes()),
                buildBlock(values, valueType),
                buildBlock(uiis, uiiType));
    }

    private long toLong(Object value, Type type)
    {
        if (type.equals(DOUBLE)) {
            return Double.doubleToLongBits((double) value);
        }
        else if (type == VARCHAR) {
            return XxHash64.hash((Slice) value);
        }
        else {
            return (long) value;
        }
    }

    private Block buildBlock(List<?> values, Type type)
    {
        if (type.equals(DOUBLE)) {
            return createDoublesBlock(values.stream().map(o -> (Double) o).collect(Collectors.toList()));
        }
        else if (type == VARCHAR) {
            return createSlicesBlock(values.stream().map(o -> (Slice) o).collect(Collectors.toList()));
        }
        else {
            return createLongsBlock(values.stream().map(o -> (Long) o).collect(Collectors.toList()));
        }
    }

    private List<Slice> buildStringSliceList(List<String> strings)
    {
        return strings.stream().map(this::stringToSlice).collect(Collectors.toList());
    }

    private Slice stringToSlice(String s)
    {
        if (s == null) {
            return null;
        }
        return Slices.utf8Slice(s);
    }

    private static JavaAggregationFunctionImplementation getAggregation(Type... arguments)
    {
        return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction(NAME, fromTypes(arguments)));
    }

    private List<Long> generateLongs(int size)
    {
        Random generator = new Random(13);
        return generator.longs(size).boxed().collect(Collectors.toList());
    }

    private List<Slice> generateStringSlices(int size)
    {
        Random generator = new Random(123);
        return buildStringSliceList(generator.longs(size).boxed().map(Long::toHexString).collect(Collectors.toList()));
    }

    private List<Double> generateDoubles(int size)
    {
        Random generator = new Random(123);
        return generator.doubles(size).boxed().collect(Collectors.toList());
    }

    private List<Object> generateNulls(int size)
    {
        return Arrays.asList(new Object[size]);
    }

    private <K> List<K> includeNulls(List<K> values)
    {
        Random generator = new Random(123);
        for (int i = 0; i < values.size(); i++) {
            if (generator.nextDouble() < 0.2) {
                values.set(i, null);
            }
        }
        return values;
    }
}