FunctionsParserHelper.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.annotations;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.FunctionDescriptor;
import com.facebook.presto.spi.function.IsNull;
import com.facebook.presto.spi.function.LiteralParameters;
import com.facebook.presto.spi.function.LongVariableConstraint;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.facebook.presto.type.Constraint;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import javax.annotation.Nullable;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Stream;
import static com.facebook.presto.common.function.OperatorType.BETWEEN;
import static com.facebook.presto.common.function.OperatorType.CAST;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.HASH_CODE;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.type.StandardTypes.PARAMETRIC_TYPES;
import static com.facebook.presto.operator.annotations.ImplementationDependency.isImplementationDependencyAnnotation;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static com.facebook.presto.util.Failures.checkCondition;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.lang.reflect.Modifier.isPublic;
import static java.lang.reflect.Modifier.isStatic;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;
public class FunctionsParserHelper
{
private static final Set<OperatorType> COMPARABLE_TYPE_OPERATORS = ImmutableSet.of(EQUAL, NOT_EQUAL, HASH_CODE);
private static final Set<OperatorType> ORDERABLE_TYPE_OPERATORS = ImmutableSet.of(LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL, BETWEEN);
private FunctionsParserHelper()
{}
public static boolean containsAnnotation(Annotation[] annotations, Predicate<Annotation> predicate)
{
return Arrays.stream(annotations).anyMatch(predicate);
}
public static boolean containsImplementationDependencyAnnotation(Annotation[] annotations)
{
return containsAnnotation(annotations, ImplementationDependency::isImplementationDependencyAnnotation);
}
public static List<TypeVariableConstraint> createTypeVariableConstraints(Iterable<TypeParameter> typeParameters, List<ImplementationDependency> dependencies)
{
Set<String> orderableRequired = new HashSet<>();
Set<String> comparableRequired = new HashSet<>();
for (ImplementationDependency dependency : dependencies) {
if (dependency instanceof OperatorImplementationDependency) {
OperatorType operator = ((OperatorImplementationDependency) dependency).getOperator();
if (operator == CAST) {
continue;
}
Set<String> argumentTypes = ((OperatorImplementationDependency) dependency).getArgumentTypes().stream()
.map(TypeSignature::getBase)
.collect(toImmutableSet());
checkArgument(argumentTypes.size() == 1, "Operator dependency must only have arguments of a single type");
String argumentType = Iterables.getOnlyElement(argumentTypes);
if (COMPARABLE_TYPE_OPERATORS.contains(operator)) {
comparableRequired.add(argumentType);
}
if (ORDERABLE_TYPE_OPERATORS.contains(operator)) {
orderableRequired.add(argumentType);
}
}
}
ImmutableList.Builder<TypeVariableConstraint> typeVariableConstraints = ImmutableList.builder();
for (TypeParameter typeParameter : typeParameters) {
String name = typeParameter.value();
String variadicBound = typeParameter.boundedBy().isEmpty() ? null : typeParameter.boundedBy();
checkArgument(variadicBound == null || PARAMETRIC_TYPES.contains(variadicBound), "boundedBy must be a parametric type, got %s", variadicBound);
if (orderableRequired.contains(name)) {
typeVariableConstraints.add(new TypeVariableConstraint(name, false, true, variadicBound, false));
}
else if (comparableRequired.contains(name)) {
typeVariableConstraints.add(new TypeVariableConstraint(name, true, false, variadicBound, false));
}
else {
typeVariableConstraints.add(new TypeVariableConstraint(name, false, false, variadicBound, false));
}
}
return typeVariableConstraints.build();
}
public static void validateSignaturesCompatibility(Optional<Signature> signatureOld, Signature signatureNew)
{
if (!signatureOld.isPresent()) {
return;
}
checkArgument(signatureOld.get().equals(signatureNew), "Implementations with type parameters must all have matching signatures. %s does not match %s", signatureOld.get(), signatureNew);
}
@SafeVarargs
public static Set<Method> findPublicStaticMethods(Class<?> clazz, Class<? extends Annotation>... includedAnnotations)
{
return findPublicStaticMethods(clazz, ImmutableSet.copyOf(asList(includedAnnotations)), ImmutableSet.of());
}
public static Set<Method> findPublicStaticMethods(Class<?> clazz, Set<Class<? extends Annotation>> includedAnnotations, Set<Class<? extends Annotation>> excludedAnnotations)
{
return findMethods(
clazz.getMethods(),
method -> checkArgument(isStatic(method.getModifiers()) && isPublic(method.getModifiers()), "Annotated method [%s] must be static and public", method.getName()),
includedAnnotations,
excludedAnnotations);
}
@SafeVarargs
public static Set<Method> findPublicMethods(Class<?> clazz, Class<? extends Annotation>... includedAnnotations)
{
return findPublicMethods(clazz, ImmutableSet.copyOf(asList(includedAnnotations)), ImmutableSet.of());
}
public static Set<Method> findPublicMethods(Class<?> clazz, Set<Class<? extends Annotation>> includedAnnotations, Set<Class<? extends Annotation>> excludedAnnotations)
{
return findMethods(
clazz.getDeclaredMethods(),
method -> checkArgument(isPublic(method.getModifiers()), "Annotated method [%s] must be public"),
includedAnnotations,
excludedAnnotations);
}
public static Set<Method> findMethods(
Method[] allMethods,
Consumer<Method> methodChecker,
Set<Class<? extends Annotation>> includedAnnotations,
Set<Class<? extends Annotation>> excludedAnnotations)
{
ImmutableSet.Builder<Method> methods = ImmutableSet.builder();
for (Method method : allMethods) {
boolean included = false;
boolean excluded = false;
for (Annotation annotation : method.getAnnotations()) {
for (Class<?> annotationClass : excludedAnnotations) {
if (annotationClass.isInstance(annotation)) {
excluded = true;
break;
}
}
if (excluded) {
break;
}
if (included) {
continue;
}
for (Class<?> annotationClass : includedAnnotations) {
if (annotationClass.isInstance(annotation)) {
included = true;
break;
}
}
}
if (included && !excluded) {
methodChecker.accept(method);
methods.add(method);
}
}
return methods.build();
}
public static Optional<Constructor<?>> findConstructor(Class<?> clazz)
{
Constructor<?>[] constructors = clazz.getConstructors();
checkArgument(constructors.length <= 1, "Class [%s] must have no more than 1 public constructor");
if (constructors.length == 0) {
return Optional.empty();
}
return Optional.of(constructors[0]);
}
public static Set<String> parseLiteralParameters(Method method)
{
LiteralParameters literalParametersAnnotation = method.getAnnotation(LiteralParameters.class);
if (literalParametersAnnotation == null) {
return ImmutableSet.of();
}
return ImmutableSet.copyOf(literalParametersAnnotation.value());
}
public static boolean containsLegacyNullable(Annotation[] annotations)
{
return Arrays.stream(annotations)
.map(Annotation::annotationType)
.map(Class::getName)
.anyMatch(name -> name.equals(Nullable.class.getName()));
}
public static boolean isPrestoAnnotation(Annotation annotation)
{
return isImplementationDependencyAnnotation(annotation) ||
annotation instanceof SqlType ||
annotation instanceof SqlNullable ||
annotation instanceof IsNull;
}
public static Optional<String> parseDescription(AnnotatedElement base, AnnotatedElement override)
{
Optional<String> overrideDescription = parseDescription(override);
if (overrideDescription.isPresent()) {
return overrideDescription;
}
return parseDescription(base);
}
public static Optional<String> parseDescription(AnnotatedElement base)
{
Description description = base.getAnnotation(Description.class);
return (description == null) ? Optional.empty() : Optional.of(description.value());
}
public static ComplexTypeFunctionDescriptor parseFunctionDescriptor(AnnotatedElement base)
{
FunctionDescriptor descriptor = base.getAnnotation(FunctionDescriptor.class);
if (descriptor == null) {
return ComplexTypeFunctionDescriptor.DEFAULT;
}
int pushdownSubfieldArgIndex = descriptor.pushdownSubfieldArgIndex();
Optional<Integer> descriptorPushdownIndex;
if (pushdownSubfieldArgIndex < 0) {
descriptorPushdownIndex = Optional.empty();
}
else {
descriptorPushdownIndex = Optional.of(pushdownSubfieldArgIndex);
}
return new ComplexTypeFunctionDescriptor(
true,
emptyList(),
Optional.of(emptySet()),
Optional.of(ComplexTypeFunctionDescriptor::allSubfieldsRequired),
descriptorPushdownIndex);
}
public static List<LongVariableConstraint> parseLongVariableConstraints(Method inputFunction)
{
return Stream.of(inputFunction.getAnnotationsByType(Constraint.class))
.map(annotation -> new LongVariableConstraint(annotation.variable(), annotation.expression()))
.collect(toImmutableList());
}
public static Map<String, Class<?>> getDeclaredSpecializedTypeParameters(Method method, Set<TypeParameter> typeParameters)
{
Map<String, Class<?>> specializedTypeParameters = new HashMap<>();
TypeParameterSpecialization[] typeParameterSpecializations = method.getAnnotationsByType(TypeParameterSpecialization.class);
ImmutableSet<String> typeParameterNames = typeParameters.stream()
.map(TypeParameter::value)
.collect(toImmutableSet());
for (TypeParameterSpecialization specialization : typeParameterSpecializations) {
checkArgument(typeParameterNames.contains(specialization.name()), "%s does not match any declared type parameters (%s) [%s]", specialization.name(), typeParameters, method);
Class<?> existingSpecialization = specializedTypeParameters.get(specialization.name());
checkArgument(existingSpecialization == null || existingSpecialization.equals(specialization.nativeContainerType()),
"%s has conflicting specializations %s and %s [%s]", specialization.name(), existingSpecialization, specialization.nativeContainerType(), method);
specializedTypeParameters.put(specialization.name(), specialization.nativeContainerType());
}
return specializedTypeParameters;
}
public static void checkPushdownSubfieldArgIndex(Method method, Signature signature, Optional<Integer> pushdownSubfieldArgIndex)
{
if (pushdownSubfieldArgIndex.isPresent()) {
Map<String, TypeVariableConstraint> typeConstraintMapping = new HashMap<>();
for (TypeVariableConstraint constraint : signature.getTypeVariableConstraints()) {
typeConstraintMapping.put(constraint.getName(), constraint);
}
checkCondition(signature.getArgumentTypes().size() > pushdownSubfieldArgIndex.get(), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has out of range pushdown subfield arg index", method);
String typeVariableName = signature.getArgumentTypes().get(pushdownSubfieldArgIndex.get()).toString();
// The type variable must be directly a ROW type
// or (it is a type alias that is not bounded by a type)
// or (it is a type alias that maps to a row type)
boolean meetsTypeConstraint = (!typeConstraintMapping.containsKey(typeVariableName) && typeVariableName.equals(com.facebook.presto.common.type.StandardTypes.ROW)) ||
(typeConstraintMapping.containsKey(typeVariableName) && typeConstraintMapping.get(typeVariableName).getVariadicBound() == null && !typeConstraintMapping.get(typeVariableName).isNonDecimalNumericRequired()) ||
(typeConstraintMapping.containsKey(typeVariableName) && typeConstraintMapping.get(typeVariableName).getVariadicBound().equals(com.facebook.presto.common.type.StandardTypes.ROW));
checkCondition(meetsTypeConstraint, FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] does not have a struct or row type as pushdown subfield arg", method);
}
}
}