SfmSketch.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.noisyaggregation.sketch;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Murmur3Hash128;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import org.openjdk.jol.info.ClassLayout;
import javax.annotation.concurrent.NotThreadSafe;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
/**
* SfmSketch is a sketch for distinct counting, very similar to HyperLogLog.
* This sketch is introduced as the Sketch-Flip-Merge (SFM) summary in the paper
* <a href="https://arxiv.org/pdf/2302.02056.pdf">Sketch-Flip-Merge: Mergeable Sketches for Private Distinct Counting</a>.
* <p>
* The primary differences between SfmSketch and HyperLogLog are that
* (a) SfmSketch supports differential privacy, and
* (b) where HyperLogLog tracks only max observed bucket values, SfmSketch tracks all bucket values observed.
* <p>
* This means that SfmSketch is a larger sketch than HyperLogLog, but offers the ability to store completely
* DP sketches with a fixed, public hash function while maintaining accurate cardinality estimates.
* <p>
* SfmSketch is created in a non-private mode. Privacy must be enabled through the enablePrivacy() function.
* Once made private, the sketch becomes immutable. Privacy is quantified by the parameter epsilon.
* <p>
* When epsilon > 0, the sketch is epsilon-DP, and bits are randomized to preserve privacy.
* When epsilon == NON_PRIVATE_EPSILON, the sketch is not private, and bits are set deterministically.
* <p>
* The best accuracy comes with NON_PRIVATE_EPSILON. For private epsilons, larger gives more accuracy,
* while smaller gives more privacy.
*/
@NotThreadSafe
public class SfmSketch
{
public static final double NON_PRIVATE_EPSILON = Double.POSITIVE_INFINITY;
private static final byte FORMAT_TAG = 7;
private static final int MAX_ESTIMATION_ITERATIONS = 1000;
private static final int INSTANCE_SIZE = ClassLayout.parseClass(SfmSketch.class).instanceSize();
private final int indexBitLength;
private final int precision;
private double randomizedResponseProbability;
private final Bitmap bitmap;
private SfmSketch(Bitmap bitmap, int indexBitLength, int precision, double randomizedResponseProbability)
{
requireNonNull(bitmap, "bitmap cannot be null");
validatePrefixLength(indexBitLength);
validatePrecision(precision, indexBitLength);
validateRandomizedResponseProbability(randomizedResponseProbability);
this.bitmap = bitmap;
this.indexBitLength = indexBitLength;
this.precision = precision;
this.randomizedResponseProbability = randomizedResponseProbability;
}
/**
* Create a new SfmSketch in non-private mode. To make private,
* call enablePrivacy() after populating the sketch.
*/
public static SfmSketch create(int numberOfBuckets, int precision)
{
// Only create non-private sketches.
// Private sketches are immutable, so they're kind of useless to create.
double randomizedResponseProbability = getRandomizedResponseProbability(NON_PRIVATE_EPSILON);
int indexBitLength = indexBitLength(numberOfBuckets);
Bitmap bitmap = new Bitmap(numberOfBuckets * precision);
return new SfmSketch(bitmap, indexBitLength, precision, randomizedResponseProbability);
}
public static SfmSketch deserialize(Slice serialized)
{
// Format:
// format | indexBitLength | precision | epsilon | bitmap
BasicSliceInput input = serialized.getInput();
byte format = input.readByte();
checkArgument(format == FORMAT_TAG, "Wrong format tag");
int indexBitLength = input.readInt();
int precision = input.readInt();
double randomizedResponseProbability = input.readDouble();
int bitmapByteLength = input.readInt();
Bitmap bitmap = Bitmap.fromSliceInput(input, bitmapByteLength, numberOfBuckets(indexBitLength) * precision);
return new SfmSketch(bitmap, indexBitLength, precision, randomizedResponseProbability);
}
public void add(long value)
{
addHash(Murmur3Hash128.hash64(value));
}
public void add(Slice value)
{
addHash(Murmur3Hash128.hash64(value));
}
public void addHash(long hash)
{
int index = computeIndex(hash, indexBitLength);
// cap zeros at precision - 1
// essentially, we're looking at a (precision - 1)-bit hash
int zeros = Math.min(precision - 1, numberOfTrailingZeros(hash, indexBitLength));
flipBitOn(index, zeros);
}
public void addIndexAndZeros(long index, long zeros)
{
long buckets = numberOfBuckets(indexBitLength);
checkArgument(index >= 0 && index < buckets,
"index %s must be between zero (inclusive) and the number of buckets-1 %s", index, buckets - 1);
checkArgument(zeros >= 0 && zeros <= 64,
"zeros %s must be between 0 and 64", zeros);
// cap zeros at precision - 1
// essentially, we're looking at a (precision - 1)-bit hash
zeros = Math.min(precision - 1, zeros);
flipBitOn((int) index, (int) zeros);
}
/**
* Estimates cardinality via maximum psuedolikelihood (Newton's method)
*/
public long cardinality()
{
// The initial guess of 1 may seem awful, but this converges quickly, and starting small returns better results for small cardinalities.
// This generally takes <= 40 iterations, even for cardinalities as large as 10^33.
double guess = 1;
double changeInGuess = Double.POSITIVE_INFINITY;
int iterations = 0;
while (Math.abs(changeInGuess) > 0.1 && iterations < MAX_ESTIMATION_ITERATIONS) {
changeInGuess = -logLikelihoodFirstDerivative(guess) / logLikelihoodSecondDerivative(guess);
guess += changeInGuess;
iterations += 1;
}
return Math.max(0, Math.round(guess));
}
public static int computeIndex(long hash, int indexBitLength)
{
return (int) (hash >>> (Long.SIZE - indexBitLength));
}
/**
* Enable privacy on a non-privacy-enabled sketch
* <p>
* Per Lemma 4.7, <a href="https://arxiv.org/pdf/2302.02056.pdf">arXiv:2302.02056</a>,
* flipping every bit with probability 1/(e^epsilon + 1) achieves differential privacy.
*/
public void enablePrivacy(double epsilon)
{
enablePrivacy(epsilon, getDefaultRandomizationStrategy());
}
public void enablePrivacy(double epsilon, RandomizationStrategy randomizationStrategy)
{
requireNonNull(randomizationStrategy, "randomizationStrategy cannot be null");
checkArgument(!isPrivacyEnabled(), "sketch is already privacy-enabled");
validateEpsilon(epsilon);
randomizedResponseProbability = getRandomizedResponseProbability(epsilon);
// Flip every bit with fixed probability
bitmap.flipAll(randomizedResponseProbability, randomizationStrategy);
}
public int estimatedSerializedSize()
{
return SizeOf.SIZE_OF_BYTE + // type + version
SizeOf.SIZE_OF_INT + // indexBitLength
SizeOf.SIZE_OF_INT + // precision
SizeOf.SIZE_OF_DOUBLE + // randomized response probability
SizeOf.SIZE_OF_INT + // bitmap byte length
(bitmap.byteLength() * SizeOf.SIZE_OF_BYTE); // bitmap
}
private void flipBitOn(int bucket, int level)
{
checkArgument(!isPrivacyEnabled(), "privacy-enabled SfmSketch is immutable");
int i = getBitLocation(bucket, level);
bitmap.setBit(i, true);
}
@VisibleForTesting
int getBitLocation(int bucket, int level)
{
return level * numberOfBuckets(indexBitLength) + bucket;
}
public Bitmap getBitmap()
{
return bitmap;
}
private static RandomizationStrategy getDefaultRandomizationStrategy()
{
return new SecureRandomizationStrategy();
}
@VisibleForTesting
double getOnProbability()
{
// probability of a 1-bit remaining a 1-bit under randomized response
return 1 - randomizedResponseProbability;
}
static double getRandomizedResponseProbability(double epsilon)
{
// If non-private, we don't use randomized response.
// Otherwise, flip bits with probability 1/(exp(epsilon) + 1).
if (epsilon == NON_PRIVATE_EPSILON) {
return 0;
}
return 1.0 / (Math.exp(epsilon) + 1);
}
@VisibleForTesting
double getRandomizedResponseProbability()
{
// probability of a 0-bit flipping to a 1-bit under randomized response
return randomizedResponseProbability;
}
public long getRetainedSizeInBytes()
{
return INSTANCE_SIZE + bitmap.getRetainedSizeInBytes();
}
public static int indexBitLength(int numberOfBuckets)
{
Preconditions.checkArgument(isPowerOf2(numberOfBuckets), "numberOfBuckets must be a power of 2, actual: %s", numberOfBuckets);
// 2**N has N trailing zeros, and we've asserted numberOfBuckets == 2**N
return Integer.numberOfTrailingZeros(numberOfBuckets);
}
public static boolean isPowerOf2(long value)
{
Preconditions.checkArgument(value > 0, "value must be positive");
return (value & (value - 1)) == 0;
}
public boolean isPrivacyEnabled()
{
return getRandomizedResponseProbability() > 0;
}
private double logLikelihoodFirstDerivative(double n)
{
// Technically, this is the first derivative of the log of a psuedolikelihood.
double result = 0;
for (int level = 0; level < precision; level++) {
double termOn = logLikelihoodTermFirstDerivative(level, true, n);
double termOff = logLikelihoodTermFirstDerivative(level, false, n);
for (int bucket = 0; bucket < numberOfBuckets(indexBitLength); bucket++) {
result += bitmap.getBit(getBitLocation(bucket, level)) ? termOn : termOff;
}
}
return result;
}
private double logLikelihoodTermFirstDerivative(int level, boolean on, double n)
{
double p = observationProbability(level);
int sign = on ? -1 : 1;
double c1 = on ? getOnProbability() : 1 - getOnProbability();
double c2 = getOnProbability() - getRandomizedResponseProbability();
return Math.log1p(-p) * (1 - c1 / (c1 + sign * c2 * Math.pow(1 - p, n)));
}
private double logLikelihoodSecondDerivative(double n)
{
// Technically, this is the second derivative of the log of a psuedolikelihood.
double result = 0;
for (int level = 0; level < precision; level++) {
double termOn = logLikelihoodTermSecondDerivative(level, true, n);
double termOff = logLikelihoodTermSecondDerivative(level, false, n);
for (int bucket = 0; bucket < numberOfBuckets(indexBitLength); bucket++) {
result += bitmap.getBit(getBitLocation(bucket, level)) ? termOn : termOff;
}
}
return result;
}
private double logLikelihoodTermSecondDerivative(int level, boolean on, double n)
{
double p = observationProbability(level);
int sign = on ? -1 : 1;
double c1 = on ? getOnProbability() : 1 - getOnProbability();
double c2 = getOnProbability() - getRandomizedResponseProbability();
return sign * c1 * c2 * Math.pow(Math.log1p(-p), 2) * Math.pow(1 - p, n) * Math.pow(c1 + sign * c2 * Math.pow(1 - p, n), -2);
}
/**
* Merging two sketches with randomizedResponseProbability values p1 and p2 is equivalent to
* having created two non-private sketches, merged them, then enabled privacy with a
* randomizedResponseProbability value of:
* <p>
* (p1 + p2 - 3 * p1 * p2) / (1 - 2 * p1 * p2)
* <p>
* This can be derived from the fact that two private sketches created with epsilon1 and epsilon2
* merge to be equivalent to a single sketch created with epsilon:
* <p>
* -log(exp(-epsilon1) + exp(-epsilon2) - exp(-(epsilon1 + epsilon2))
* <p>
* For details, see Theorem 4.8, <a href="https://arxiv.org/pdf/2302.02056.pdf">arXiv:2302.02056</a>.
* For verification, see the unit tests.
*/
@VisibleForTesting
static double mergeRandomizedResponseProbabilities(double p1, double p2)
{
return (p1 + p2 - 3 * p1 * p2) / (1 - 2 * p1 * p2);
}
/**
* Performs a merge of the other sketch into the current sketch. This is performed
* as a randomized merge as described in Theorem 4.8,
* <a href="https://arxiv.org/pdf/2302.02056.pdf">arXiv:2302.02056</a>.
* <p>
* The formula used in this function is a simplification of the form presented in the original paper.
* See also Section 3, <a href="https://arxiv.org/pdf/2306.09394.pdf">arXiv:2306.09394</a>.
*/
public void mergeWith(SfmSketch other)
{
// Using ThreadLocalRandomizationStrategy instead of
// SecureRandomizationStrategy since combining sketches
// does not need to be cryptographically secure
mergeWith(other, new ThreadLocalRandomizationStrategy());
}
public void mergeWith(SfmSketch other, RandomizationStrategy randomizationStrategy)
{
requireNonNull(randomizationStrategy, "randomizationStrategy cannot be null");
// Strictly speaking, we may be able to provide more general merging than suggested here.
// It's not clear how useful this would be in practice.
checkArgument(precision == other.precision, "cannot merge two SFM sketches with different precision: %s vs. %s", precision, other.precision);
checkArgument(indexBitLength == other.indexBitLength, "cannot merge two SFM sketches with different indexBitLength: %s vs. %s",
indexBitLength, other.indexBitLength);
if (!isPrivacyEnabled() && !other.isPrivacyEnabled()) {
// if neither sketch is private, we just take the OR of the sketches
bitmap.or(other.getBitmap());
}
else {
// if either sketch is private, we combine using a randomized merge
// (the non-private case above is a special case of this more complicated math)
double p1 = randomizedResponseProbability;
double p2 = other.randomizedResponseProbability;
double p = mergeRandomizedResponseProbabilities(p1, p2);
double normalizer = (1 - 2 * p) / ((1 - 2 * p1) * (1 - 2 * p2));
for (int i = 0; i < bitmap.length(); i++) {
double bit1 = bitmap.getBit(i) ? 1 : 0;
double bit2 = other.bitmap.getBit(i) ? 1 : 0;
double x = 1 - 2 * p - normalizer * (1 - p1 - bit1) * (1 - p2 - bit2);
double probability = p + normalizer * x;
probability = Math.min(1.0, Math.max(0.0, probability));
bitmap.setBit(i, randomizationStrategy.nextBoolean(probability));
}
}
randomizedResponseProbability = mergeRandomizedResponseProbabilities(randomizedResponseProbability, other.randomizedResponseProbability);
}
public static int numberOfBuckets(int indexBitLength)
{
return 1 << indexBitLength;
}
public static int numberOfTrailingZeros(long hash, int indexBitLength)
{
long value = hash | (1L << (Long.SIZE - indexBitLength)); // place a 1 in the final position of the prefix to avoid flowing into prefix when the hash happens to be 0
return Long.numberOfTrailingZeros(value);
}
private double observationProbability(int level)
{
// probability of observing a run of zeros of length level in any single bucket
// note: this is NOT (in general) the probability of having a 1 in the corresponding location in the sketch
// (it is if bits are set deterministically, as when epsilon < 0)
return Math.pow(2.0, -(level + 1)) / numberOfBuckets(indexBitLength);
}
public Slice serialize()
{
DynamicSliceOutput sliceOutput = new DynamicSliceOutput(estimatedSerializedSize());
serialize(sliceOutput);
return sliceOutput.slice();
}
public void serialize(DynamicSliceOutput sliceOutput)
{
byte[] bitmapBytes = bitmap.toBytes();
sliceOutput.appendByte(FORMAT_TAG)
.appendInt(indexBitLength)
.appendInt(precision)
.appendDouble(randomizedResponseProbability)
.appendInt(bitmapBytes.length)
.appendBytes(bitmapBytes);
}
private static void validateEpsilon(double epsilon)
{
checkArgument(epsilon > 0, "epsilon must be greater than zero or equal to NON_PRIVATE_EPSILON");
}
private static void validatePrecision(int precision, int indexBitLength)
{
checkArgument(precision > 0, "precision must be positive", Byte.SIZE);
checkArgument(precision + indexBitLength <= Long.SIZE, "precision + indexBitLength cannot exceed %s", Long.SIZE);
}
private static void validatePrefixLength(int indexBitLength)
{
checkArgument(indexBitLength >= 1 && indexBitLength <= 32, "indexBitLength is out of range");
}
private static void validateRandomizedResponseProbability(double p)
{
checkArgument(p >= 0 && p <= 0.5, "randomizedResponseProbability should be in the interval [0, 0.5]");
}
}