WeightedDoubleReservoirSample.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation.differentialentropy;

import io.airlift.slice.SizeOf;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import org.openjdk.jol.info.ClassLayout;

import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;

import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class WeightedDoubleReservoirSample
        implements Cloneable
{
    public static final int MAX_SAMPLES_LIMIT = 1_000_000;
    private static final int INSTANCE_SIZE = ClassLayout.parseClass(WeightedDoubleReservoirSample.class).instanceSize();

    private int count;
    private double[] samples;
    private double[] weights;
    private double totalPopulationWeight;

    public WeightedDoubleReservoirSample(int maxSamples)
    {
        checkArgument(maxSamples > 0, format("Maximum number of samples must be positive: %s", maxSamples));
        checkArgument(
                maxSamples <= MAX_SAMPLES_LIMIT,
                format("Maximum number of samples must not exceed limit: %s %s", maxSamples, MAX_SAMPLES_LIMIT));

        this.samples = new double[maxSamples];
        this.weights = new double[maxSamples];
    }

    private WeightedDoubleReservoirSample(WeightedDoubleReservoirSample other)
    {
        this.count = other.count;
        this.samples = Arrays.copyOf(other.samples, other.samples.length);
        this.weights = Arrays.copyOf(other.weights, other.weights.length);
        this.totalPopulationWeight = other.totalPopulationWeight;
    }

    private WeightedDoubleReservoirSample(int count, double[] samples, double[] weights, double totalPopulationWeight)
    {
        this.count = count;
        this.samples = requireNonNull(samples, "samples is null");
        this.weights = requireNonNull(weights, "weights is null");
        this.totalPopulationWeight = totalPopulationWeight;
    }

    public long getMaxSamples()
    {
        return samples.length;
    }

    public void add(double sample, double weight)
    {
        checkArgument(weight >= 0, format("Weight %s cannot be negative", weight));
        totalPopulationWeight += weight;
        double adjustedWeight = Math.pow(
                ThreadLocalRandom.current().nextDouble(),
                1.0 / weight);
        addWithAdjustedWeight(sample, adjustedWeight);
    }

    private void addWithAdjustedWeight(double sample, double adjustedWeight)
    {
        if (count < samples.length) {
            samples[count] = sample;
            weights[count] = adjustedWeight;
            count++;
            bubbleUp();
            return;
        }

        if (adjustedWeight <= weights[0]) {
            return;
        }

        samples[0] = sample;
        weights[0] = adjustedWeight;
        bubbleDown();
    }

    public void mergeWith(WeightedDoubleReservoirSample other)
    {
        totalPopulationWeight += other.totalPopulationWeight;
        for (int i = 0; i < other.count; i++) {
            addWithAdjustedWeight(other.samples[i], other.weights[i]);
        }
    }

    @Override
    public final WeightedDoubleReservoirSample clone()
    {
        return new WeightedDoubleReservoirSample(this);
    }

    public double[] getSamples()
    {
        return Arrays.copyOf(samples, count);
    }

    private void swap(int i, int j)
    {
        double tmpElement = samples[i];
        double tmpWeight = weights[i];
        samples[i] = samples[j];
        weights[i] = weights[j];
        samples[j] = tmpElement;
        weights[j] = tmpWeight;
    }

    private void bubbleDown()
    {
        int index = 0;
        while (leftChild(index) < count) {
            int smallestChildIndex = leftChild(index);

            if (rightChild(index) < count && weights[leftChild(index)] > weights[rightChild(index)]) {
                smallestChildIndex = rightChild(index);
            }

            if (weights[index] > weights[smallestChildIndex]) {
                swap(index, smallestChildIndex);
            }
            else {
                break;
            }

            index = smallestChildIndex;
        }
    }

    private void bubbleUp()
    {
        int index = count - 1;
        while (index > 0 && weights[index] < weights[parent(index)]) {
            swap(index, parent(index));
            index = parent(index);
        }
    }

    private static int parent(int pos)
    {
        return pos / 2;
    }

    private static int leftChild(int pos)
    {
        return 2 * pos;
    }

    private static int rightChild(int pos)
    {
        return 2 * pos + 1;
    }

    public static WeightedDoubleReservoirSample deserialize(SliceInput input)
    {
        int count = input.readInt();
        int maxSamples = input.readInt();
        checkArgument(count <= maxSamples, "count must not be larger than number of samples");
        double[] samples = new double[maxSamples];
        input.readBytes(Slices.wrappedDoubleArray(samples), count * SizeOf.SIZE_OF_DOUBLE);
        double[] weights = new double[maxSamples];
        input.readBytes(Slices.wrappedDoubleArray(weights), count * SizeOf.SIZE_OF_DOUBLE);
        double totalPopulationWeight = input.readDouble();
        return new WeightedDoubleReservoirSample(count, samples, weights, totalPopulationWeight);
    }

    public void serialize(SliceOutput output)
    {
        output.appendInt(count);
        output.appendInt(samples.length);
        for (int i = 0; i < count; i++) {
            output.appendDouble(samples[i]);
        }
        for (int i = 0; i < count; i++) {
            output.appendDouble(weights[i]);
        }
        output.appendDouble(totalPopulationWeight);
    }

    public int getRequiredBytesForSerialization()
    {
        return SizeOf.SIZE_OF_INT + // count
                SizeOf.SIZE_OF_INT + 2 * SizeOf.SIZE_OF_DOUBLE * Math.min(count, samples.length) + // samples, weights
                SizeOf.SIZE_OF_DOUBLE; // totalPopulationWeight;
    }

    public long estimatedInMemorySize()
    {
        return INSTANCE_SIZE +
                SizeOf.sizeOf(samples) +
                SizeOf.sizeOf(weights);
    }

    public double getTotalPopulationWeight()
    {
        return totalPopulationWeight;
    }
}