FixedHistogramJacknifeStateStrategy.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.differentialentropy;
import com.facebook.presto.operator.aggregation.fixedhistogram.FixedDoubleBreakdownHistogram;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import java.util.Map;
import static com.facebook.presto.operator.aggregation.differentialentropy.EntropyCalculations.calculateEntropyFromHistogramAggregates;
import static com.facebook.presto.operator.aggregation.differentialentropy.FixedHistogramStateStrategyUtils.getXLogX;
import static com.google.common.collect.Streams.stream;
import static java.lang.Math.toIntExact;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.summingDouble;
/**
* Calculates sample entropy using jacknife estimates on a fixed histogram.
* See http://cs.brown.edu/~pvaliant/unseen_nips.pdf.
*/
public class FixedHistogramJacknifeStateStrategy
implements DifferentialEntropyStateStrategy
{
private final FixedDoubleBreakdownHistogram histogram;
public FixedHistogramJacknifeStateStrategy(long bucketCount, double min, double max)
{
FixedHistogramStateStrategyUtils.validateParameters(
bucketCount,
min,
max);
histogram = new FixedDoubleBreakdownHistogram(toIntExact(bucketCount), min, max);
}
private FixedHistogramJacknifeStateStrategy(FixedDoubleBreakdownHistogram histogram)
{
this.histogram = histogram;
}
private FixedHistogramJacknifeStateStrategy(FixedHistogramJacknifeStateStrategy other)
{
histogram = other.histogram.clone();
}
@Override
public void validateParameters(long bucketCount, double sample, double weight, double min, double max)
{
FixedHistogramStateStrategyUtils.validateParameters(
histogram.getBucketCount(),
histogram.getMin(),
histogram.getMax(),
bucketCount,
sample,
weight,
min,
max);
}
@Override
public void mergeWith(DifferentialEntropyStateStrategy other)
{
histogram.mergeWith(((FixedHistogramJacknifeStateStrategy) other).histogram);
}
@Override
public void add(double value, double weight)
{
histogram.add(value, weight);
}
@Override
public double getTotalPopulationWeight()
{
return stream(histogram.iterator())
.mapToDouble(FixedDoubleBreakdownHistogram.Bucket::getWeight)
.sum();
}
@Override
public double calculateEntropy()
{
Map<Double, Double> bucketWeights = stream(histogram.iterator()).collect(
groupingBy(
FixedDoubleBreakdownHistogram.Bucket::getLeft,
summingDouble(e -> e.getCount() * e.getWeight())));
double sumWeight = bucketWeights.values().stream().mapToDouble(Double::doubleValue).sum();
if (sumWeight == 0.0) {
return Double.NaN;
}
long n = stream(histogram.iterator())
.mapToLong(FixedDoubleBreakdownHistogram.Bucket::getCount)
.sum();
double sumWeightLogWeight =
bucketWeights.values().stream().mapToDouble(w -> w == 0.0 ? 0.0 : w * Math.log(w)).sum();
double entropy = n * calculateEntropyFromHistogramAggregates(histogram.getWidth(), sumWeight, sumWeightLogWeight);
for (FixedDoubleBreakdownHistogram.Bucket bucketWeight : histogram) {
double weight = bucketWeights.get(bucketWeight.getLeft());
if (weight > 0.0) {
entropy -= getHoldOutEntropy(
n,
bucketWeight.getRight() - bucketWeight.getLeft(),
sumWeight,
sumWeightLogWeight,
weight,
bucketWeight.getWeight(),
bucketWeight.getCount());
}
}
return entropy;
}
private static double getHoldOutEntropy(
long n,
double width,
double sumW,
double sumWeightLogWeight,
double bucketWeight,
double entryWeight,
long entryMultiplicity)
{
double holdoutBucketWeight = Math.max(bucketWeight - entryWeight, 0);
double holdoutSumWeight =
sumW - bucketWeight + holdoutBucketWeight;
double holdoutSumWeightLogWeight =
sumWeightLogWeight - getXLogX(bucketWeight) + getXLogX(holdoutBucketWeight);
double holdoutEntropy = entryMultiplicity * (n - 1) *
calculateEntropyFromHistogramAggregates(width, holdoutSumWeight, holdoutSumWeightLogWeight) /
n;
return holdoutEntropy;
}
@Override
public long getEstimatedSize()
{
return histogram.estimatedInMemorySize();
}
@Override
public int getRequiredBytesForSpecificSerialization()
{
return histogram.getRequiredBytesForSerialization();
}
public static FixedHistogramJacknifeStateStrategy deserialize(SliceInput input)
{
FixedDoubleBreakdownHistogram histogram = FixedDoubleBreakdownHistogram.deserialize(input);
return new FixedHistogramJacknifeStateStrategy(histogram);
}
@Override
public void serialize(SliceOutput out)
{
histogram.serialize(out);
}
@Override
public DifferentialEntropyStateStrategy clone()
{
return new FixedHistogramJacknifeStateStrategy(this);
}
@Override
public DifferentialEntropyStateStrategy cloneEmpty()
{
return new FixedHistogramJacknifeStateStrategy(histogram.getBucketCount(), histogram.getMin(), histogram.getMax());
}
}