TestReservoirSampleAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation.reservoirsample;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Parameters;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.executeAggregation;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;

public class TestReservoirSampleAggregation
{
    protected FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();

    @Test
    public void testNoInitialSample()
    {
        assertAggregation(
                getDoubleFunction(),
                // seen count, and sample
                ImmutableList.of(5L, ImmutableList.of(1.0, 1.0)),
                // arguments
                copyBlock(BIGINT, nullBlock(), 5),
                copyBlock(BIGINT, bigintBlock(0), 5),
                doubleBlock(1, 1, 1, 1, 1),
                copyBlock(INTEGER, intBlock(2), 5));
    }

    @Test
    public void testLarge()
    {
        int sampleSize = 5000;
        int inputSize = 15_000;
        assertAggregation(
                getDoubleFunction(),
                // seen count, and sample
                ImmutableList.of((long) inputSize, IntStream.range(0, sampleSize).mapToObj(x -> 1.0).collect(Collectors.toList())),
                // arguments
                copyBlock(BIGINT, nullBlock(), inputSize),
                copyBlock(BIGINT, bigintBlock(0), inputSize),
                doubleBlock(IntStream.range(0, inputSize).mapToDouble(x -> 1.0).toArray()),
                copyBlock(INTEGER, intBlock(sampleSize), inputSize));
    }

    @DataProvider(name = "invalidSampleSize")
    public Object[][] invalidSampleParameters()
    {
        return new Object[][] {{0}, {-1}};
    }

    /**
     * Throws exception when desired sample size is <= 0
     */
    @Test(dataProvider = "invalidSampleSize", expectedExceptions = IllegalArgumentException.class)
    @Parameters("sampleSize")
    public void testInvalidSampleSize(int sampleSize)
    {
        assertAggregation(
                getDoubleFunction(),
                // seen count, and sample
                ImmutableList.of(-1L, ImmutableList.of(1.0, 1.0)),
                // arguments
                copyBlock(BIGINT, nullBlock(), 5),
                copyBlock(BIGINT, bigintBlock(0), 5),
                doubleBlock(1, 1, 1, 1, 1),
                copyBlock(INTEGER, intBlock(sampleSize), 5));
    }

    @Test
    public void testInitialSampleSameSize()
    {
        assertAggregation(
                getDoubleFunction(),
                // seen count, and sample
                ImmutableList.of(15L, ImmutableList.of(1.0, 1.0)),
                // arguments
                // initial sample
                arrayOfBlock(DOUBLE, doubleArrayBlock(1.0, 1.0), 5),
                // initial sample seen count
                copyBlock(BIGINT, bigintBlock(10), 5),
                // actual input values
                doubleBlock(1, 1, 1, 1, 1),
                // sample size
                copyBlock(INTEGER, intBlock(2), 5));
    }

    /**
     * Throws exception because the initial sample size is not equal to the desired sample size
     */
    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInitialSampleWrongSize()
    {
        assertAggregation(
                getDoubleFunction(),
                // seen count, and sample
                ImmutableList.of(15L, ImmutableList.of(1.0, 1.0)),
                // arguments
                // initial sample
                arrayOfBlock(DOUBLE, doubleArrayBlock(1.0, 1.0, 2.0), 5),
                // initial sample seen count
                copyBlock(BIGINT, bigintBlock(10), 5),
                // actual input values
                doubleBlock(1, 1, 1, 1, 1),
                // sample size
                copyBlock(INTEGER, intBlock(2), 5));
    }

    /**
     * valid because when the initial sample was created there could have been less records than
     * the desired sample size.
     */
    @Test
    public void testInitialSampleSmallerThanMaxSize()
    {
        assertAggregation(
                getDoubleFunction(),
                // seen count, and sample
                ImmutableList.of(6L, ImmutableList.of(1.0, 1.0)),
                // arguments
                // initial sample
                arrayOfBlock(DOUBLE, doubleArrayBlock(1.0), 5),
                // initial sample seen count
                copyBlock(BIGINT, bigintBlock(1), 5),
                // actual input values
                doubleBlock(1, 1, 1, 1, 1),
                // sample size
                copyBlock(INTEGER, intBlock(2), 5));
    }

