AggregationTestInput.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.groupByAggregations;
import com.facebook.presto.common.Page;
import com.facebook.presto.operator.UpdateMemory;
import com.facebook.presto.operator.aggregation.AggregationTestUtils;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.spi.function.aggregation.GroupByIdBlock;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.google.common.base.Suppliers;
import org.testng.internal.collections.Ints;
import java.util.Optional;
import java.util.function.Supplier;
import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory;
public class AggregationTestInput
{
private final Page[] pages;
private final JavaAggregationFunctionImplementation function;
private int[] args;
private final int offset;
private final boolean isReversed;
public AggregationTestInput(JavaAggregationFunctionImplementation function, Page[] pages, int offset, boolean isReversed)
{
this.pages = pages;
this.function = function;
args = GroupByAggregationTestUtils.createArgs(function);
this.offset = offset;
this.isReversed = isReversed;
}
public void runPagesOnAccumulatorWithAssertion(long groupId, GroupedAccumulator groupedAccumulator, AggregationTestOutput expectedValue)
{
GroupedAccumulator accumulator = Suppliers.ofInstance(groupedAccumulator).get();
for (Page page : getPages()) {
accumulator.addInput(getGroupIdBlock(groupId, page), page);
}
expectedValue.validateAccumulator(accumulator, groupId);
}
public GroupedAccumulator runPagesOnAccumulator(long groupId, GroupedAccumulator groupedAccumulator)
{
return runPagesOnAccumulator(groupId, Suppliers.ofInstance(groupedAccumulator));
}
public GroupedAccumulator runPagesOnAccumulator(long groupId, Supplier<GroupedAccumulator> accumulatorSupplier)
{
GroupedAccumulator accumulator = accumulatorSupplier.get();
for (Page page : getPages()) {
accumulator.addInput(getGroupIdBlock(groupId, page), page);
}
return accumulator;
}
private GroupByIdBlock getGroupIdBlock(long groupId, Page page)
{
return AggregationTestUtils.createGroupByIdBlock((int) groupId, page.getPositionCount());
}
private Page[] getPages()
{
Page[] pages = this.pages;
if (isReversed) {
pages = AggregationTestUtils.reverseColumns(pages);
}
if (offset > 0) {
pages = AggregationTestUtils.offsetColumns(pages, offset);
}
return pages;
}
public GroupedAccumulator createGroupedAccumulator()
{
return generateAccumulatorFactory(function, Ints.asList(args), Optional.empty())
.createGroupedAccumulator(UpdateMemory.NOOP);
}
}