AggregationFromAnnotationsParser.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.operator.ParametricImplementationsGroup;
import com.facebook.presto.operator.annotations.FunctionsParserHelper;
import com.facebook.presto.spi.function.AccumulatorState;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationStateSerializerFactory;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.OutputFunction;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import javax.annotation.Nullable;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.operator.aggregation.AggregationImplementation.Parser.parseImplementation;
import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseDescription;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;
public class AggregationFromAnnotationsParser
{
private AggregationFromAnnotationsParser()
{
}
// This function should only be used for function matching for testing purposes.
// General purpose function matching is done through FunctionRegistry.
@VisibleForTesting
public static ParametricAggregation parseFunctionDefinitionWithTypesConstraint(Class<?> clazz, TypeSignature returnType, List<TypeSignature> argumentTypes)
{
requireNonNull(returnType, "returnType is null");
requireNonNull(argumentTypes, "argumentTypes is null");
for (ParametricAggregation aggregation : parseFunctionDefinitions(clazz)) {
if (aggregation.getSignature().getReturnType().equals(returnType) &&
aggregation.getSignature().getArgumentTypes().equals(argumentTypes)) {
return aggregation;
}
}
throw new IllegalArgumentException(String.format("No method with return type %s and arguments %s", returnType, argumentTypes));
}
public static List<ParametricAggregation> parseFunctionDefinitions(Class<?> aggregationDefinition)
{
AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class);
requireNonNull(aggregationAnnotation, "aggregationAnnotation is null");
ImmutableList.Builder<ParametricAggregation> builder = ImmutableList.builder();
for (Class<?> stateClass : getStateClasses(aggregationDefinition)) {
Method combineFunction = getCombineFunction(aggregationDefinition, stateClass);
Optional<Method> aggregationStateSerializerFactory = getAggregationStateSerializerFactory(aggregationDefinition, stateClass);
for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) {
for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) {
for (AggregationHeader header : parseHeaders(aggregationDefinition, outputFunction)) {
AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory);
ParametricImplementationsGroup<AggregationImplementation> implementations = ParametricImplementationsGroup.of(onlyImplementation);
builder.add(new ParametricAggregation(implementations.getSignature(), header, implementations));
}
}
}
}
return builder.build();
}
public static ParametricAggregation parseFunctionDefinition(Class<?> aggregationDefinition)
{
ParametricImplementationsGroup.Builder<AggregationImplementation> implementationsBuilder = ParametricImplementationsGroup.builder();
AggregationHeader header = parseHeader(aggregationDefinition);
for (Class<?> stateClass : getStateClasses(aggregationDefinition)) {
Method combineFunction = getCombineFunction(aggregationDefinition, stateClass);
Optional<Method> aggregationStateSerializerFactory = getAggregationStateSerializerFactory(aggregationDefinition, stateClass);
Method outputFunction = getOnlyElement(getOutputFunctions(aggregationDefinition, stateClass));
for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) {
AggregationImplementation implementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory);
implementationsBuilder.addImplementation(implementation);
}
}
ParametricImplementationsGroup<AggregationImplementation> implementations = implementationsBuilder.build();
return new ParametricAggregation(implementations.getSignature(), header, implementations);
}
private static Optional<Method> getAggregationStateSerializerFactory(Class<?> aggregationDefinition, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> stateSerializerFactories = FunctionsParserHelper.findPublicStaticMethods(aggregationDefinition, AggregationStateSerializerFactory.class).stream()
.filter(method -> ((AggregationStateSerializerFactory) method.getAnnotation(AggregationStateSerializerFactory.class)).value().equals(stateClass))
.collect(toImmutableList());
if (stateSerializerFactories.isEmpty()) {
return Optional.empty();
}
checkArgument(stateSerializerFactories.size() == 1,
String.format(
"Expect at most 1 @AggregationStateSerializerFactory(%s.class) annotation, found %s in %s",
stateClass.toGenericString(),
stateSerializerFactories.size(),
aggregationDefinition.toGenericString()));
return Optional.of(getOnlyElement(stateSerializerFactories));
}
private static AggregationHeader parseHeader(AnnotatedElement aggregationDefinition)
{
AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class);
requireNonNull(aggregationAnnotation, "aggregationAnnotation is null");
return new AggregationHeader(
aggregationAnnotation.value(),
parseDescription(aggregationDefinition),
aggregationAnnotation.decomposable(),
aggregationAnnotation.isOrderSensitive(),
aggregationAnnotation.visibility(),
aggregationAnnotation.isCalledOnNullInput());
}
private static List<AggregationHeader> parseHeaders(AnnotatedElement aggregationDefinition, AnnotatedElement toParse)
{
AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class);
return getNames(toParse, aggregationAnnotation).stream()
.map(name ->
new AggregationHeader(
name,
parseDescription(aggregationDefinition, toParse),
aggregationAnnotation.decomposable(),
aggregationAnnotation.isOrderSensitive(),
aggregationAnnotation.visibility(),
aggregationAnnotation.isCalledOnNullInput()))
.collect(toImmutableList());
}
private static List<String> getNames(@Nullable AnnotatedElement outputFunction, AggregationFunction aggregationAnnotation)
{
List<String> defaultNames = ImmutableList.<String>builder().add(aggregationAnnotation.value()).addAll(Arrays.asList(aggregationAnnotation.alias())).build();
if (outputFunction == null) {
return defaultNames;
}
AggregationFunction annotation = outputFunction.getAnnotation(AggregationFunction.class);
if (annotation == null) {
return defaultNames;
}
else {
return ImmutableList.<String>builder().add(annotation.value()).addAll(Arrays.asList(annotation.alias())).build();
}
}
public static Method getCombineFunction(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> combineFunctions = FunctionsParserHelper.findPublicStaticMethods(clazz, CombineFunction.class).stream()
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 0)] == stateClass)
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method, 1)] == stateClass)
.collect(toImmutableList());
checkArgument(combineFunctions.size() == 1, String.format("There must be exactly one @CombineFunction in class %s for the @AggregationState %s ", clazz.toGenericString(), stateClass.toGenericString()));
return getOnlyElement(combineFunctions);
}
private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> outputFunctions = FunctionsParserHelper.findPublicStaticMethods(clazz, OutputFunction.class).stream()
.filter(method -> method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass)
.collect(toImmutableList());
checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions");
return outputFunctions;
}
private static List<Method> getInputFunctions(Class<?> clazz, Class<?> stateClass)
{
// Only include methods that match this state class
List<Method> inputFunctions = FunctionsParserHelper.findPublicStaticMethods(clazz, InputFunction.class).stream()
.filter(method -> (method.getParameterTypes()[AggregationImplementation.Parser.findAggregationStateParamId(method)] == stateClass))
.collect(toImmutableList());
checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions");
return inputFunctions;
}
private static Set<Class<?>> getStateClasses(Class<?> clazz)
{
ImmutableSet.Builder<Class<?>> builder = ImmutableSet.builder();
for (Method inputFunction : FunctionsParserHelper.findPublicStaticMethods(clazz, InputFunction.class)) {
checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters");
Class<?> stateClass = AggregationImplementation.Parser.findAggregationStateParamType(inputFunction);
checkArgument(AccumulatorState.class.isAssignableFrom(stateClass), "stateClass is not a subclass of AccumulatorState");
builder.add(stateClass);
}
ImmutableSet<Class<?>> stateClasses = builder.build();
checkArgument(!stateClasses.isEmpty(), "No input functions found");
return stateClasses;
}
}