    /**
     * Throws exception because the processed count is less than the size of the initial sample
     */
    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInitialSampleSeenCountSmallerThanInitialSample()
    {
        assertAggregation(
                getDoubleFunction(),
                // seen count, and sample
                ImmutableList.of(6L, ImmutableList.of(1.0, 1.0)),
                // arguments
                // initial sample
                arrayOfBlock(DOUBLE, doubleArrayBlock(1.0, 1.0), 5),
                // initial sample seen count
                copyBlock(BIGINT, bigintBlock(1), 5),
                // actual input values
                doubleBlock(1, 1, 1, 1, 1),
                // sample size
                copyBlock(INTEGER, intBlock(2), 5));
    }

    @Test
    public void testValidResults()
    {
        Object result = executeAggregation(
                getDoubleFunction(),
                // initial sample
                copyBlock(UNKNOWN, nullBlock(), 10),
                // initial sample seen count
                copyBlock(BIGINT, bigintBlock(0), 10),
                // actual input values
                doubleBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
                // sample size
                copyBlock(INTEGER, intBlock(4), 10));
        Set<Double> items = IntStream.range(0, 10).boxed().map(Integer::doubleValue).collect(Collectors.toSet());
        assertTrue(result instanceof List);
        List<Object> resultItems = (List<Object>) result;
        Long processedCount = (Long) resultItems.get(0);
        assertEquals(processedCount, Long.valueOf(items.size()));
        List<Object> sample = (List<Object>) resultItems.get(1);
        assertTrue(items.containsAll(sample));
    }

    private JavaAggregationFunctionImplementation getFunction(Type... arguments)
    {
        return functionAndTypeManager.getJavaAggregateFunctionImplementation(functionAndTypeManager.lookupFunction("reservoir_sample", fromTypes(arguments)));
    }

    private JavaAggregationFunctionImplementation getDoubleFunction()
    {
        return getFunction(new ArrayType(DOUBLE), BIGINT, DOUBLE, INTEGER);
    }

    private static Block bigintBlock(long value)
    {
        return BIGINT.createBlockBuilder(null, 1).writeLong(value).build();
    }

    private static Block intBlock(int... values)
    {
        BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, values.length);
        Arrays.stream(values).forEach(value -> {
            INTEGER.writeLong(blockBuilder, value);
        });
        return blockBuilder.build();
    }

    private static Block doubleBlock(double... values)
    {
        BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, values.length);
        Arrays.stream(values).forEach(value -> {
            DOUBLE.writeDouble(blockBuilder, value);
        });
        return blockBuilder.build();
    }

    private static Block doubleArrayBlock(double... values)
    {
        BlockBuilder builder = DOUBLE.createBlockBuilder(null, values.length);
        Arrays.stream(values)
                .forEach(value -> {
                    DOUBLE.writeDouble(builder, value);
                });
        return builder.build();
    }

    private static Block arrayOfBlock(Type innerType, Block value, int count)
    {
        Type arrayType = new ArrayType(innerType);
        BlockBuilder builder = arrayType.createBlockBuilder(null, count);
        for (int i = 0; i < count; i++) {
            builder.appendStructure(value);
        }
        return builder.build();
    }

    private static Block nullBlock()
    {
        return DOUBLE.createBlockBuilder(null, 1).appendNull().build();
    }

    private static Block copyBlock(Type type, Block value, int positionCount)
    {
        BlockBuilder builder = type.createBlockBuilder(null, positionCount);
        for (int i = 0; i < positionCount; i++) {
            type.appendTo(value, 0, builder);
        }
        return builder.build();
    }
}