TestPrecisionRecallAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Iterator;
import java.util.List;
import java.util.Locale;

import static com.facebook.presto.block.BlockAssertions.createBooleansBlock;
import static com.facebook.presto.block.BlockAssertions.createDoublesBlock;
import static com.facebook.presto.block.BlockAssertions.createLongsBlock;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public abstract class TestPrecisionRecallAggregation
        extends AbstractTestAggregationFunction
{
    private static final Integer NUM_BINS = 3;
    private static final double MIN_FALSE_PRED = 0.2;
    private static final double MAX_FALSE_PRED = 0.5;

    private final String functionName;
    private JavaAggregationFunctionImplementation precisionRecallFunction;

    @BeforeClass
    public void setUp()
    {
        FunctionAndTypeManager functionAndTypeManager = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
        precisionRecallFunction = functionAndTypeManager.getJavaAggregateFunctionImplementation(
                functionAndTypeManager.lookupFunction(
                        this.functionName,
                        fromTypes(BIGINT, BOOLEAN, DOUBLE, DOUBLE)));
    }

    @Test
    public void testNegativeWeight()
    {
        try {
            assertAggregation(
                    precisionRecallFunction,
                    0.0,
                    createLongsBlock(Long.valueOf(200)),
                    createBooleansBlock(Boolean.valueOf(true)),
                    createDoublesBlock(Double.valueOf(0.2)),
                    createDoublesBlock(Double.valueOf(-0.2)));
            fail("Expected PrestoException");
        }
        catch (PrestoException e) {
            assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("weight"));
            assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("negative"));
        }
    }

    @Test
    public void testTooHighPrediction()
    {
        try {
            assertAggregation(
                    precisionRecallFunction,
                    0.0,
                    createLongsBlock(Long.valueOf(200)),
                    createBooleansBlock(Boolean.valueOf(true)),
                    createDoublesBlock(Double.valueOf(1.2)),
                    createDoublesBlock(Double.valueOf(0.2)));
            fail("Expected PrestoException");
        }
        catch (PrestoException e) {
            assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("prediction"));
        }
    }

    @Test
    public void testTooLowPrediction()
    {
        try {
            assertAggregation(
                    precisionRecallFunction,
                    0.0,
                    createLongsBlock(Long.valueOf(200)),
                    createBooleansBlock(Boolean.valueOf(true)),
                    createDoublesBlock(Double.valueOf(-1.2)),
                    createDoublesBlock(Double.valueOf(0.2)));
            fail("Expected PrestoException");
        }
        catch (PrestoException e) {
            assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("prediction"));
        }
    }

    @Test
    public void testNonConstantBuckets()
    {
        try {
            assertAggregation(
                    precisionRecallFunction,
                    0.0,
                    createLongsBlock(Long.valueOf(200), Long.valueOf(300)),
                    createBooleansBlock(Boolean.valueOf(true), Boolean.valueOf(false)),
                    createDoublesBlock(Double.valueOf(0.2), Double.valueOf(0.3)),
                    createDoublesBlock(Double.valueOf(1), Double.valueOf(1)));
            fail("Expected PrestoException");
        }
        catch (PrestoException e) {
            assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("bucket"));
        }
    }

    @Override
    public Block[] getSequenceBlocks(int start, int length)
    {
        start = Math.abs(start);

        BlockBuilder bucketCountBlockBuilder = BIGINT.createBlockBuilder(null, length);
        BlockBuilder outcomeBlockBuilder = BOOLEAN.createBlockBuilder(null, length);
        BlockBuilder predBlockBuilder = DOUBLE.createBlockBuilder(null, length);
        for (int i = start; i < start + length; i++) {
            BIGINT.writeLong(bucketCountBlockBuilder, TestPrecisionRecallAggregation.NUM_BINS);
            final Result result =
                    TestPrecisionRecallAggregation.getResult(start, length, i);
            BOOLEAN.writeBoolean(outcomeBlockBuilder, result.outcome);
            DOUBLE.writeDouble(predBlockBuilder, result.prediction);
        }

        return new Block[] {
                bucketCountBlockBuilder.build(),
                outcomeBlockBuilder.build(),
                predBlockBuilder.build(),
        };
    }

    private static class Result
    {
        public final Boolean outcome;
        public final Double prediction;

        public Result(Boolean outcome, Double prediction)
        {
            this.outcome = outcome;
            this.prediction = prediction;
        }
    }

    protected static class BucketResult
    {
        public final Double left;
        public final Double right;
        public final Double totalTrueWeight;
        public final Double totalFalseWeight;
        public final Double remainingTrueWeight;
        public final Double remainingFalseWeight;

        public BucketResult(
                Double left,
                Double right,
                Double totalTrueWeight,
                Double totalFalseWeight,
                Double remainingTrueWeight,
                Double remainingFalseWeight)
        {
            this.left = left;
            this.right = right;
            this.totalTrueWeight = totalTrueWeight;
            this.totalFalseWeight = totalFalseWeight;
            this.remainingTrueWeight = remainingTrueWeight;
            this.remainingFalseWeight = remainingFalseWeight;
        }
    }

    protected static Iterator<BucketResult> getResultsIterator(int start, int length)
    {
        final int effectiveStart = Math.abs(start);

        return new Iterator<BucketResult>()
        {
            int i;

            @Override
            public boolean hasNext()
            {
                final Double left = (double) (i) / TestPrecisionRecallAggregation.NUM_BINS;
                for (int j = start; j < effectiveStart + length; j++) {
                    final Result result =
                            TestPrecisionRecallAggregation.getResult(effectiveStart, length, j);
                    if (result.outcome && result.prediction >= left) {
                        return true;
                    }
                }
                return false;
            }

            @Override
            public BucketResult next()
            {
                final Double left = (double) (i) / TestPrecisionRecallAggregation.NUM_BINS;
                final Double right = (double) (i + 1) / TestPrecisionRecallAggregation.NUM_BINS;
                Double totalTrue = 0.0;
                Double totalFalse = 0.0;
                Double remainingTrue = 0.0;
                Double remainingFalse = 0.0;
                for (int j = start; j < start + length; j++) {
                    final Result result =
                            TestPrecisionRecallAggregation.getResult(start, length, j);
                    if (result.outcome) {
                        totalTrue += 1.0;
                        if (result.prediction >= left) {
                            remainingTrue += 1.0;
                        }
                    }
                    else {
                        totalFalse += 1.0;
                        if (result.prediction >= left) {
                            remainingFalse += 1.0;
                        }
                    }
                }
                i++;
                return new BucketResult(left, right, totalTrue, totalFalse, remainingTrue, remainingFalse);
            }

            @Override
            public void remove()
            {
                throw new UnsupportedOperationException();
            }
        };
    }

    protected TestPrecisionRecallAggregation(String functionName)
    {
        this.functionName = functionName;
    }

    @Override
    protected String getFunctionName()
    {
        return functionName;
    }

    @Override
    protected List<String> getFunctionParameterTypes()
    {
        return ImmutableList.of(
                StandardTypes.INTEGER,
                StandardTypes.BOOLEAN,
                StandardTypes.DOUBLE);
    }

    protected static Result getResult(int start, int length, int i)
    {
        final Double prediction = Double.valueOf(i - start) / (length + 1);
        final Boolean outcome = prediction < TestPrecisionRecallAggregation.MIN_FALSE_PRED ||
                prediction > TestPrecisionRecallAggregation.MAX_FALSE_PRED;
        return new Result(outcome, prediction);
    }
}