TestKllSketchAggregationFunction.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.SqlVarbinary;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.aggregation.sketch.kll.KllSketchAggregationState;
import com.facebook.presto.operator.scalar.AbstractTestFunctions;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.apache.datasketches.common.ArrayOfDoublesSerDe;
import org.apache.datasketches.kll.KllItemsSketch;
import org.apache.datasketches.memory.WritableMemory;
import org.intellij.lang.annotations.Language;
import org.testng.Assert.ThrowingRunnable;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateTimeEncoding.packDateTimeWithZone;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TimeType.TIME;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
public class TestKllSketchAggregationFunction
extends AbstractTestFunctions
{
private static final MetadataManager metadata = MetadataManager.createTestMetadataManager();
private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = metadata.getFunctionAndTypeManager();
private static final JavaAggregationFunctionImplementation DOUBLE_FUNCTION = getFunction(DOUBLE);
private static final JavaAggregationFunctionImplementation DOUBLE_WITH_K_FUNCTION = getFunction("sketch_kll_with_k", DOUBLE, BIGINT);
@Test
public void testDouble()
{
double[] items = DoubleStream.iterate(0, i -> i + ThreadLocalRandom.current().nextDouble()).limit(100).toArray();
BlockBuilder out = DOUBLE.createBlockBuilder(null, items.length);
KllItemsSketch<Double> sketch = KllItemsSketch.newHeapInstance(Double::compareTo, new ArrayOfDoublesSerDe());
Arrays.stream(items).forEach(item -> {
DOUBLE.writeDouble(out, item);
sketch.update(item);
});
Block input = out.build();
SqlVarbinary result = (SqlVarbinary) AggregationTestUtils.executeAggregation(
DOUBLE_FUNCTION,
input);
KllItemsSketch<Double> recreated = KllItemsSketch.wrap(WritableMemory.writableWrap((result.getBytes())), Double::compareTo, new ArrayOfDoublesSerDe());
checkSketchesEqual(Arrays.stream(items).boxed().collect(toImmutableList()), sketch, recreated);
}
@Test
public void testDoubleWithK()
{
double[] items = DoubleStream.iterate(0, i -> i + ThreadLocalRandom.current().nextDouble()).limit(100).toArray();
BlockBuilder out = DOUBLE.createBlockBuilder(null, items.length);
BlockBuilder kBlock = BIGINT.createBlockBuilder(null, items.length);
int k = 150;
KllItemsSketch<Double> sketch = KllItemsSketch.newHeapInstance(k, Double::compareTo, new ArrayOfDoublesSerDe());
Arrays.stream(items).forEach(item -> {
DOUBLE.writeDouble(out, item);
sketch.update(item);
BIGINT.writeLong(kBlock, k);
});
Block input = out.build();
SqlVarbinary result = (SqlVarbinary) AggregationTestUtils.executeAggregation(
DOUBLE_WITH_K_FUNCTION,
input,
kBlock.build());
KllItemsSketch<Double> recreated = KllItemsSketch.wrap(WritableMemory.writableWrap((result.getBytes())), Double::compareTo, new ArrayOfDoublesSerDe());
checkSketchesEqual(DoubleStream.of(items).boxed().collect(Collectors.toList()), sketch, recreated);
}
@Test
public void testInvalidK()
{
double[] items = DoubleStream.iterate(0, i -> i + ThreadLocalRandom.current().nextDouble()).limit(10).toArray();
BlockBuilder inputBlock = DOUBLE.createBlockBuilder(null, items.length);
BlockBuilder kBlockLow = BIGINT.createBlockBuilder(null, items.length);
Arrays.stream(items).forEach(item -> {
DOUBLE.writeDouble(inputBlock, item);
BIGINT.writeLong(kBlockLow, 7);
});
Block input = inputBlock.build();
assertThrows(() -> AggregationTestUtils.executeAggregation(
DOUBLE_WITH_K_FUNCTION,
inputBlock.build(),
kBlockLow.build()), PrestoException.class, "k value must satisfy 8 <= k <= 65535: 7");
BlockBuilder kBlockHigh = BIGINT.createBlockBuilder(null, items.length);
Arrays.stream(items).forEach(item -> {
BIGINT.writeLong(kBlockHigh, 65536);
});
assertThrows(() -> AggregationTestUtils.executeAggregation(
DOUBLE_WITH_K_FUNCTION,
input,
kBlockHigh.build()), PrestoException.class, "k value must satisfy 8 <= k <= 65535: 65536");
}
@DataProvider(name = "testTypes")
public Object[][] testTypesProvider()
{
return new Object[][] {
{TINYINT, (Supplier<Object>) () -> (long) ThreadLocalRandom.current().nextInt(0, Byte.MAX_VALUE),
(BiConsumer<BlockBuilder, Long>) TINYINT::writeLong},
{SMALLINT, (Supplier<Object>) () -> (long) ThreadLocalRandom.current().nextInt(0, Short.MAX_VALUE),
(BiConsumer<BlockBuilder, Long>) SMALLINT::writeLong},
{INTEGER, (Supplier<Object>) () -> (long) ThreadLocalRandom.current().nextInt(),
(BiConsumer<BlockBuilder, Long>) INTEGER::writeLong},
{BIGINT, (Supplier<Object>) () -> (long) ThreadLocalRandom.current().nextLong(),
(BiConsumer<BlockBuilder, Long>) BIGINT::writeLong},
{REAL, (Supplier<Object>) () -> (long) Float.floatToIntBits(ThreadLocalRandom.current().nextFloat()),
(BiConsumer<BlockBuilder, Long>) REAL::writeLong},
{DOUBLE, (Supplier<Object>) () -> ThreadLocalRandom.current().nextDouble(), (BiConsumer<BlockBuilder, Double>) DOUBLE::writeDouble},
{VARCHAR, (Supplier<Object>) () -> Slices.utf8Slice(String.valueOf("abcdefghijklmnopqrstuvwxyz".charAt(ThreadLocalRandom.current().nextInt(26)))),
(BiConsumer<BlockBuilder, Slice>) VARCHAR::writeSlice},
{BOOLEAN, (Supplier<Object>) () -> ThreadLocalRandom.current().nextBoolean(), (BiConsumer<BlockBuilder, Boolean>) BOOLEAN::writeBoolean},
{DATE, (Supplier<Object>) () -> ThreadLocalRandom.current().nextLong(0, 100), (BiConsumer<BlockBuilder, Long>) DATE::writeLong},
{TIME, (Supplier<Object>) () -> ThreadLocalRandom.current().nextLong(0, 100), (BiConsumer<BlockBuilder, Long>) TIME::writeLong},
{TIMESTAMP, (Supplier<Object>) () -> ThreadLocalRandom.current().nextLong(0, 100), (BiConsumer<BlockBuilder, Long>) TIMESTAMP::writeLong},
{TIMESTAMP_WITH_TIME_ZONE, (Supplier<Object>) () -> packDateTimeWithZone(ThreadLocalRandom.current().nextLong(0, 100), TimeZoneKey.UTC_KEY),
(BiConsumer<BlockBuilder, Long>) TIMESTAMP_WITH_TIME_ZONE::writeLong},
{INTERVAL_YEAR_MONTH, (Supplier<Object>) () -> ThreadLocalRandom.current().nextLong(0, 100), (BiConsumer<BlockBuilder, Long>) INTERVAL_YEAR_MONTH::writeLong}
};
}
@Test(dataProvider = "testTypes")
public void testTypes(Type type, Supplier<Object> values, BiConsumer<BlockBuilder, Object> writeBlockValue)
{
int length = 100;
JavaAggregationFunctionImplementation function = getFunction(type);
BlockBuilder out = type.createBlockBuilder(null, length);
KllSketchAggregationState.SketchParameters parameters = KllSketchAggregationState.getSketchParameters(type);
KllItemsSketch sketch = KllItemsSketch.<Object>newHeapInstance(parameters.getComparator(), parameters.getSerde());
List addedValues = Stream.generate(values).limit(length)
.map(item -> {
writeBlockValue.accept(out, item);
sketch.update(parameters.getConversion().apply(item));
return item;
})
.collect(toImmutableList());
Block input = out.build();
SqlVarbinary result = (SqlVarbinary) AggregationTestUtils.executeAggregation(
function,
input);
KllItemsSketch recreated = KllItemsSketch.wrap(WritableMemory.writableWrap((result.getBytes())), parameters.getComparator(), parameters.getSerde());
List sketchItems = (List<Object>) addedValues.stream().map(parameters.getConversion()::apply).collect(Collectors.toList());
checkSketchesEqual(sketchItems, sketch, recreated);
}
@Test
public void testEmptyInput()
{
assertAggregation(DOUBLE_FUNCTION,
null,
DOUBLE.createBlockBuilder(null, 0).build());
}
@Test
public void testNulls()
{
// test no exception is thrown
assertAggregation(DOUBLE_FUNCTION,
null,
DOUBLE.createBlockBuilder(null, 2)
.appendNull()
.appendNull()
.build());
}
private static void assertThrows(ThrowingRunnable runnable, Class<?> exceptionType, @Language("regexp") String regex)
{
try {
runnable.run();
throw new AssertionError("no exception was thrown");
}
catch (Throwable e) {
assertEquals(e.getClass(), exceptionType);
assertTrue(Optional.ofNullable(e.getMessage()).orElse("").matches(regex), format("Error message: '%s' didn't match regex: '%s'", e.getMessage(), regex));
}
}
private static <T> void checkSketchesEqual(List<T> items, KllItemsSketch<T> expected, KllItemsSketch<T> actual)
{
assertEquals(expected.getK(), actual.getK());
items.forEach(item -> assertEquals(actual.getRank(item), expected.getRank(item), 1E-8));
assertEquals(actual.getSortedView().getCumulativeWeights(), expected.getSortedView().getCumulativeWeights(), "weights are not equal");
assertEquals(actual.getSortedView().getQuantiles(), expected.getSortedView().getQuantiles(), "quantiles are not equal");
}
private static JavaAggregationFunctionImplementation getFunction(Type... types)
{
return getFunction("sketch_kll", types);
}
private static JavaAggregationFunctionImplementation getFunction(String name, Type... types)
{
return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(
metadata.getFunctionAndTypeManager()
.lookupFunction(name, fromTypes(types)));
}
}