TopNRowNumberOperator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spiller.SpillerFactory;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import static com.facebook.presto.SystemSessionProperties.isDictionaryAggregationEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
public class TopNRowNumberOperator
implements Operator
{
public static class TopNRowNumberOperatorFactory
implements OperatorFactory
{
private final int operatorId;
private final PlanNodeId planNodeId;
private final List<Type> sourceTypes;
private final List<Integer> outputChannels;
private final List<Integer> partitionChannels;
private final List<Type> partitionTypes;
private final List<Integer> sortChannels;
private final List<SortOrder> sortOrder;
private final int maxRowCountPerPartition;
private final boolean partial;
private final Optional<Integer> hashChannel;
private final int expectedPositions;
private final boolean generateRowNumber;
private boolean closed;
private final long unspillMemoryLimit;
private final JoinCompiler joinCompiler;
private final SpillerFactory spillerFactory;
private final boolean spillEnabled;
public TopNRowNumberOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
List<? extends Type> sourceTypes,
List<Integer> outputChannels,
List<Integer> partitionChannels,
List<? extends Type> partitionTypes,
List<Integer> sortChannels,
List<SortOrder> sortOrder,
int maxRowCountPerPartition,
boolean partial,
Optional<Integer> hashChannel,
int expectedPositions,
long unspillMemoryLimit,
JoinCompiler joinCompiler,
SpillerFactory spillerFactory,
boolean spillEnabled)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.sourceTypes = ImmutableList.copyOf(sourceTypes);
this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null"));
this.partitionChannels = ImmutableList.copyOf(requireNonNull(partitionChannels, "partitionChannels is null"));
this.partitionTypes = ImmutableList.copyOf(requireNonNull(partitionTypes, "partitionTypes is null"));
this.sortChannels = ImmutableList.copyOf(requireNonNull(sortChannels));
this.sortOrder = ImmutableList.copyOf(requireNonNull(sortOrder));
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
this.partial = partial;
checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");
this.maxRowCountPerPartition = maxRowCountPerPartition;
checkArgument(expectedPositions > 0, "expectedPositions must be > 0");
this.generateRowNumber = !partial;
this.expectedPositions = expectedPositions;
this.unspillMemoryLimit = unspillMemoryLimit;
this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null");
this.spillerFactory = spillerFactory;
this.spillEnabled = spillEnabled;
}
@Override
public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, TopNRowNumberOperator.class.getSimpleName());
return new TopNRowNumberOperator(
operatorContext,
sourceTypes,
outputChannels,
partitionChannels,
partitionTypes,
sortChannels,
sortOrder,
maxRowCountPerPartition,
generateRowNumber,
hashChannel,
expectedPositions,
unspillMemoryLimit,
joinCompiler,
spillerFactory,
spillEnabled);
}
@Override
public void noMoreOperators()
{
closed = true;
}
@Override
public OperatorFactory duplicate()
{
return new TopNRowNumberOperatorFactory(operatorId, planNodeId, sourceTypes, outputChannels, partitionChannels, partitionTypes, sortChannels, sortOrder, maxRowCountPerPartition, partial, hashChannel, expectedPositions, unspillMemoryLimit, joinCompiler, spillerFactory, spillEnabled);
}
}
private final OperatorContext operatorContext;
private final int[] outputChannels;
private GroupedTopNBuilder groupedTopNBuilder;
private boolean finishing;
private boolean finished;
private Work<?> unfinishedWork;
private WorkProcessor<Page> outputPages;
public TopNRowNumberOperator(
OperatorContext operatorContext,
List<? extends Type> sourceTypes,
List<Integer> outputChannels,
List<Integer> partitionChannels,
List<Type> partitionTypes,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
int maxRowCountPerPartition,
boolean generateRowNumber,
Optional<Integer> hashChannel,
int expectedPositions,
long unspillMemoryLimit,
JoinCompiler joinCompiler,
SpillerFactory spillerFactory,
boolean spillEnabled)
{
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
ImmutableList.Builder<Integer> outputChannelsBuilder = ImmutableList.builder();
for (int channel : requireNonNull(outputChannels, "outputChannels is null")) {
outputChannelsBuilder.add(channel);
}
if (generateRowNumber) {
outputChannelsBuilder.add(outputChannels.size());
}
this.outputChannels = Ints.toArray(outputChannelsBuilder.build());
checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");
List<Type> types = toTypes(sourceTypes, outputChannels, generateRowNumber);
Supplier<GroupByHash> groupByHashSupplier = () -> createGroupByHash(
partitionTypes,
partitionChannels,
hashChannel,
expectedPositions,
joinCompiler,
isDictionaryAggregationEnabled(operatorContext.getSession()),
this::updateMemoryReservation);
if (spillEnabled) {
this.groupedTopNBuilder = new SpillableGroupedTopNBuilder(
ImmutableList.copyOf(sourceTypes),
partitionTypes,
partitionChannels,
() -> new InMemoryGroupedTopNBuilder(
ImmutableList.copyOf(sourceTypes),
new SimplePageWithPositionComparator(types, sortChannels, sortOrders),
maxRowCountPerPartition,
generateRowNumber,
operatorContext.localRevocableMemoryContext(),
groupByHashSupplier.get()),
() -> new InMemoryGroupedTopNBuilder(
ImmutableList.copyOf(sourceTypes),
new SimplePageWithPositionComparator(types, sortChannels, sortOrders),
maxRowCountPerPartition,
generateRowNumber,
operatorContext.localUserMemoryContext(),
groupByHashSupplier.get()),
operatorContext::isWaitingForMemory,
unspillMemoryLimit,
operatorContext.localUserMemoryContext(),
operatorContext.localRevocableMemoryContext(),
operatorContext.aggregateSystemMemoryContext(),
operatorContext.aggregateSystemMemoryContext(),
operatorContext.getSpillContext(),
operatorContext.getDriverContext().getYieldSignal(),
spillerFactory);
}
else {
this.groupedTopNBuilder = new InMemoryGroupedTopNBuilder(
ImmutableList.copyOf(sourceTypes),
new SimplePageWithPositionComparator(types, sortChannels, sortOrders),
maxRowCountPerPartition,
generateRowNumber,
operatorContext.localUserMemoryContext(),
groupByHashSupplier.get());
}
}
private GroupByHash createGroupByHash(
List<? extends Type> partitionTypes,
List<Integer> partitionChannels,
Optional<Integer> inputHashChannel,
int expectedPositions,
JoinCompiler joinCompiler,
boolean isDictionaryAggregationEnabled,
UpdateMemory updateMemory)
{
if (!partitionChannels.isEmpty()) {
checkArgument(expectedPositions > 0, "expectedPositions must be > 0");
return GroupByHash.createGroupByHash(
partitionTypes,
Ints.toArray(partitionChannels),
inputHashChannel,
expectedPositions,
isDictionaryAggregationEnabled,
joinCompiler,
updateMemory);
}
else {
return new NoChannelGroupByHash();
}
}
@Override
public OperatorContext getOperatorContext()
{
return operatorContext;
}
@Override
public void finish()
{
finishing = true;
}
@Override
public boolean isFinished()
{
// has no more input, has finished flushing, and has no unfinished work
return finished;
}
@Override
public boolean needsInput()
{
// still has more input, has not started flushing yet, and has no unfinished work
return !finishing && outputPages == null && unfinishedWork == null;
}
@Override
public void addInput(Page page)
{
checkState(!finishing, "Operator is already finishing");
checkState(unfinishedWork == null, "Cannot add input with the operator when unfinished work is not empty");
checkState(outputPages == null, "Cannot add input with the operator when flushing");
requireNonNull(page, "page is null");
unfinishedWork = groupedTopNBuilder.processPage(page);
if (unfinishedWork.process()) {
unfinishedWork = null;
}
updateMemoryReservation();
}
@Override
public ListenableFuture<?> startMemoryRevoke()
{
return groupedTopNBuilder.startMemoryRevoke();
}
@Override
public void finishMemoryRevoke()
{
groupedTopNBuilder.finishMemoryRevoke();
}
@Override
public Page getOutput()
{
if (finished) {
return null;
}
if (unfinishedWork != null) {
boolean finished = unfinishedWork.process();
updateMemoryReservation();
if (!finished) {
return null;
}
unfinishedWork = null;
}
if (!finishing) {
return null;
}
if (outputPages == null) {
if (groupedTopNBuilder == null) {
finished = true;
return null;
}
// start flushing
outputPages = groupedTopNBuilder.buildResult();
}
if (!outputPages.process()) {
return null;
}
if (outputPages.isFinished()) {
if (groupedTopNBuilder != null) {
groupedTopNBuilder.close();
groupedTopNBuilder = null;
}
finished = true;
return null;
}
Page outputPage = outputPages.getResult()
.extractChannels(outputChannels);
updateMemoryReservation();
return outputPage;
}
@VisibleForTesting
public int getCapacity()
{
GroupByHash groupByHash = groupedTopNBuilder.getGroupByHash();
checkState(groupByHash != null);
return groupByHash.getCapacity();
}
private boolean updateMemoryReservation()
{
groupedTopNBuilder.updateMemoryReservations();
return operatorContext.isWaitingForMemory().isDone();
}
private static List<Type> toTypes(List<? extends Type> sourceTypes, List<Integer> outputChannels, boolean generateRowNumber)
{
ImmutableList.Builder<Type> types = ImmutableList.builder();
for (int channel : outputChannels) {
types.add(sourceTypes.get(channel));
}
if (generateRowNumber) {
types.add(BIGINT);
}
return types.build();
}
}