SetDigest.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.type.setdigest;

import com.facebook.airlift.stats.cardinality.HyperLogLog;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import com.google.common.primitives.Shorts;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Murmur3Hash128;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import it.unimi.dsi.fastutil.longs.Long2ShortRBTreeMap;
import it.unimi.dsi.fastutil.longs.Long2ShortSortedMap;
import it.unimi.dsi.fastutil.longs.LongBidirectionalIterator;
import it.unimi.dsi.fastutil.longs.LongRBTreeSet;
import it.unimi.dsi.fastutil.longs.LongSortedSet;
import org.openjdk.jol.info.ClassLayout;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.slice.SizeOf.SIZE_OF_BYTE;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.airlift.slice.SizeOf.SIZE_OF_SHORT;
import static java.util.Objects.requireNonNull;

/**
 * For the MinHash algorithm, see "On the resemblance and containment of documents" by Andrei Z. Broder,
 * and the Wikipedia page: http://en.wikipedia.org/wiki/MinHash#Variant_with_a_single_hash_function
 */
public class SetDigest
{
    private static final byte UNCOMPRESSED_FORMAT = 1;
    public static final int NUMBER_OF_BUCKETS = 2048;
    public static final int DEFAULT_MAX_HASHES = 8192;
    private static final int SIZE_OF_ENTRY = SIZE_OF_LONG + SIZE_OF_SHORT;
    private static final int SIZE_OF_SETDIGEST = ClassLayout.parseClass(SetDigest.class).instanceSize();
    private static final int SIZE_OF_RBTREEMAP = ClassLayout.parseClass(Long2ShortRBTreeMap.class).instanceSize();

    private final HyperLogLog hll;
    private final Long2ShortSortedMap minhash;
    private final int maxHashes;

    public SetDigest()
    {
        this(DEFAULT_MAX_HASHES, HyperLogLog.newInstance(NUMBER_OF_BUCKETS), new Long2ShortRBTreeMap());
    }

    public SetDigest(int maxHashes, int numHllBuckets)
    {
        this(maxHashes, HyperLogLog.newInstance(numHllBuckets), new Long2ShortRBTreeMap());
    }

    public SetDigest(int maxHashes, HyperLogLog hll, Long2ShortSortedMap minhash)
    {
        this.maxHashes = maxHashes;
        this.hll = requireNonNull(hll, "hll is null");
        this.minhash = requireNonNull(minhash, "minhash is null");
    }

    public static SetDigest newInstance(Slice serialized)
    {
        requireNonNull(serialized, "serialized is null");
        SliceInput input = serialized.getInput();
        checkArgument(input.readByte() == UNCOMPRESSED_FORMAT, "Unexpected version");

        int hllLength = input.readInt();
        Slice serializedHll = Slices.allocate(hllLength);
        input.readBytes(serializedHll, hllLength);
        HyperLogLog hll = HyperLogLog.newInstance(serializedHll);

        Long2ShortRBTreeMap minhash = new Long2ShortRBTreeMap();
        int maxHashes = input.readInt();
        int minhashLength = input.readInt();
        // The values are stored after the keys
        SliceInput valuesInput = serialized.getInput();
        valuesInput.setPosition(input.position() + minhashLength * SIZE_OF_LONG);

        for (int i = 0; i < minhashLength; i++) {
            minhash.put(input.readLong(), valuesInput.readShort());
        }

        return new SetDigest(maxHashes, hll, minhash);
    }

    public Slice serialize()
    {
        try (SliceOutput output = new DynamicSliceOutput(estimatedSerializedSize())) {
            output.appendByte(UNCOMPRESSED_FORMAT);
            Slice serializedHll = hll.serialize();
            output.appendInt(serializedHll.length());
            output.appendBytes(serializedHll);
            output.appendInt(maxHashes);
            output.appendInt(minhash.size());
            for (long key : minhash.keySet()) {
                output.appendLong(key);
            }
            for (short value : minhash.values()) {
                output.appendShort(value);
            }
            return output.slice();
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public HyperLogLog getHll()
    {
        return hll;
    }

    public int estimatedInMemorySize()
    {
        return hll.estimatedInMemorySize() + minhash.size() * SIZE_OF_ENTRY + SIZE_OF_SETDIGEST + SIZE_OF_RBTREEMAP;
    }

    public int estimatedSerializedSize()
    {
        return SIZE_OF_BYTE + SIZE_OF_INT + hll.estimatedSerializedSize() + 2 * SIZE_OF_INT + minhash.size() * SIZE_OF_ENTRY;
    }

    public boolean isExact()
    {
        // There's an ambiguity when minhash.size() == maxHashes, since this could either
        // be an exact set with maxHashes elements, or an inexact one. Which is why strict
        // inequality is used here.
        return minhash.size() < maxHashes;
    }

    public long cardinality()
    {
        if (isExact()) {
            return minhash.size();
        }
        return hll.cardinality();
    }

    public static long exactIntersectionCardinality(SetDigest a, SetDigest b)
    {
        checkState(a.isExact(), "exact intersection cannot operate on approximate sets");
        checkArgument(b.isExact(), "exact intersection cannot operate on approximate sets");

        return Sets.intersection(a.minhash.keySet(), b.minhash.keySet()).size();
    }

    public static double jaccardIndex(SetDigest a, SetDigest b)
    {
        int sizeOfSmallerSet = Math.min(a.minhash.size(), b.minhash.size());
        LongSortedSet minUnion = new LongRBTreeSet(a.minhash.keySet());
        minUnion.addAll(b.minhash.keySet());

        int intersection = 0;
        int i = 0;
        for (long key : minUnion) {
            if (a.minhash.containsKey(key) && b.minhash.containsKey(key)) {
                intersection++;
            }
            i++;
            if (i >= sizeOfSmallerSet) {
                break;
            }
        }
        return intersection / (double) sizeOfSmallerSet;
    }

    public void add(long value)
    {
        addHash(Murmur3Hash128.hash64(value));
        hll.add(value);
    }

    public void add(Slice value)
    {
        addHash(Murmur3Hash128.hash64(value));
        hll.add(value);
    }

    private void addHash(long hash)
    {
        short value = minhash.get(hash);
        if (value < Short.MAX_VALUE) {
            minhash.put(hash, (short) (value + 1));
        }
        while (minhash.size() > maxHashes) {
            minhash.remove(minhash.lastLongKey());
        }
    }

    public void mergeWith(SetDigest other)
    {
        hll.mergeWith(other.hll);
        LongBidirectionalIterator iterator = other.minhash.keySet().iterator();
        while (iterator.hasNext()) {
            long key = iterator.nextLong();
            int count = minhash.get(key) + other.minhash.get(key);
            minhash.put(key, Shorts.saturatedCast(count));
        }
        while (minhash.size() > maxHashes) {
            minhash.remove(minhash.lastLongKey());
        }
    }

    public Map<Long, Short> getHashCounts()
    {
        return ImmutableMap.copyOf(minhash);
    }
}