FieldVisibilitySchemaTransformation.java
package graphql.schema.transform;
import com.google.common.collect.ImmutableList;
import graphql.PublicApi;
import graphql.introspection.Introspection;
import graphql.schema.GraphQLEnumType;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLFieldsContainer;
import graphql.schema.GraphQLInputObjectField;
import graphql.schema.GraphQLInputObjectType;
import graphql.schema.GraphQLInterfaceType;
import graphql.schema.GraphQLNamedType;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLScalarType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.GraphQLUnionType;
import graphql.schema.SchemaTraverser;
import graphql.schema.transform.VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static graphql.schema.SchemaTransformer.transformSchemaWithDeletes;
/**
* Transforms a schema by applying a visibility predicate to every field.
*/
@PublicApi
public class FieldVisibilitySchemaTransformation {
private final VisibleFieldPredicate visibleFieldPredicate;
private final Runnable beforeTransformationHook;
private final Runnable afterTransformationHook;
public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate) {
this(visibleFieldPredicate, () -> {
}, () -> {
});
}
public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate,
Runnable beforeTransformationHook,
Runnable afterTransformationHook) {
this.visibleFieldPredicate = visibleFieldPredicate;
this.beforeTransformationHook = beforeTransformationHook;
this.afterTransformationHook = afterTransformationHook;
}
public final GraphQLSchema apply(GraphQLSchema schema) {
beforeTransformationHook.run();
// Find root unused types BEFORE transformation
// These are types that exist in the schema but are NOT reachable from operation types + directives
Set<String> rootUnusedTypes = findRootUnusedTypes(schema);
// we delete all fields that should be deleted
// this assumes the field remove itself is semantically valid
GraphQLSchema interimSchema = transformSchemaWithDeletes(schema,
new FieldRemovalVisitor(visibleFieldPredicate));
// cleanup schema
// now we want to remove all types which are not reachable via root types, directives and the interface implements relationship
SchemaTraverser schemaTraverser = new SchemaTraverser(childrenWithInterfaceImplementations(interimSchema));
// first we observe all types we don't want to delete
Set<String> observedTypes = new LinkedHashSet<>();
TypeObservingVisitor typeObservingVisitor = new TypeObservingVisitor(observedTypes);
schemaTraverser.depthFirst(typeObservingVisitor, getRootTypes(interimSchema));
// Traverse from root unused types that still exist after transformation
// This preserves originally unused types and their dependencies
List<GraphQLSchemaElement> existingRootUnusedTypes = rootUnusedTypes.stream()
.map(interimSchema::getType)
.filter(Objects::nonNull)
.map(type -> (GraphQLSchemaElement) type)
.collect(Collectors.toList());
if (!existingRootUnusedTypes.isEmpty()) {
schemaTraverser.depthFirst(typeObservingVisitor, existingRootUnusedTypes);
}
// then we delete all the types which are not used anymore
GraphQLSchema finalSchema = transformSchemaWithDeletes(interimSchema,
new TypeRemovalVisitor(observedTypes));
afterTransformationHook.run();
return finalSchema;
}
/**
* Finds root unused types - types that exist in additional types but are NOT reachable
* from operation types (Query, Mutation, Subscription) and directives.
*/
private Set<String> findRootUnusedTypes(GraphQLSchema schema) {
// Collect all types reachable from operation roots + directives
// Use a traverser that includes interface implementations
Set<String> typesReachableFromRoots = new LinkedHashSet<>();
SchemaTraverser traverser = new SchemaTraverser(childrenWithInterfaceImplementations(schema));
TypeObservingVisitor visitor = new TypeObservingVisitor(typesReachableFromRoots);
traverser.depthFirst(visitor, getRootTypes(schema));
// Root unused types are additional types that are NOT reachable from roots
Set<String> rootUnusedTypes = new LinkedHashSet<>();
for (GraphQLNamedType type : schema.getAdditionalTypes()) {
String typeName = type.getName();
if (!typesReachableFromRoots.contains(typeName) && !isIntrospectionType(typeName)) {
rootUnusedTypes.add(typeName);
}
}
return rootUnusedTypes;
}
/**
* Checks if a type is an introspection type that should be protected from removal.
* This includes standard introspection types (starting with "__") and special types
* like _AppliedDirective (starting with "_") added by IntrospectionWithDirectivesSupport.
*/
private static boolean isIntrospectionType(String typeName) {
return Introspection.isIntrospectionTypes(typeName) || typeName.startsWith("_");
}
/**
* Creates a function that returns children of a schema element, including interface implementations.
* This ensures that when traversing from an interface, we also visit all types that implement it.
*/
private Function<GraphQLSchemaElement, List<GraphQLSchemaElement>> childrenWithInterfaceImplementations(GraphQLSchema schema) {
return schemaElement -> {
if (!(schemaElement instanceof GraphQLInterfaceType)) {
return schemaElement.getChildren();
}
ArrayList<GraphQLSchemaElement> children = new ArrayList<>(schemaElement.getChildren());
List<GraphQLObjectType> implementations = schema.getImplementations((GraphQLInterfaceType) schemaElement);
children.addAll(implementations);
return children;
};
}
private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
private final Set<String> observedTypes;
private TypeObservingVisitor(Set<String> observedTypes) {
this.observedTypes = observedTypes;
}
@Override
protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
TraverserContext<GraphQLSchemaElement> context) {
if (node instanceof GraphQLObjectType ||
node instanceof GraphQLEnumType ||
node instanceof GraphQLInputObjectType ||
node instanceof GraphQLInterfaceType ||
node instanceof GraphQLUnionType ||
node instanceof GraphQLScalarType) {
observedTypes.add(((GraphQLNamedType) node).getName());
}
return TraversalControl.CONTINUE;
}
}
private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {
private final VisibleFieldPredicate visibilityPredicate;
private final Set<GraphQLFieldDefinition> fieldDefinitionsToActuallyRemove = new LinkedHashSet<>();
private final Set<GraphQLInputObjectField> inputObjectFieldsToDelete = new LinkedHashSet<>();
private FieldRemovalVisitor(VisibleFieldPredicate visibilityPredicate) {
this.visibilityPredicate = visibilityPredicate;
}
@Override
public TraversalControl visitGraphQLObjectType(GraphQLObjectType objectType, TraverserContext<GraphQLSchemaElement> context) {
return visitFieldsContainer(objectType, context);
}
@Override
public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType objectType, TraverserContext<GraphQLSchemaElement> context) {
return visitFieldsContainer(objectType, context);
}
private TraversalControl visitFieldsContainer(GraphQLFieldsContainer fieldsContainer, TraverserContext<GraphQLSchemaElement> context) {
boolean allFieldsDeleted = true;
for (GraphQLFieldDefinition fieldDefinition : fieldsContainer.getFieldDefinitions()) {
VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl(
fieldDefinition, fieldsContainer);
if (!visibilityPredicate.isVisible(environment)) {
fieldDefinitionsToActuallyRemove.add(fieldDefinition);
} else {
allFieldsDeleted = false;
}
}
if (allFieldsDeleted) {
// we are deleting the whole interface type because all fields are supposed to be deleted
return deleteNode(context);
} else {
return TraversalControl.CONTINUE;
}
}
@Override
public TraversalControl visitGraphQLInputObjectType(GraphQLInputObjectType inputObjectType, TraverserContext<GraphQLSchemaElement> context) {
boolean allFieldsDeleted = true;
for (GraphQLInputObjectField inputField : inputObjectType.getFieldDefinitions()) {
VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl(
inputField, inputObjectType);
if (!visibilityPredicate.isVisible(environment)) {
inputObjectFieldsToDelete.add(inputField);
} else {
allFieldsDeleted = false;
}
}
if (allFieldsDeleted) {
// we are deleting the whole input object type because all fields are supposed to be deleted
return deleteNode(context);
} else {
return TraversalControl.CONTINUE;
}
}
@Override
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition definition,
TraverserContext<GraphQLSchemaElement> context) {
if (fieldDefinitionsToActuallyRemove.contains(definition)) {
return deleteNode(context);
} else {
return TraversalControl.CONTINUE;
}
}
@Override
public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField definition,
TraverserContext<GraphQLSchemaElement> context) {
if (inputObjectFieldsToDelete.contains(definition)) {
return deleteNode(context);
} else {
return TraversalControl.CONTINUE;
}
}
}
private static class TypeRemovalVisitor extends GraphQLTypeVisitorStub {
private final Set<String> protectedTypeNames;
private TypeRemovalVisitor(Set<String> protectedTypeNames) {
this.protectedTypeNames = protectedTypeNames;
}
@Override
public TraversalControl visitGraphQLType(GraphQLSchemaElement node,
TraverserContext<GraphQLSchemaElement> context) {
if (node instanceof GraphQLNamedType) {
String name = ((GraphQLNamedType) node).getName();
if (isIntrospectionType(name)) {
return TraversalControl.CONTINUE;
}
}
if (node instanceof GraphQLObjectType ||
node instanceof GraphQLEnumType ||
node instanceof GraphQLInputObjectType ||
node instanceof GraphQLInterfaceType ||
node instanceof GraphQLUnionType ||
node instanceof GraphQLScalarType) {
String name = ((GraphQLNamedType) node).getName();
if (!protectedTypeNames.contains(name)) {
return deleteNode(context);
}
}
return TraversalControl.CONTINUE;
}
}
private List<GraphQLSchemaElement> getRootTypes(GraphQLSchema schema) {
return ImmutableList.<GraphQLSchemaElement>builder()
.addAll(getOperationTypes(schema))
.addAll(schema.getDirectives())
.build();
}
private List<GraphQLObjectType> getOperationTypes(GraphQLSchema schema) {
return Stream.of(
schema.getQueryType(),
schema.getSubscriptionType(),
schema.getMutationType()
).filter(Objects::nonNull).collect(Collectors.toList());
}
}