TestKHyperLogLog.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.type.khyperloglog;
import io.airlift.slice.Slice;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.testng.annotations.Test;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.LongStream;
import static io.airlift.slice.testing.SliceAssertions.assertSlicesEqual;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
public class TestKHyperLogLog
{
@Test
public void testCardinality()
throws Exception
{
int trials = 1000;
for (int indexBits = 4; indexBits <= 12; indexBits++) {
Map<Integer, StandardDeviation> errors = new HashMap<>();
int numberOfBuckets = 1 << indexBits;
int maxCardinality = numberOfBuckets * 2;
for (int trial = 0; trial < trials; trial++) {
KHyperLogLog khll = new KHyperLogLog();
for (int cardinality = 1; cardinality <= maxCardinality; cardinality++) {
khll.add(ThreadLocalRandom.current().nextLong(), 0L);
if (cardinality % (numberOfBuckets / 10) == 0) {
// only do this a few times, since computing the cardinality is currently not
// as cheap as it should be
double error = (khll.cardinality() - cardinality) * 1.0 / cardinality;
StandardDeviation stdev = errors.computeIfAbsent(cardinality, k -> new StandardDeviation());
stdev.increment(error);
}
}
}
double expectedStandardError = 1.04 / Math.sqrt(1 << indexBits);
for (Map.Entry<Integer, StandardDeviation> entry : errors.entrySet()) {
// Give an extra error margin. This is mostly a sanity check to catch egregious errors
double realStandardError = entry.getValue().getResult();
assertTrue(realStandardError <= expectedStandardError * 1.1,
String.format("Failed at p = %s, cardinality = %s. Expected std error = %s, actual = %s",
indexBits,
entry.getKey(),
expectedStandardError,
realStandardError));
}
}
}
@Test
public void testMerge()
throws Exception
{
// small vs small
verifyMerge(LongStream.rangeClosed(0, 100), LongStream.rangeClosed(50, 150));
// small vs big
verifyMerge(LongStream.rangeClosed(0, 100), LongStream.rangeClosed(50, 5000));
// big vs small
verifyMerge(LongStream.rangeClosed(50, 5000), LongStream.rangeClosed(0, 100));
// big vs big
verifyMerge(LongStream.rangeClosed(0, 5000), LongStream.rangeClosed(3000, 8000));
}
private void verifyMerge(LongStream one, LongStream two)
{
KHyperLogLog khll1 = new KHyperLogLog();
KHyperLogLog khll2 = new KHyperLogLog();
KHyperLogLog expected = new KHyperLogLog();
long uii;
for (long value : one.toArray()) {
uii = randomLong(100);
khll1.add(value, uii);
expected.add(value, uii);
}
for (long value : two.toArray()) {
uii = randomLong(100);
khll2.add(value, uii);
expected.add(value, uii);
}
KHyperLogLog merged = khll1.mergeWith(khll2);
assertEquals(merged.cardinality(), expected.cardinality());
assertEquals(merged.reidentificationPotential(10), expected.reidentificationPotential(10));
assertSlicesEqual(khll1.serialize(), expected.serialize());
}
@Test
public void testSerialization()
throws Exception
{
// small
verifySerialization(LongStream.rangeClosed(0, 1000));
// large
verifySerialization(LongStream.rangeClosed(0, 200000));
}
private void verifySerialization(LongStream sequence)
{
KHyperLogLog khll = new KHyperLogLog();
for (Long value : sequence.toArray()) {
khll.add(value, (long) (Math.random() * 100));
}
Slice serialized = khll.serialize();
KHyperLogLog deserialized = KHyperLogLog.newInstance(serialized);
assertEquals(khll.cardinality(), deserialized.cardinality());
assertEquals(khll.reidentificationPotential(10), deserialized.reidentificationPotential(10));
Slice reserialized = deserialized.serialize();
assertSlicesEqual(serialized, reserialized);
}
@Test
public void testHistogram()
throws Exception
{
// small
buildHistogramAndVerify(256, 1000);
// large
buildHistogramAndVerify(256, 200000);
}
public void buildHistogramAndVerify(int histogramSize, int count)
{
KHyperLogLog khll = new KHyperLogLog();
Map<Long, HashSet<Long>> map = new HashMap<>();
long uii;
long value;
for (int i = 0; i < count; i++) {
uii = randomLong(histogramSize);
value = randomLong(count);
khll.add(value, uii);
map.computeIfAbsent(value, k -> new HashSet<>()).add(uii);
}
int size = map.size();
Map<Long, Double> realHistogram = new HashMap<>();
for (HashSet<Long> uiis : map.values()) {
long bucket = Math.min(uiis.size(), histogramSize);
realHistogram.merge(bucket, (double) 1 / size, Double::sum);
}
Map<Long, Double> khllHistogram = khll.uniquenessDistribution(histogramSize);
verifyUniquenessDistribution(realHistogram, khllHistogram);
verifyReidentificationPotential(map, khll);
}
public void verifyUniquenessDistribution(Map<Long, Double> realHistogram, Map<Long, Double> khllHistogram)
{
double estimated = 0.0;
double real = 0.0;
int histogramSize = realHistogram.size();
for (long i = 1; i < histogramSize; i++) {
estimated += khllHistogram.get(i);
real += realHistogram.getOrDefault(i, 0.0);
assertTrue(Math.abs(estimated - real) <= 0.1 * real,
format("Expected histogram value %f +/- 10%%, got %f, for bucket %d", real, estimated, i));
}
}
public void verifyReidentificationPotential(Map<Long, HashSet<Long>> map, KHyperLogLog khll)
{
double estimated;
double real;
int size = map.size();
for (int threshold = 1; threshold < 10; threshold++) {
estimated = khll.reidentificationPotential(threshold);
real = 0.0;
for (HashSet<Long> uiis : map.values()) {
if (uiis.size() <= threshold) {
real++;
}
}
real /= size;
assertTrue(Math.abs(estimated - real) <= 0.1 * real,
format("Expected reidentification potential %f +/- 10%%, got %f, for set of size %d", real, estimated, size));
}
}
private long randomLong(int range)
{
return (long) (Math.pow(Math.random(), 2.0) * range);
}
}