StreamSummary.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.approxmostfrequent.stream;
import com.facebook.presto.common.array.IntBigArray;
import com.facebook.presto.common.array.LongBigArray;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.type.TypeUtils;
import com.google.common.annotations.VisibleForTesting;
import org.openjdk.jol.info.ClassLayout;
import java.util.List;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES;
import static com.google.common.base.Preconditions.checkArgument;
import static it.unimi.dsi.fastutil.HashCommon.arraySize;
import static it.unimi.dsi.fastutil.HashCommon.murmurHash3;
import static java.lang.Math.toIntExact;
public class StreamSummary
implements PriorityQueueDataChangeListener
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(StreamSummary.class).instanceSize();
private static final int COMPACT_THRESHOLD_BYTES = 32768;
private static final float FILL_RATIO = 0.75f;
private static final int COMPACT_THRESHOLD_RATIO = 3;
private static final int EMPTY = -1;
private static final int DELETE_MARKER = -2;
private final Type type;
private final int heapCapacity;
private final int maxBuckets;
private int maxFill;
private int mask;
private int generation;
private LongBigArray blockPositionToCount;
private IntBigArray hashToBlockPosition;
private int hashCapacity;
private BlockBuilder heapBlockBuilder;
private final IndexedPriorityQueue minHeap;
private IntBigArray blockToHeapIndex;
public StreamSummary(
Type type,
int maxBuckets,
int heapCapacity)
{
this.type = type;
this.maxBuckets = maxBuckets;
this.heapCapacity = heapCapacity;
this.blockPositionToCount = new LongBigArray();
this.blockToHeapIndex = new IntBigArray();
this.hashToBlockPosition = new IntBigArray(EMPTY);
this.hashCapacity = arraySize(heapCapacity, FILL_RATIO);
this.hashToBlockPosition.ensureCapacity(hashCapacity);
this.heapBlockBuilder = type.createBlockBuilder(null, heapCapacity);
this.minHeap = new IndexedPriorityQueue(heapCapacity, this::compare, this);
this.mask = hashCapacity - 1;
this.maxFill = calculateMaxFill(hashCapacity);
this.blockPositionToCount.ensureCapacity(hashCapacity);
this.blockToHeapIndex.ensureCapacity(hashCapacity);
}
public void add(Block block, int blockPosition, long incrementCount)
{
int hashPosition = getBucketId(TypeUtils.hashPosition(type, block, blockPosition), mask);
// look for empty slot or slot containing this key
while (true) {
int bucketPosition = hashToBlockPosition.get(hashPosition);
if (bucketPosition == EMPTY) {
break;
}
if (bucketPosition != DELETE_MARKER && type.equalTo(block, blockPosition, heapBlockBuilder, bucketPosition)) {
blockPositionToCount.add(bucketPosition, incrementCount);
int heapIndex = blockToHeapIndex.get(bucketPosition);
minHeap.get(heapIndex).setGeneration(generation++);
minHeap.percolateDown(heapIndex);
return;
}
// increment position and mask to handle wrap around
hashPosition = (hashPosition + 1) & mask;
}
addNewGroup(block, blockPosition, hashPosition, incrementCount);
}
private void addNewGroup(Block block, int blockPosition, int hashPosition, long incrementCount)
{
int newElementBlockPosition = heapBlockBuilder.getPositionCount();
if (minHeap.isFull()) {
//replace min
StreamDataEntity min = minHeap.getMin();
int removedBlock = getBlockPosition(min);
long minCount = blockPositionToCount.get(removedBlock);
handleDelete(removedBlock, min.getHashPosition());
hashToBlockPosition.set(hashPosition, newElementBlockPosition);
blockPositionToCount.set(newElementBlockPosition, minCount + incrementCount);
minHeap.replaceMin(new StreamDataEntity(hashPosition, generation++));
}
else {
hashToBlockPosition.set(hashPosition, newElementBlockPosition);
blockPositionToCount.set(newElementBlockPosition, incrementCount);
minHeap.add(new StreamDataEntity(hashPosition, generation++));
}
type.appendTo(block, blockPosition, heapBlockBuilder);
compactAndRehashIfNeeded();
}
private void handleDelete(int removedBlock, int removedHashPosition)
{
blockPositionToCount.set(removedBlock, 0);
blockToHeapIndex.set(removedBlock, EMPTY);
hashToBlockPosition.set(removedHashPosition, DELETE_MARKER);
}
private void compactAndRehashIfNeeded()
{
if (shouldCompact(heapBlockBuilder.getSizeInBytes(), heapBlockBuilder.getPositionCount())) {
compact();
}
else {
if (heapBlockBuilder.getPositionCount() >= maxFill) {
rehash();
}
}
}
protected boolean shouldCompact(long sizeInBytes, int numberOfPositionInBlock)
{
return sizeInBytes >= COMPACT_THRESHOLD_BYTES && numberOfPositionInBlock / getHeapSize() >= COMPACT_THRESHOLD_RATIO;
}
@VisibleForTesting
public int getHeapSize()
{
return minHeap.getSize();
}
private synchronized void compact()
{
BlockBuilder newHeapBlockBuilder = type.createBlockBuilder(null, heapBlockBuilder.getPositionCount());
//since block positions are changed, we need to update all data structures which are using block position as reference
LongBigArray newBlockPositionToCount = new LongBigArray();
hashCapacity = arraySize(heapCapacity, FILL_RATIO);
maxFill = calculateMaxFill(hashCapacity);
newBlockPositionToCount.ensureCapacity(hashCapacity);
IntBigArray newBlockToHeapIndex = new IntBigArray();
newBlockToHeapIndex.ensureCapacity(hashCapacity);
for (int heapPosition = 0; heapPosition < getHeapSize(); heapPosition++) {
int newBlockPos = newHeapBlockBuilder.getPositionCount();
StreamDataEntity heapEntry = minHeap.get(heapPosition);
int oldBlockPosition = getBlockPosition(heapEntry);
type.appendTo(heapBlockBuilder, oldBlockPosition, newHeapBlockBuilder);
newBlockPositionToCount.set(newBlockPos, blockPositionToCount.get(oldBlockPosition));
newBlockToHeapIndex.set(newBlockPos, heapPosition);
hashToBlockPosition.set(heapEntry.getHashPosition(), newBlockPos);
}
blockPositionToCount = newBlockPositionToCount;
heapBlockBuilder = newHeapBlockBuilder;
blockToHeapIndex = newBlockToHeapIndex;
rehash();
}
private void rehash()
{
long newCapacityLong = hashCapacity * 2L;
if (newCapacityLong > Integer.MAX_VALUE) {
throw new PrestoException(GENERIC_INSUFFICIENT_RESOURCES, "Size of hash table cannot exceed 1 billion entries");
}
int newCapacity = (int) newCapacityLong;
int newMask = newCapacity - 1;
IntBigArray newHashToBlockPosition = new IntBigArray(EMPTY);
newHashToBlockPosition.ensureCapacity(newCapacity);
for (int heapPosition = 0; heapPosition < getHeapSize(); heapPosition++) {
StreamDataEntity heapEntry = minHeap.get(heapPosition);
int blockPosition = getBlockPosition(heapEntry);
// find an empty slot for the address
int hashPosition = getBucketId(TypeUtils.hashPosition(type, heapBlockBuilder, blockPosition), newMask);
while (newHashToBlockPosition.get(hashPosition) != EMPTY) {
hashPosition = (hashPosition + 1) & newMask;
}
// record the mapping
newHashToBlockPosition.set(hashPosition, blockPosition);
heapEntry.setHashPosition(hashPosition);
}
hashCapacity = newCapacity;
mask = newMask;
maxFill = calculateMaxFill(newCapacity);
this.hashToBlockPosition = newHashToBlockPosition;
this.blockPositionToCount.ensureCapacity(maxFill);
this.blockToHeapIndex.ensureCapacity(maxFill);
}
private int compare(StreamDataEntity heapValue1, StreamDataEntity heapValue2)
{
int compare = Long.compare(
getCount(heapValue1),
getCount(heapValue2));
if (compare == 0) {
//When the counts are same, we want to consider the previously generated value as minimum
// to prefer it over newly generated value with same count when remove min is called
compare = Long.compare(heapValue1.getGeneration(), heapValue2.getGeneration());
}
return compare;
}
private long getCount(StreamDataEntity heapEntry)
{
return blockPositionToCount.get(getBlockPosition(heapEntry));
}
private int getBlockPosition(StreamDataEntity heapEntry)
{
return hashToBlockPosition.get(heapEntry.getHashPosition());
}
private static int getBucketId(long rawHash, int mask)
{
return ((int) murmurHash3(rawHash)) & mask;
}
public void topK(BlockBuilder out)
{
//get top k heap entries
List<StreamDataEntity> sortedHeapEntries = getTopHeapEntries();
//write data and count to output
BlockBuilder valueBuilder = out.beginBlockEntry();
for (StreamDataEntity heapEntry : sortedHeapEntries) {
type.appendTo(heapBlockBuilder, getBlockPosition(heapEntry), valueBuilder);
BIGINT.writeLong(valueBuilder, getCount(heapEntry));
}
out.closeEntry();
}
private List<StreamDataEntity> getTopHeapEntries()
{
return minHeap.topK(maxBuckets, (heapEntry1, heapEntry2) -> {
int compare = Long.compare(
getCount(heapEntry2),
getCount(heapEntry1));
if (compare == 0) {
return Integer.compare(heapEntry1.getGeneration(), heapEntry2.getGeneration());
}
return compare;
});
}
public void merge(StreamSummary otherStreamSummary)
{
otherStreamSummary.readAllValues(this::add);
}
public void readAllValues(StreamSummaryReader reader)
{
List<StreamDataEntity> heapEntries = getTopHeapEntries();
for (StreamDataEntity heapEntry : heapEntries) {
reader.read(heapBlockBuilder, getBlockPosition(heapEntry), getCount(heapEntry));
}
}
public void serialize(BlockBuilder out)
{
BlockBuilder blockBuilder = out.beginBlockEntry();
if (getHeapSize() > 0) {
BIGINT.writeLong(blockBuilder, maxBuckets);
BIGINT.writeLong(blockBuilder, heapCapacity);
List<StreamDataEntity> sortedHeap = getTopHeapEntries();
BlockBuilder keyItems = blockBuilder.beginBlockEntry();
for (StreamDataEntity heapEntry : sortedHeap) {
type.appendTo(heapBlockBuilder, getBlockPosition(heapEntry), keyItems);
}
blockBuilder.closeEntry();
BlockBuilder valueItems = blockBuilder.beginBlockEntry();
for (StreamDataEntity heapEntry : sortedHeap) {
BIGINT.writeLong(valueItems, getCount(heapEntry));
}
blockBuilder.closeEntry();
}
out.closeEntry();
}
public static StreamSummary deserialize(Type type, Block block)
{
int currentPosition = 0;
int maxBuckets = toIntExact(BIGINT.getLong(block, currentPosition++));
int heapCapacity = toIntExact(BIGINT.getLong(block, currentPosition++));
StreamSummary streamSummary = new StreamSummary(type, maxBuckets, heapCapacity);
Block keysBlock = new ArrayType(type).getObject(block, currentPosition++);
Block valuesBlock = new ArrayType(BIGINT).getObject(block, currentPosition);
for (int position = 0; position < keysBlock.getPositionCount(); position++) {
streamSummary.add(keysBlock, position, valuesBlock.getLong(position));
}
return streamSummary;
}
public long estimatedInMemorySize()
{
return INSTANCE_SIZE + heapBlockBuilder.getRetainedSizeInBytes() + minHeap.estimatedInMemorySize() + blockPositionToCount.sizeOf() +
hashToBlockPosition.sizeOf();
}
private static int calculateMaxFill(int hashSize)
{
checkArgument(hashSize > 0, "hashSize must be greater than 0");
int maxFill = (int) Math.ceil(hashSize * FILL_RATIO);
if (maxFill == hashSize) {
maxFill--;
}
checkArgument(hashSize > maxFill, "hashSize must be larger than maxFill");
return maxFill;
}
@Override
public void indexChanged(StreamDataEntity blockReferenceEntity, int newIndex)
{
blockToHeapIndex.set(getBlockPosition(blockReferenceEntity), newIndex);
}
}