HiveAggregationFunctionImplementationFactory.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.hive.functions.aggregation;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.hive.functions.type.BlockInputDecoder;
import com.facebook.presto.hive.functions.type.BlockInputDecoders;
import com.facebook.presto.hive.functions.type.ObjectEncoders;
import com.facebook.presto.spi.function.Signature;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.facebook.presto.hive.functions.aggregation.HiveAccumulatorMethodHandles.getCombineFunction;
import static com.facebook.presto.hive.functions.aggregation.HiveAccumulatorMethodHandles.getInputFunction;
import static com.facebook.presto.hive.functions.aggregation.HiveAccumulatorMethodHandles.getOutputFunction;
import static java.util.Objects.requireNonNull;
public class HiveAggregationFunctionImplementationFactory
{
private final Signature signature;
private final List<Type> inputTypes;
private final Type intermediateType;
private final Type outputType;
private final Supplier<GenericUDAFEvaluator> partialEvaluatorSupplier;
private final Supplier<GenericUDAFEvaluator> finalEvaluatorSupplier;
private final ObjectInspector[] inputInspectors;
private final ObjectInspector intermediateInspector;
private final ObjectInspector outputInspector;
public HiveAggregationFunctionImplementationFactory(
Signature signature,
List<Type> inputTypes,
Type intermediateType,
Type outputType,
Supplier<GenericUDAFEvaluator> partialEvaluatorSupplier,
Supplier<GenericUDAFEvaluator> finalEvaluatorSupplier,
ObjectInspector[] inputInspectors,
ObjectInspector intermediateInspector,
ObjectInspector outputInspector)
{
this.signature = requireNonNull(signature);
this.inputTypes = requireNonNull(inputTypes);
this.intermediateType = requireNonNull(intermediateType);
this.outputType = requireNonNull(outputType);
this.partialEvaluatorSupplier = partialEvaluatorSupplier;
this.finalEvaluatorSupplier = finalEvaluatorSupplier;
this.inputInspectors = requireNonNull(inputInspectors);
this.intermediateInspector = requireNonNull(intermediateInspector);
this.outputInspector = requireNonNull(outputInspector);
}
public HiveAggregationFunctionImplementation create()
{
HiveAggregationFunctionDescription metadata = new HiveAggregationFunctionDescription(
signature.getName(),
inputTypes,
ImmutableList.of(intermediateType),
outputType,
true,
false);
HiveAccumulatorInvoker invocationContext = new HiveAccumulatorInvoker(
partialEvaluatorSupplier,
finalEvaluatorSupplier,
ObjectEncoders.createEncoder(outputType, outputInspector),
outputType);
List<BlockInputDecoder> inputDecoders = Streams.zip(
inputTypes.stream(),
Stream.of(inputInspectors),
(type, inspector) -> BlockInputDecoders.createBlockInputDecoder(inspector, type))
.collect(Collectors.toList());
HiveAccumulatorFunctions methods = new HiveAccumulatorFunctions(
getInputFunction(invocationContext, inputDecoders),
getCombineFunction(invocationContext),
getOutputFunction(invocationContext));
HiveAccumulatorStateDescription stateMetadata = new HiveAccumulatorStateDescription(
HiveAccumulatorState.class,
new HiveAccumulatorStateSerializer(
partialEvaluatorSupplier,
finalEvaluatorSupplier,
intermediateType,
intermediateInspector),
new HiveAccumulatorStateFactory(invocationContext::newAggregationBuffer));
return new HiveAggregationFunctionImplementation(metadata, methods, stateMetadata);
}
}