ArgValueOfAllowedTypeChecker.java
package graphql.schema.idl;
import graphql.AssertException;
import graphql.GraphQLContext;
import graphql.GraphQLError;
import graphql.Internal;
import graphql.execution.CoercedVariables;
import graphql.language.Argument;
import graphql.language.ArrayValue;
import graphql.language.Directive;
import graphql.language.EnumTypeDefinition;
import graphql.language.EnumTypeExtensionDefinition;
import graphql.language.EnumValue;
import graphql.language.EnumValueDefinition;
import graphql.language.InputObjectTypeDefinition;
import graphql.language.InputObjectTypeExtensionDefinition;
import graphql.language.InputValueDefinition;
import graphql.language.ListType;
import graphql.language.Node;
import graphql.language.NonNullType;
import graphql.language.NullValue;
import graphql.language.ObjectField;
import graphql.language.ObjectValue;
import graphql.language.ScalarTypeDefinition;
import graphql.language.ScalarTypeExtensionDefinition;
import graphql.language.Type;
import graphql.language.TypeDefinition;
import graphql.language.TypeName;
import graphql.language.Value;
import graphql.schema.CoercingParseLiteralException;
import graphql.schema.GraphQLScalarType;
import graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;
import static graphql.Assert.assertShouldNeverHappen;
import static graphql.collect.ImmutableKit.emptyList;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.DUPLICATED_KEYS_MESSAGE;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.EXPECTED_ENUM_MESSAGE;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.EXPECTED_NON_NULL_MESSAGE;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.EXPECTED_OBJECT_MESSAGE;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.MISSING_REQUIRED_FIELD_MESSAGE;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.MUST_BE_VALID_ENUM_VALUE_MESSAGE;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.NOT_A_VALID_SCALAR_LITERAL_MESSAGE;
import static graphql.schema.idl.errors.DirectiveIllegalArgumentTypeError.UNKNOWN_FIELDS_MESSAGE;
import static java.lang.String.format;
import static java.util.stream.Collectors.counting;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
/**
* Class to check whether a given directive argument value
* matches a given directive definition.
*/
@Internal
class ArgValueOfAllowedTypeChecker {
private final Directive directive;
private final Node<?> element;
private final String elementName;
private final Argument argument;
private final TypeDefinitionRegistry typeRegistry;
private final RuntimeWiring runtimeWiring;
ArgValueOfAllowedTypeChecker(final Directive directive,
final Node<?> element,
final String elementName,
final Argument argument,
final TypeDefinitionRegistry typeRegistry,
final RuntimeWiring runtimeWiring) {
this.directive = directive;
this.element = element;
this.elementName = elementName;
this.argument = argument;
this.typeRegistry = typeRegistry;
this.runtimeWiring = runtimeWiring;
}
/**
* Recursively inspects an argument value given an allowed type.
* Given the (invalid) SDL below:
* <p>
* directive @myDirective(arg: [[String]] ) on FIELD_DEFINITION
* <p>
* query {
* f: String @myDirective(arg: ["A String"])
* }
* <p>
* it will first check that the `myDirective.arg` type is an array
* and fail when finding "A String" as it expected a nested array ([[String]]).
*
* @param errors validation error collector
* @param instanceValue directive argument value
* @param allowedArgType directive definition argument allowed type
*/
void checkArgValueMatchesAllowedType(List<GraphQLError> errors, Value<?> instanceValue, Type<?> allowedArgType) {
if (allowedArgType instanceof TypeName) {
checkArgValueMatchesAllowedTypeName(errors, instanceValue, allowedArgType);
} else if (allowedArgType instanceof ListType) {
checkArgValueMatchesAllowedListType(errors, instanceValue, (ListType) allowedArgType);
} else if (allowedArgType instanceof NonNullType) {
checkArgValueMatchesAllowedNonNullType(errors, instanceValue, (NonNullType) allowedArgType);
} else {
assertShouldNeverHappen("Unsupported Type '%s' was added. ", allowedArgType);
}
}
private void addValidationError(List<GraphQLError> errors, String message, Object... args) {
errors.add(new DirectiveIllegalArgumentTypeError(element, elementName, directive.getName(), argument.getName(), format(message, args)));
}
private void checkArgValueMatchesAllowedTypeName(List<GraphQLError> errors, Value<?> instanceValue, Type<?> allowedArgType) {
if (instanceValue instanceof NullValue) {
return;
}
String allowedTypeName = ((TypeName) allowedArgType).getName();
TypeDefinition<?> allowedTypeDefinition = typeRegistry.getTypeOrNull(allowedTypeName);
if (allowedTypeDefinition == null) {
throw new AssertException(format("Directive unknown argument type '%s'. This should have been validated before.", allowedTypeName));
}
if (allowedTypeDefinition instanceof ScalarTypeDefinition) {
checkArgValueMatchesAllowedScalar(errors, instanceValue, (ScalarTypeDefinition) allowedTypeDefinition);
} else if (allowedTypeDefinition instanceof EnumTypeDefinition) {
checkArgValueMatchesAllowedEnum(errors, instanceValue, (EnumTypeDefinition) allowedTypeDefinition);
} else if (allowedTypeDefinition instanceof InputObjectTypeDefinition) {
checkArgValueMatchesAllowedInputType(errors, instanceValue, (InputObjectTypeDefinition) allowedTypeDefinition);
} else {
assertShouldNeverHappen("'%s' must be an input type. It is %s instead. ", allowedTypeName, allowedTypeDefinition.getClass());
}
}
private void checkArgValueMatchesAllowedInputType(List<GraphQLError> errors, Value<?> instanceValue, InputObjectTypeDefinition allowedTypeDefinition) {
if (!(instanceValue instanceof ObjectValue)) {
addValidationError(errors, EXPECTED_OBJECT_MESSAGE, instanceValue.getClass().getSimpleName());
return;
}
ObjectValue objectValue = ((ObjectValue) instanceValue);
// duck typing validation, if it looks like the definition
// then it must be the same type as the definition
List<ObjectField> fields = objectValue.getObjectFields();
List<InputObjectTypeExtensionDefinition> inputObjExt = typeRegistry.inputObjectTypeExtensions().getOrDefault(allowedTypeDefinition.getName(), emptyList());
Stream<InputValueDefinition> inputObjExtValues = inputObjExt.stream().flatMap(inputObj -> inputObj.getInputValueDefinitions().stream());
List<InputValueDefinition> inputValueDefinitions = Stream.concat(allowedTypeDefinition.getInputValueDefinitions().stream(), inputObjExtValues).collect(toList());
// check for duplicated fields
Map<String, Long> fieldsToOccurrenceMap = fields.stream().map(ObjectField::getName)
.collect(groupingBy(Function.identity(), counting()));
if (fieldsToOccurrenceMap.values().stream().anyMatch(count -> count > 1)) {
addValidationError(errors, DUPLICATED_KEYS_MESSAGE, fieldsToOccurrenceMap.entrySet().stream()
.filter(entry -> entry.getValue() > 1)
.map(Map.Entry::getKey)
.collect(joining(",")));
return;
}
// check for unknown fields
Map<String, InputValueDefinition> nameToInputValueDefMap = inputValueDefinitions.stream()
.collect(toMap(InputValueDefinition::getName, inputValueDef -> inputValueDef));
List<ObjectField> unknownFields = fields.stream()
.filter(field -> !nameToInputValueDefMap.containsKey(field.getName()))
.collect(toList());
if (!unknownFields.isEmpty()) {
addValidationError(errors, UNKNOWN_FIELDS_MESSAGE,
unknownFields.stream()
.map(ObjectField::getName)
.collect(joining(",")),
allowedTypeDefinition.getName());
return;
}
// fields to map for easy access
Map<String, ObjectField> nameToFieldsMap = fields.stream()
.collect(toMap(ObjectField::getName, objectField -> objectField));
// check each single field with its definition
inputValueDefinitions.forEach(allowedValueDef -> {
ObjectField objectField = nameToFieldsMap.get(allowedValueDef.getName());
checkArgInputObjectValueFieldMatchesAllowedDefinition(errors, objectField, allowedValueDef);
});
}
private void checkArgValueMatchesAllowedEnum(List<GraphQLError> errors, Value<?> instanceValue, EnumTypeDefinition allowedTypeDefinition) {
if (!(instanceValue instanceof EnumValue)) {
addValidationError(errors, EXPECTED_ENUM_MESSAGE, instanceValue.getClass().getSimpleName());
return;
}
EnumValue enumValue = ((EnumValue) instanceValue);
List<EnumTypeExtensionDefinition> enumExtensions = typeRegistry.enumTypeExtensions().getOrDefault(allowedTypeDefinition.getName(), emptyList());
Stream<EnumValueDefinition> enumExtStream = enumExtensions.stream().flatMap(enumExt -> enumExt.getEnumValueDefinitions().stream());
List<EnumValueDefinition> enumValueDefinitions = Stream.concat(allowedTypeDefinition.getEnumValueDefinitions().stream(), enumExtStream).collect(toList());
boolean noneMatchAllowedEnumValue = enumValueDefinitions.stream()
.noneMatch(enumAllowedValue -> enumAllowedValue.getName().equals(enumValue.getName()));
if (noneMatchAllowedEnumValue) {
addValidationError(errors, MUST_BE_VALID_ENUM_VALUE_MESSAGE, enumValue.getName(), enumValueDefinitions.stream()
.map(EnumValueDefinition::getName)
.collect(joining(",")));
}
}
private void checkArgValueMatchesAllowedScalar(List<GraphQLError> errors, Value<?> instanceValue, ScalarTypeDefinition allowedTypeDefinition) {
// scalars are allowed to accept ANY literal value - its up to their coercion to decide if its valid or not
List<ScalarTypeExtensionDefinition> extensions = typeRegistry.scalarTypeExtensions().getOrDefault(allowedTypeDefinition.getName(), emptyList());
ScalarWiringEnvironment environment = new ScalarWiringEnvironment(typeRegistry, allowedTypeDefinition, extensions);
WiringFactory wiringFactory = runtimeWiring.getWiringFactory();
GraphQLScalarType scalarType;
if (wiringFactory.providesScalar(environment)) {
scalarType = wiringFactory.getScalar(environment);
} else {
scalarType = runtimeWiring.getScalars().get(allowedTypeDefinition.getName());
}
// scalarType will always be present as
// scalar implementation validation has been performed earlier
if (!isArgumentValueScalarLiteral(scalarType, instanceValue)) {
addValidationError(errors, NOT_A_VALID_SCALAR_LITERAL_MESSAGE, allowedTypeDefinition.getName());
}
}
private void checkArgInputObjectValueFieldMatchesAllowedDefinition(List<GraphQLError> errors, ObjectField objectField, InputValueDefinition allowedValueDef) {
if (objectField != null) {
checkArgValueMatchesAllowedType(errors, objectField.getValue(), allowedValueDef.getType());
return;
}
// check if field definition is required and has no default value
if (allowedValueDef.getType() instanceof NonNullType && allowedValueDef.getDefaultValue() == null) {
addValidationError(errors, MISSING_REQUIRED_FIELD_MESSAGE, allowedValueDef.getName());
}
// other cases are
// - field definition is marked as non-null but has a default value, so the default value can be used
// - field definition is nullable hence null can be used
}
private void checkArgValueMatchesAllowedNonNullType(List<GraphQLError> errors, Value<?> instanceValue, NonNullType allowedArgType) {
if (instanceValue instanceof NullValue) {
addValidationError(errors, EXPECTED_NON_NULL_MESSAGE);
return;
}
Type<?> unwrappedAllowedType = allowedArgType.getType();
checkArgValueMatchesAllowedType(errors, instanceValue, unwrappedAllowedType);
}
private void checkArgValueMatchesAllowedListType(List<GraphQLError> errors, Value<?> instanceValue, ListType allowedArgType) {
// From the spec, on input coercion:
// If the value passed as an input to a list type is not a list and not the null value,
// then the result of input coercion is a list of size one where the single item value
// is the result of input coercion for the list���s item type on the provided value
// (note this may apply recursively for nested lists).
Value<?> coercedInstanceValue = instanceValue;
if (!(instanceValue instanceof ArrayValue) && !(instanceValue instanceof NullValue)) {
coercedInstanceValue = new ArrayValue(Collections.singletonList(instanceValue));
}
if (coercedInstanceValue instanceof NullValue) {
return;
}
Type<?> unwrappedAllowedType = allowedArgType.getType();
ArrayValue arrayValue = ((ArrayValue) coercedInstanceValue);
arrayValue.getValues().forEach(value -> {
checkArgValueMatchesAllowedType(errors, value, unwrappedAllowedType);
});
}
private boolean isArgumentValueScalarLiteral(GraphQLScalarType scalarType, Value<?> instanceValue) {
try {
scalarType.getCoercing().parseLiteral(instanceValue, CoercedVariables.emptyVariables(), GraphQLContext.getDefault(), Locale.getDefault());
return true;
} catch (CoercingParseLiteralException ex) {
return false;
}
}
}