UnweightedDoubleReservoirSample.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.differentialentropy;
import io.airlift.slice.SizeOf;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import org.openjdk.jol.info.ClassLayout;
import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public class UnweightedDoubleReservoirSample
implements Cloneable
{
public static final int MAX_SAMPLES_LIMIT = 1_000_000;
private static final int INSTANCE_SIZE = ClassLayout.parseClass(UnweightedDoubleReservoirSample.class).instanceSize();
private int seenCount;
private double[] samples;
public UnweightedDoubleReservoirSample(int maxSamples)
{
checkArgument(
maxSamples > 0,
format("Maximum number of samples must be positive: %s", maxSamples));
checkArgument(
maxSamples <= MAX_SAMPLES_LIMIT,
format("Maximum number of samples must not exceed maximum: %s %s", maxSamples, MAX_SAMPLES_LIMIT));
this.samples = new double[maxSamples];
}
private UnweightedDoubleReservoirSample(UnweightedDoubleReservoirSample other)
{
this.seenCount = other.seenCount;
this.samples = Arrays.copyOf(requireNonNull(other.samples, "samples is null"), other.samples.length);
}
private UnweightedDoubleReservoirSample(int seenCount, double[] samples)
{
this.seenCount = seenCount;
this.samples = samples;
}
public int getMaxSamples()
{
return samples.length;
}
public void add(double sample)
{
seenCount++;
if (seenCount <= samples.length) {
samples[seenCount - 1] = sample;
return;
}
int index = ThreadLocalRandom.current().nextInt(0, seenCount);
if (index < samples.length) {
samples[index] = sample;
}
}
public void mergeWith(UnweightedDoubleReservoirSample other)
{
checkArgument(
samples.length == other.samples.length,
format("Maximum number of samples %s must be equal to that of other %s", samples.length, other.samples.length));
if (other.seenCount < other.samples.length) {
for (int i = 0; i < other.seenCount; i++) {
add(other.samples[i]);
}
return;
}
if (seenCount < samples.length) {
UnweightedDoubleReservoirSample target = ((UnweightedDoubleReservoirSample) other.clone());
for (int i = 0; i < seenCount; i++) {
target.add(samples[i]);
}
seenCount = target.seenCount;
samples = target.samples;
return;
}
shuffleArray(samples);
shuffleArray(other.samples);
int nextIndex = 0;
int otherNextIndex = 0;
double[] merged = new double[samples.length];
for (int i = 0; i < samples.length; i++) {
if (ThreadLocalRandom.current().nextLong(0, seenCount + other.seenCount) < seenCount) {
merged[i] = samples[nextIndex++];
}
else {
merged[i] = other.samples[otherNextIndex++];
}
}
seenCount += other.seenCount;
samples = merged;
}
public int getTotalPopulationCount()
{
return seenCount;
}
@Override
public final UnweightedDoubleReservoirSample clone()
{
return new UnweightedDoubleReservoirSample(this);
}
public double[] getSamples()
{
return Arrays.copyOf(samples, Math.min(seenCount, samples.length));
}
private static void shuffleArray(double[] samples)
{
for (int i = samples.length - 1; i > 0; i--) {
int index = ThreadLocalRandom.current().nextInt(0, i + 1);
double sample = samples[index];
samples[index] = samples[i];
samples[i] = sample;
}
}
public static UnweightedDoubleReservoirSample deserialize(SliceInput input)
{
int seenCount = input.readInt();
int maxSamples = input.readInt();
double[] samples = new double[maxSamples];
input.readBytes(Slices.wrappedDoubleArray(samples), Math.min(seenCount, samples.length) * SizeOf.SIZE_OF_DOUBLE);
return new UnweightedDoubleReservoirSample(seenCount, samples);
}
public void serialize(SliceOutput output)
{
output.appendInt(seenCount);
output.appendInt(samples.length);
for (int i = 0; i < Math.min(seenCount, samples.length); i++) {
output.appendDouble(samples[i]);
}
}
public int getRequiredBytesForSerialization()
{
return SizeOf.SIZE_OF_INT + // seenCount
SizeOf.SIZE_OF_INT + SizeOf.SIZE_OF_DOUBLE * Math.min(seenCount, samples.length); // samples
}
public long estimatedInMemorySize()
{
return INSTANCE_SIZE +
SizeOf.sizeOf(samples);
}
}