FixedHistogramMleStateStrategy.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation.differentialentropy;

import com.facebook.presto.operator.aggregation.fixedhistogram.FixedDoubleHistogram;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;

import static com.facebook.presto.operator.aggregation.differentialentropy.FixedHistogramStateStrategyUtils.getXLogX;
import static com.google.common.collect.Streams.stream;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

/**
 * Calculates sample entropy using MLE (maximum likelihood estimates) on a NumericHistogram.
 */
public class FixedHistogramMleStateStrategy
        implements DifferentialEntropyStateStrategy
{
    private final FixedDoubleHistogram histogram;

    public FixedHistogramMleStateStrategy(long bucketCount, double min, double max)
    {
        FixedHistogramStateStrategyUtils.validateParameters(
                bucketCount,
                min,
                max);

        histogram = new FixedDoubleHistogram(toIntExact(bucketCount), min, max);
    }

    private FixedHistogramMleStateStrategy(FixedHistogramMleStateStrategy other)
    {
        histogram = other.histogram.clone();
    }

    private FixedHistogramMleStateStrategy(FixedDoubleHistogram histogram)
    {
        this.histogram = requireNonNull(histogram, "histogram is null");
    }

    @Override
    public void validateParameters(
            long bucketCount,
            double sample,
            double weight,
            double min,
            double max)
    {
        FixedHistogramStateStrategyUtils.validateParameters(
                histogram.getBucketCount(),
                histogram.getMin(),
                histogram.getMax(),
                bucketCount,
                sample,
                weight,
                min,
                max);
    }

    @Override
    public void add(double sample, double weight)
    {
        histogram.add(sample, weight);
    }

    @Override
    public double getTotalPopulationWeight()
    {
        return stream(histogram.iterator())
                .mapToDouble(FixedDoubleHistogram.Bucket::getWeight)
                .sum();
    }

    @Override
    public double calculateEntropy()
    {
        double sum = 0;
        for (FixedDoubleHistogram.Bucket bucket : histogram) {
            sum += bucket.getWeight();
        }
        if (sum == 0.0) {
            return Double.NaN;
        }

        double rawEntropy = 0;
        for (FixedDoubleHistogram.Bucket bucket : histogram) {
            rawEntropy -= getXLogX(bucket.getWeight() / sum);
        }
        return (rawEntropy + Math.log(histogram.getWidth())) / Math.log(2);
    }

    @Override
    public long getEstimatedSize()
    {
        return histogram.estimatedInMemorySize();
    }

    @Override
    public int getRequiredBytesForSpecificSerialization()
    {
        return histogram.getRequiredBytesForSerialization();
    }

    @Override
    public void mergeWith(DifferentialEntropyStateStrategy other)
    {
        histogram.mergeWith(((FixedHistogramMleStateStrategy) other).histogram);
    }

    public static FixedHistogramMleStateStrategy deserialize(SliceInput input)
    {
        FixedDoubleHistogram histogram = FixedDoubleHistogram.deserialize(input);

        return new FixedHistogramMleStateStrategy(histogram);
    }

    @Override
    public void serialize(SliceOutput out)
    {
        histogram.serialize(out);
    }

    @Override
    public DifferentialEntropyStateStrategy clone()
    {
        return new FixedHistogramMleStateStrategy(this);
    }

    @Override
    public DifferentialEntropyStateStrategy cloneEmpty()
    {
        return new FixedHistogramMleStateStrategy(histogram.getBucketCount(), histogram.getMin(), histogram.getMax());
    }
}