TestMultimapAggAggregation.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.RowPageBuilder;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.operator.UpdateMemory;
import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInput;
import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInputBuilder;
import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestOutput;
import com.facebook.presto.operator.aggregation.groupByAggregations.GroupByAggregationTestUtils;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.primitives.Ints;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory;
import static com.facebook.presto.operator.aggregation.multimapagg.MultimapAggregationFunction.NAME;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.util.StructuralTestUtil.mapType;
import static com.google.common.base.Preconditions.checkState;
import static org.testng.Assert.assertTrue;
public class TestMultimapAggAggregation
{
private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = createTestMetadataManager().getFunctionAndTypeManager();
@Test
public void testSingleValueMap()
{
testMultimapAgg(DOUBLE, ImmutableList.of(1.0), VARCHAR, ImmutableList.of("a"));
testMultimapAgg(VARCHAR, ImmutableList.of("a"), BIGINT, ImmutableList.of(1L));
}
@Test
public void testMultiValueMap()
{
testMultimapAgg(DOUBLE, ImmutableList.of(1.0, 1.0, 1.0), VARCHAR, ImmutableList.of("a", "b", "c"));
testMultimapAgg(DOUBLE, ImmutableList.of(1.0, 1.0, 2.0), VARCHAR, ImmutableList.of("a", "b", "c"));
}
@Test
public void testOrderValueMap()
{
testMultimapAgg(VARCHAR, ImmutableList.of("a", "a", "a"), BIGINT, ImmutableList.of(1L, 2L, 3L));
testMultimapAgg(VARCHAR, ImmutableList.of("a", "a", "a"), BIGINT, ImmutableList.of(2L, 1L, 3L));
testMultimapAgg(VARCHAR, ImmutableList.of("a", "a", "a"), BIGINT, ImmutableList.of(3L, 2L, 1L));
}
@Test
public void testDuplicateValueMap()
{
testMultimapAgg(VARCHAR, ImmutableList.of("a", "a", "a"), BIGINT, ImmutableList.of(1L, 1L, 1L));
testMultimapAgg(VARCHAR, ImmutableList.of("a", "b", "a", "b", "c"), BIGINT, ImmutableList.of(1L, 1L, 1L, 1L, 1L));
}
@Test
public void testNullMap()
{
testMultimapAgg(DOUBLE, ImmutableList.<Double>of(), VARCHAR, ImmutableList.<String>of());
}
@Test
public void testDoubleMapMultimap()
{
Type mapType = mapType(VARCHAR, BIGINT);
List<Double> expectedKeys = ImmutableList.of(1.0, 2.0, 3.0);
List<Map<String, Long>> expectedValues = ImmutableList.of(ImmutableMap.of("a", 1L), ImmutableMap.of("b", 2L, "c", 3L, "d", 4L), ImmutableMap.of("a", 1L));
testMultimapAgg(DOUBLE, expectedKeys, mapType, expectedValues);
}
@Test
public void testDoubleArrayMultimap()
{
Type arrayType = new ArrayType(VARCHAR);
List<Double> expectedKeys = ImmutableList.of(1.0, 2.0, 3.0);
List<List<String>> expectedValues = ImmutableList.of(ImmutableList.of("a", "b"), ImmutableList.of("c"), ImmutableList.of("d", "e", "f"));
testMultimapAgg(DOUBLE, expectedKeys, arrayType, expectedValues);
}
@Test
public void testDoubleRowMap()
{
RowType innerRowType = RowType.from(ImmutableList.of(
RowType.field("f1", BIGINT),
RowType.field("f2", DOUBLE)));
testMultimapAgg(DOUBLE, ImmutableList.of(1.0, 2.0, 3.0), innerRowType, ImmutableList.of(ImmutableList.of(1L, 1.0), ImmutableList.of(2L, 2.0), ImmutableList.of(3L, 3.0)));
}
@Test
public void testMultiplePages()
{
JavaAggregationFunctionImplementation aggFunction = getInternalAggregationFunction(BIGINT, BIGINT);
GroupedAccumulator groupedAccumulator = getGroupedAccumulator(aggFunction);
testMultimapAggWithGroupBy(aggFunction, groupedAccumulator, 0, BIGINT, ImmutableList.of(1L, 1L), BIGINT, ImmutableList.of(2L, 3L));
}
@Test
public void testMultiplePagesAndGroups()
{
JavaAggregationFunctionImplementation aggFunction = getInternalAggregationFunction(BIGINT, BIGINT);
GroupedAccumulator groupedAccumulator = getGroupedAccumulator(aggFunction);
testMultimapAggWithGroupBy(aggFunction, groupedAccumulator, 0, BIGINT, ImmutableList.of(1L, 1L), BIGINT, ImmutableList.of(2L, 3L));
testMultimapAggWithGroupBy(aggFunction, groupedAccumulator, 300, BIGINT, ImmutableList.of(7L, 7L), BIGINT, ImmutableList.of(8L, 9L));
}
@Test
public void testManyValues()
{
JavaAggregationFunctionImplementation aggFunction = getInternalAggregationFunction(BIGINT, BIGINT);
GroupedAccumulator groupedAccumulator = getGroupedAccumulator(aggFunction);
int numGroups = 30000;
int numKeys = 10;
int numValueArraySize = 2;
Random random = new Random();
for (int group = 0; group < numGroups; group++) {
ImmutableList.Builder<Long> keyBuilder = ImmutableList.builder();
ImmutableList.Builder<Long> valueBuilder = ImmutableList.builder();
for (int i = 0; i < numKeys; i++) {
long key = random.nextLong();
for (int j = 0; j < numValueArraySize; j++) {
long value = random.nextLong();
keyBuilder.add(key);
valueBuilder.add(value);
}
}
testMultimapAggWithGroupBy(aggFunction, groupedAccumulator, group, BIGINT, keyBuilder.build(), BIGINT, valueBuilder.build());
}
}
@Test
public void testEmptyStateOutputIsNull()
{
JavaAggregationFunctionImplementation aggregationFunction = getInternalAggregationFunction(BIGINT, BIGINT);
GroupedAccumulator groupedAccumulator = generateAccumulatorFactory(aggregationFunction, Ints.asList(), Optional.empty()).createGroupedAccumulator(UpdateMemory.NOOP);
BlockBuilder blockBuilder = groupedAccumulator.getFinalType().createBlockBuilder(null, 1);
groupedAccumulator.evaluateFinal(0, blockBuilder);
assertTrue(blockBuilder.isNull(0));
}
private static <K, V> void testMultimapAgg(Type keyType, List<K> expectedKeys, Type valueType, List<V> expectedValues)
{
checkState(expectedKeys.size() == expectedValues.size(), "expectedKeys and expectedValues should have equal size");
JavaAggregationFunctionImplementation aggFunc = getInternalAggregationFunction(keyType, valueType);
testMultimapAgg(aggFunc, keyType, expectedKeys, valueType, expectedValues);
}
private static JavaAggregationFunctionImplementation getInternalAggregationFunction(Type keyType, Type valueType)
{
return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction(NAME, fromTypes(keyType, valueType)));
}
private static <K, V> void testMultimapAgg(JavaAggregationFunctionImplementation aggFunc, Type keyType, List<K> expectedKeys, Type valueType, List<V> expectedValues)
{
Map<K, List<V>> map = new HashMap<>();
for (int i = 0; i < expectedKeys.size(); i++) {
if (!map.containsKey(expectedKeys.get(i))) {
map.put(expectedKeys.get(i), new ArrayList<>());
}
map.get(expectedKeys.get(i)).add(expectedValues.get(i));
}
RowPageBuilder builder = RowPageBuilder.rowPageBuilder(keyType, valueType);
for (int i = 0; i < expectedKeys.size(); i++) {
builder.row(expectedKeys.get(i), expectedValues.get(i));
}
assertAggregation(aggFunc, map.isEmpty() ? null : map, builder.build());
}
private static <K, V> void testMultimapAggWithGroupBy(
JavaAggregationFunctionImplementation aggregationFunction,
GroupedAccumulator groupedAccumulator,
int groupId,
Type keyType,
List<K> expectedKeys,
Type valueType,
List<V> expectedValues)
{
RowPageBuilder pageBuilder = RowPageBuilder.rowPageBuilder(keyType, valueType);
ImmutableMultimap.Builder<K, V> outputBuilder = ImmutableMultimap.builder();
for (int i = 0; i < expectedValues.size(); i++) {
pageBuilder.row(expectedKeys.get(i), expectedValues.get(i));
outputBuilder.put(expectedKeys.get(i), expectedValues.get(i));
}
Page page = pageBuilder.build();
AggregationTestInput input = new AggregationTestInputBuilder(
new Block[] {page.getBlock(0), page.getBlock(1)},
aggregationFunction).build();
AggregationTestOutput testOutput = new AggregationTestOutput(outputBuilder.build().asMap());
input.runPagesOnAccumulatorWithAssertion(groupId, groupedAccumulator, testOutput);
}
private GroupedAccumulator getGroupedAccumulator(JavaAggregationFunctionImplementation aggFunction)
{
return generateAccumulatorFactory(aggFunction, Ints.asList(GroupByAggregationTestUtils.createArgs(aggFunction)), Optional.empty()).createGroupedAccumulator(UpdateMemory.NOOP);
}
}