TestSetDigest.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.type.setdigest;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.SingleMapBlock;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.MapType;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;
import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import static com.facebook.presto.common.block.MethodHandleUtil.compose;
import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle;
import static com.facebook.presto.type.setdigest.SetDigest.DEFAULT_MAX_HASHES;
import static com.facebook.presto.type.setdigest.SetDigest.NUMBER_OF_BUCKETS;
import static com.facebook.presto.type.setdigest.SetDigestFunctions.hashCounts;
import static com.facebook.presto.type.setdigest.SetDigestFunctions.intersectionCardinality;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
public class TestSetDigest
{
private static final MethodHandle KEY_NATIVE_EQUALS = getOperatorMethodHandle(OperatorType.EQUAL, BIGINT, BIGINT);
private static final MethodHandle KEY_BLOCK_EQUALS = compose(KEY_NATIVE_EQUALS, nativeValueGetter(BIGINT), nativeValueGetter(BIGINT));
private static final MethodHandle KEY_NATIVE_HASH_CODE = getOperatorMethodHandle(OperatorType.HASH_CODE, BIGINT);
private static final MethodHandle KEY_BLOCK_HASH_CODE = compose(KEY_NATIVE_HASH_CODE, nativeValueGetter(BIGINT));
@Test
public void testIntersectionCardinality()
throws Exception
{
testIntersectionCardinality(DEFAULT_MAX_HASHES, NUMBER_OF_BUCKETS, DEFAULT_MAX_HASHES, NUMBER_OF_BUCKETS);
}
@Test
public void testUnevenIntersectionCardinality()
throws Exception
{
testIntersectionCardinality(DEFAULT_MAX_HASHES / 4, NUMBER_OF_BUCKETS, DEFAULT_MAX_HASHES, NUMBER_OF_BUCKETS);
}
private static void testIntersectionCardinality(int maxHashes1, int numBuckets1, int maxHashes2, int numBuckets2)
throws Exception
{
List<Integer> sizes = new ArrayList<>();
Random rand = new Random(0);
// Generate random size from each power of ten in [10, 100,000,000]
for (int i = 10; i < 100_000_000; i *= 10) {
sizes.add(rand.nextInt(i) + 10);
}
for (int size : sizes) {
int expectedCardinality = 0;
SetDigest digest1 = new SetDigest(maxHashes1, numBuckets1);
SetDigest digest2 = new SetDigest(maxHashes2, numBuckets2);
for (int j = 0; j < size; j++) {
int added = 0;
long value = rand.nextLong();
if (rand.nextDouble() < 0.5) {
digest1.add(value);
added++;
}
if (rand.nextDouble() < 0.5) {
digest2.add(value);
added++;
}
if (added == 2) {
expectedCardinality++;
}
}
long estimatedCardinality = intersectionCardinality(digest1.serialize(), digest2.serialize());
assertTrue(Math.abs(expectedCardinality - estimatedCardinality) / (double) expectedCardinality < 0.10,
format("Expected intersection cardinality %d +/- 10%%, got %d, for set of size %d", expectedCardinality, estimatedCardinality, size));
}
}
@Test
public void testHashCounts()
{
SetDigest digest1 = new SetDigest();
digest1.add(0);
digest1.add(0);
digest1.add(1);
SetDigest digest2 = new SetDigest();
digest2.add(0);
digest2.add(0);
digest2.add(2);
digest2.add(2);
MapType mapType = new MapType(BIGINT, SMALLINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE);
Block block = hashCounts(mapType, digest1.serialize());
assertTrue(block instanceof SingleMapBlock);
Set<Short> blockValues = new HashSet<>();
for (int i = 1; i < block.getPositionCount(); i += 2) {
blockValues.add(block.getShort(i));
}
Set<Short> expected = ImmutableSet.of((short) 1, (short) 2);
assertEquals(blockValues, expected);
digest1.mergeWith(digest2);
block = hashCounts(mapType, digest1.serialize());
assertTrue(block instanceof SingleMapBlock);
expected = ImmutableSet.of((short) 1, (short) 2, (short) 4);
blockValues = new HashSet<>();
for (int i = 1; i < block.getPositionCount(); i += 2) {
blockValues.add(block.getShort(i));
}
assertEquals(blockValues, expected);
}
@Test
public void testSmallLargeIntersections()
{
List<Integer> sizes = new ArrayList<>();
Random rand = new Random(0);
for (int i = 1000; i < 1_000_000; i *= 10) {
sizes.add(rand.nextInt(i) + 10);
}
for (int size1 : sizes) {
SetDigest digest1 = new SetDigest(DEFAULT_MAX_HASHES, NUMBER_OF_BUCKETS);
Map<SetDigest, Integer> smallerSets = new HashMap<>();
for (int size2 : sizes) {
if (size2 >= size1) {
break;
}
for (int overlap = 2; overlap <= 10; overlap += 2) {
int expectedCardinality = 0;
SetDigest digest2 = new SetDigest(DEFAULT_MAX_HASHES, NUMBER_OF_BUCKETS);
for (int j = 0; j < size1; j++) {
long value = rand.nextLong();
digest1.add(value);
if (rand.nextDouble() < size2 / (double) size1) {
if (rand.nextDouble() * 10 < overlap) {
digest2.add(value);
expectedCardinality++;
}
else {
digest2.add(rand.nextLong());
}
}
}
smallerSets.put(digest2, expectedCardinality);
}
}
for (Map.Entry<SetDigest, Integer> pair : smallerSets.entrySet()) {
SetDigest digest2 = pair.getKey();
long estIntersectionCardinality =
intersectionCardinality(digest1.serialize(), digest2.serialize());
double size2 = digest2.cardinality();
assertTrue(estIntersectionCardinality <= size2);
int expectedCardinality = pair.getValue();
assertTrue(Math.abs(expectedCardinality - estIntersectionCardinality) /
(double) size1 < 0.05);
}
}
}
}