ReservoirSample.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.reservoirsample;
import com.facebook.presto.common.block.ArrayBlock;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import io.airlift.slice.SizeOf;
import org.openjdk.jol.info.ClassLayout;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.max;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public class ReservoirSample
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleReservoirSampleState.class).instanceSize();
private final Type type;
public Type getArrayType()
{
return arrayType;
}
private final Type arrayType;
/**
* Represents the list of sampled values.
* <br>
* We use an {@link ArrayList} instead of {@link Block} because the
* algorithm that generates reservoir samples requires shuffling of elements
* in the reservoir.
* <br>
* The {@link Block} interface doesn't have any method for setting values at
* arbitrary positions, so we resort to internally representing the sample
* as a list and then combining the samples into a single block later.
*/
private ArrayList<Block> samples;
private int maxSampleSize = -1;
private long processedCount;
private Block initialSample;
public Block getInitialSample()
{
return initialSample;
}
public long getInitialProcessedCount()
{
return initialProcessedCount;
}
private long initialProcessedCount = -1;
public ReservoirSample(Type type)
{
this.type = requireNonNull(type, "type is null");
this.arrayType = new ArrayType(type);
this.samples = new ArrayList<>();
}
protected ReservoirSample(Type type, long processedCount, int maxSampleSize, Block samples, Block initialSample, long initialSeenCount)
{
this.type = requireNonNull(type, "type is null");
this.arrayType = new ArrayType(type);
this.processedCount = processedCount;
this.samples = blockToList(samples);
this.maxSampleSize = maxSampleSize;
initializeInitialSample(initialSample, initialSeenCount);
}
private static ArrayList<Block> blockToList(Block inputBlock)
{
// sometimes single values such as bigint/double are serialized as
// LongArrayBlock which don't implement the Block::getBlock function.
// ArrayBlock::getSingleValueBlock returns another ArrayBlock of size 1, whereas
// we need to extract the internal block rather than have an array
Function<Integer, Block> extractor = inputBlock instanceof ArrayBlock ? inputBlock::getBlock : inputBlock::getSingleValueBlock;
return IntStream.range(0, inputBlock.getPositionCount())
.mapToObj(extractor::apply)
.collect(Collectors.toCollection(ArrayList::new));
}
private static ArrayList<Block> mergeBlockSamples(ArrayList<Block> samples1, ArrayList<Block> samples2, long seenCount1, long seenCount2)
{
int nextIndex = 0;
int otherNextIndex = 0;
ArrayList<Block> merged = new ArrayList<>(samples1.size());
for (int i = 0; i < samples1.size(); i++) {
if (ThreadLocalRandom.current().nextLong(0, seenCount1 + seenCount2) < seenCount1) {
merged.add(samples1.get(nextIndex++));
}
else {
merged.add(samples2.get(otherNextIndex++));
}
}
return merged;
}
public void tryInitialize(int n)
{
if (sampleNotInitialized()) {
samples = new ArrayList<>(max(n, 0));
maxSampleSize = n;
}
}
public void initializeInitialSample(@Nullable Block initialSample, long initialProcessedCount)
{
if (this.initialProcessedCount < 0) {
if (initialSample != null && initialSample.getPositionCount() > 0) {
checkArgument(initialProcessedCount >= initialSample.getPositionCount(),
"initialProcessedCount must be greater than or equal " +
"to the number of positions in the initial sample");
}
this.initialSample = initialSample;
this.initialProcessedCount = initialProcessedCount;
}
}
public void mergeWith(@Nullable ReservoirSample other)
{
if (other == null) {
return;
}
merge(other);
initializeInitialSample(other.initialSample, other.initialProcessedCount);
}
private boolean sampleNotInitialized()
{
return maxSampleSize < 0 || samples == null;
}
public int getSampleSize()
{
if (sampleNotInitialized()) {
return 0;
}
return samples.size();
}
public int getMaxSampleSize()
{
return maxSampleSize;
}
/**
* Potentially add a value from a block at a given position into the sample.
*
* @param block the block containing the potential sample
* @param position the position in the block to potentially insert
*/
public void add(Block block, int position)
{
if (sampleNotInitialized()) {
throw new IllegalArgumentException("reservoir sample not properly initialized");
}
processedCount++;
int sampleSize = getMaxSampleSize();
if (processedCount <= sampleSize) {
BlockBuilder sampleBlock = type.createBlockBuilder(null, 1);
type.appendTo(block, position, sampleBlock);
samples.add(sampleBlock.build());
}
else {
long index = ThreadLocalRandom.current().nextLong(0, processedCount);
if (index < samples.size()) {
BlockBuilder sampleBlock = type.createBlockBuilder(null, 1);
type.appendTo(block, position, sampleBlock);
samples.set((int) index, sampleBlock.build());
}
}
}
private void addSingleBlock(Block block)
{
processedCount++;
int sampleSize = getMaxSampleSize();
if (processedCount <= sampleSize) {
samples.add(block);
}
else {
long index = ThreadLocalRandom.current().nextLong(0L, processedCount);
if (index < samples.size()) {
samples.set((int) index, block);
}
}
}
public void merge(ReservoirSample other)
{
if (sampleNotInitialized()) {
tryInitialize(other.getMaxSampleSize());
}
if (other.sampleNotInitialized()) {
return;
}
checkArgument(
getMaxSampleSize() == other.getMaxSampleSize(),
format("maximum number of samples %s must be equal to that of other %s", getMaxSampleSize(), other.getMaxSampleSize()));
if (other.processedCount < getMaxSampleSize()) {
for (int i = 0; i < other.samples.size(); i++) {
addSingleBlock(other.samples.get(i));
}
return;
}
if (processedCount < getMaxSampleSize()) {
for (int i = 0; i < processedCount; i++) {
other.addSingleBlock(samples.get(i));
}
processedCount = other.processedCount;
samples = other.samples;
return;
}
Collections.shuffle(samples);
Collections.shuffle(other.samples);
samples = mergeBlockSamples(samples, other.samples, processedCount, other.processedCount);
processedCount += other.processedCount;
}
public Type getType()
{
return type;
}
public long getProcessedCount()
{
return processedCount;
}
public long estimatedInMemorySize()
{
return INSTANCE_SIZE +
(initialSample != null ? initialSample.getSizeInBytes() : 0) +
SizeOf.sizeOfObjectArray(samples.size());
}
public void serialize(BlockBuilder out)
{
BlockBuilder sampleBlock = getSampleBlockBuilder();
if (initialSample == null) {
out.appendNull();
}
else {
out.appendStructure(initialSample);
}
BIGINT.writeLong(out, initialProcessedCount);
BIGINT.writeLong(out, processedCount);
INTEGER.writeLong(out, maxSampleSize);
arrayType.appendTo(sampleBlock.build(), 0, out);
}
BlockBuilder getSampleBlockBuilder()
{
int sampleSize = getSampleSize();
BlockBuilder sampleBlock = arrayType.createBlockBuilder(null, sampleSize);
BlockBuilder sampleEntryBuilder = sampleBlock.beginBlockEntry();
for (int i = 0; i < sampleSize; i++) {
type.appendTo(samples.get(i), 0, sampleEntryBuilder);
}
sampleBlock.closeEntry();
return sampleBlock;
}
}