PrecisionRecallAggregation.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.operator.aggregation.fixedhistogram.FixedDoubleHistogram;
import com.facebook.presto.operator.aggregation.state.PrecisionRecallState;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.google.common.collect.Streams;
import java.util.Collections;
import java.util.Iterator;
import java.util.NoSuchElementException;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
public abstract class PrecisionRecallAggregation
{
private static final double DEFAULT_WEIGHT = 1.0;
private static final double MIN_PREDICTION_VALUE = 0.0;
private static final double MAX_PREDICTION_VALUE = 1.0;
// Effective maximum prediction, in order to ensure bin corresponding exactly to 1 is not reached.
private static final double MAX_PREDICTION_VALUE_FOR_HISTOGRAM = 0.99999999999;
private static final String ILLEGAL_PREDICTION_VALUE_MESSAGE = String.format(
"Prediction value must be between %s and %s",
MIN_PREDICTION_VALUE,
MAX_PREDICTION_VALUE);
private static final String NEGATIVE_WEIGHT_MESSAGE = "Weights must be non-negative";
private static final String INCONSISTENT_BUCKET_COUNT_MESSAGE = "Bucket count must be constant";
protected PrecisionRecallAggregation() {}
@InputFunction
public static void input(
@AggregationState PrecisionRecallState state,
@SqlType(StandardTypes.BIGINT) long bucketCount,
@SqlType(StandardTypes.BOOLEAN) boolean outcome,
@SqlType(StandardTypes.DOUBLE) double pred,
@SqlType(StandardTypes.DOUBLE) double weight)
{
if (state.getTrueWeights() == null) {
state.setTrueWeights(new FixedDoubleHistogram(
(int) (bucketCount),
MIN_PREDICTION_VALUE,
MAX_PREDICTION_VALUE));
state.setFalseWeights(new FixedDoubleHistogram(
(int) (bucketCount),
MIN_PREDICTION_VALUE,
MAX_PREDICTION_VALUE));
}
if (pred < MIN_PREDICTION_VALUE || pred > MAX_PREDICTION_VALUE) {
throw new PrestoException(
INVALID_FUNCTION_ARGUMENT,
ILLEGAL_PREDICTION_VALUE_MESSAGE);
}
pred = Math.min(pred, MAX_PREDICTION_VALUE_FOR_HISTOGRAM);
if (weight < 0) {
throw new PrestoException(
INVALID_FUNCTION_ARGUMENT,
NEGATIVE_WEIGHT_MESSAGE);
}
if (bucketCount != state.getTrueWeights().getBucketCount()) {
throw new PrestoException(
INVALID_FUNCTION_ARGUMENT,
INCONSISTENT_BUCKET_COUNT_MESSAGE);
}
if (outcome) {
state.getTrueWeights().add(pred, weight);
}
else {
state.getFalseWeights().add(pred, weight);
}
}
@InputFunction
public static void input(
@AggregationState PrecisionRecallState state,
@SqlType(StandardTypes.BIGINT) long bucketCount,
@SqlType(StandardTypes.BOOLEAN) boolean outcome,
@SqlType(StandardTypes.DOUBLE) double pred)
{
PrecisionRecallAggregation.input(state, bucketCount, outcome, pred, DEFAULT_WEIGHT);
}
@CombineFunction
public static void combine(
@AggregationState PrecisionRecallState state,
@AggregationState PrecisionRecallState otherState)
{
if (state.getTrueWeights() == null && otherState.getTrueWeights() != null) {
state.setTrueWeights(otherState.getTrueWeights().clone());
state.setFalseWeights(otherState.getFalseWeights().clone());
return;
}
if (state.getTrueWeights() != null && otherState.getTrueWeights() != null) {
state.getTrueWeights().mergeWith(otherState.getTrueWeights());
state.getFalseWeights().mergeWith(otherState.getFalseWeights());
}
}
protected static class BucketResult
{
private final double threshold;
private final double positive;
private final double negative;
private final double truePositive;
private final double trueNegative;
private final double falsePositive;
private final double falseNegative;
public double getThreshold()
{
return threshold;
}
public double getPositive()
{
return positive;
}
public double getNegative()
{
return negative;
}
public double getTruePositive()
{
return truePositive;
}
public double getTrueNegative()
{
return trueNegative;
}
public double getFalsePositive()
{
return falsePositive;
}
public double getFalseNegative()
{
return falseNegative;
}
public BucketResult(
double threshold,
double positive,
double negative,
double truePositive,
double trueNegative,
double falsePositive,
double falseNegative)
{
this.threshold = threshold;
this.positive = positive;
this.negative = negative;
this.truePositive = truePositive;
this.trueNegative = trueNegative;
this.falsePositive = falsePositive;
this.falseNegative = falseNegative;
}
}
protected static Iterator<BucketResult> getResultsIterator(@AggregationState PrecisionRecallState state)
{
if (state.getTrueWeights() == null) {
return Collections.<BucketResult>emptyList().iterator();
}
double totalTrueWeight = Streams.stream(state.getTrueWeights().iterator())
.mapToDouble(FixedDoubleHistogram.Bucket::getWeight)
.sum();
double totalFalseWeight = Streams.stream(state.getFalseWeights().iterator())
.mapToDouble(FixedDoubleHistogram.Bucket::getWeight)
.sum();
return new Iterator<BucketResult>()
{
Iterator<FixedDoubleHistogram.Bucket> trueIterator = state.getTrueWeights().iterator();
Iterator<FixedDoubleHistogram.Bucket> falseIterator = state.getFalseWeights().iterator();
double runningFalseWeight;
double runningTrueWeight;
@Override
public boolean hasNext()
{
return trueIterator.hasNext() && totalTrueWeight > runningTrueWeight;
}
@Override
public BucketResult next()
{
if (!trueIterator.hasNext() || !falseIterator.hasNext()) {
throw new NoSuchElementException();
}
FixedDoubleHistogram.Bucket trueResult = trueIterator.next();
FixedDoubleHistogram.Bucket falseResult = falseIterator.next();
BucketResult result = new BucketResult(
trueResult.getLeft(),
totalTrueWeight,
totalFalseWeight,
totalTrueWeight - runningTrueWeight,
runningFalseWeight,
totalFalseWeight - runningFalseWeight,
runningTrueWeight);
runningTrueWeight += trueResult.getWeight();
runningFalseWeight += falseResult.getWeight();
return result;
}
@Override
public void remove()
{
throw new UnsupportedOperationException();
}
};
}
}