TestEntropyAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.presto.block.BlockAssertions.createLongsBlock;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static org.testng.Assert.assertTrue;

public class TestEntropyAggregation
        extends AbstractTestAggregationFunction
{
    private static final String FUNCTION_NAME = "entropy";

    private JavaAggregationFunctionImplementation entropyFunction;

    @BeforeClass
    public void setUp()
    {
        FunctionAndTypeManager functionAndTypeManager = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
        entropyFunction = functionAndTypeManager.getJavaAggregateFunctionImplementation(
                functionAndTypeManager.lookupFunction(TestEntropyAggregation.FUNCTION_NAME, fromTypes(BIGINT)));
    }

    @Test
    public void entropyOfASingle()
    {
        assertAggregation(entropyFunction,
                0.0,
                createLongsBlock(Long.valueOf(1)));
    }

    @Test
    public void entropyOfTwoEquals()
    {
        Long x;

        x = Long.valueOf(1);
        assertAggregation(entropyFunction,
                1.0,
                createLongsBlock(x, x));

        x = Long.valueOf(20);
        assertAggregation(entropyFunction,
                1.0,
                createLongsBlock(x, x));
    }

    @Test
    public void entropyOfTwoEqualsWithNulls()
    {
        Long x;

        x = Long.valueOf(1);
        assertAggregation(entropyFunction,
                1.0,
                createLongsBlock(x, null, x));

        x = Long.valueOf(10);
        assertAggregation(entropyFunction,
                1.0,
                createLongsBlock(null, null, x, x));
    }

    @Test
    public void entropyOfTwoSkewed()
    {
        Long lower = Long.valueOf(30);
        Long higher = Long.valueOf(70);
        Double expected = 0.8812908992306931;

        assertAggregation(entropyFunction,
                expected,
                createLongsBlock(lower, higher));
        assertAggregation(entropyFunction,
                expected,
                createLongsBlock(higher, lower));
    }

    @Test
    public void entropyOfOnlyNulls()
    {
        assertAggregation(entropyFunction,
                0.0,
                createLongsBlock(null, null));
    }

    @Test
    public void entropyOfNegativeCount()
    {
        String error = "";
        try {
            assertAggregation(entropyFunction,
                    0.0,
                    createLongsBlock(Long.valueOf(-1)));
        }
        catch (PrestoException e) {
            error = e.getMessage();
        }
        assertTrue(error.toLowerCase(Locale.ENGLISH).contains("negative"));
    }

    @Override
    public Block[] getSequenceBlocks(int start, int length)
    {
        BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, length);
        for (int i = start; i < start + length; i++) {
            BIGINT.writeLong(blockBuilder, Math.abs(i));
        }
        return new Block[] {blockBuilder.build()};
    }

    @Override
    public Number getExpectedValue(int start, int length)
    {
        final ArrayList<Integer> counts = IntStream
                .range(start, start + length)
                .map(c -> Math.abs(c))
                .boxed()
                .collect(Collectors.toCollection(ArrayList::new));
        if (counts.stream().anyMatch(c -> c < 0)) {
            return null;
        }
        final double sum = counts.stream()
                .mapToDouble(c -> Math.max(c, 0.0))
                .sum();
        if (sum == 0) {
            return 0.0;
        }
        final ArrayList<Double> entropies = counts.stream()
                .filter(c -> c > 0)
                .map(c -> (c / sum) * Math.log(sum / c))
                .collect(Collectors.toCollection(ArrayList::new));
        return entropies.isEmpty() ?
                0 :
                entropies.stream().mapToDouble(c -> c).sum() / Math.log(2);
    }

    @Override
    protected String getFunctionName()
    {
        return TestEntropyAggregation.FUNCTION_NAME;
    }

    @Override
    protected List<String> getFunctionParameterTypes()
    {
        return ImmutableList.of(StandardTypes.INTEGER);
    }
}