AstSorter.java

package graphql.language;

import graphql.PublicApi;
import graphql.schema.idl.TypeInfo;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.function.Function;

import static graphql.util.TreeTransformerUtil.changeNode;
import static java.util.Comparator.naturalOrder;
import static java.util.Comparator.nullsLast;

/**
 * A class that helps you sort AST nodes
 */
@PublicApi
public class AstSorter {

    /**
     * This will sort nodes in specific orders and then alphabetically.
     *
     * The order is :
     * <ul>
     * <li>Query operation definitions</li>
     * <li>Mutation operation definitions</li>
     * <li>Subscriptions operation definitions</li>
     * <li>Fragment definitions</li>
     * <li>Directive definitions</li>
     * <li>Schema definitions</li>
     * <li>Object Type definitions</li>
     * <li>Interface Type definitions</li>
     * <li>Union Type definitions</li>
     * <li>Enum Type definitions</li>
     * <li>Scalar Type definitions</li>
     * <li>Input Object Type definitions</li>
     * </ul>
     *
     * After those groupings they will be sorted alphabetic.  All arguments and directives on elements
     * will be sorted alphabetically by name.
     *
     * @param nodeToBeSorted the node to be sorted
     * @param <T>            of type {@link graphql.language.Node}
     *
     * @return a new sorted node (because {@link graphql.language.Node}s are immutable)
     */
    public <T extends Node> T sort(T nodeToBeSorted) {

        NodeVisitorStub visitor = new NodeVisitorStub() {

            @Override
            public TraversalControl visitDocument(Document node, TraverserContext<Node> context) {
                Document changedNode = node.transform(builder -> {
                    List<Definition> definitions = sort(node.getDefinitions(), comparingDefinitions());
                    builder.definitions(definitions);
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitOperationDefinition(OperationDefinition node, TraverserContext<Node> context) {
                OperationDefinition changedNode = node.transform(builder -> {
                    builder.variableDefinitions(sort(node.getVariableDefinitions(), comparing(VariableDefinition::getName)));
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.selectionSet(sortSelectionSet(node.getSelectionSet()));
                });
                return changeNode(context, changedNode);
            }


            @Override
            public TraversalControl visitField(Field node, TraverserContext<Node> context) {
                Field changedNode = node.transform(builder -> {
                    builder.arguments(sort(node.getArguments(), comparing(Argument::getName)));
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.selectionSet(sortSelectionSet(node.getSelectionSet()));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitFragmentDefinition(FragmentDefinition node, TraverserContext<Node> context) {
                FragmentDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.selectionSet(sortSelectionSet(node.getSelectionSet()));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitInlineFragment(InlineFragment node, TraverserContext<Node> context) {
                InlineFragment changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.selectionSet(sortSelectionSet(node.getSelectionSet()));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitFragmentSpread(FragmentSpread node, TraverserContext<Node> context) {
                FragmentSpread changedNode = node.transform(builder -> {
                    List<Directive> directives = sort(node.getDirectives(), comparing(Directive::getName));
                    builder.directives(directives);
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitDirective(Directive node, TraverserContext<Node> context) {
                Directive changedNode = node.transform(builder -> {
                    List<Argument> arguments = sort(node.getArguments(), comparing(Argument::getName));
                    builder.arguments(arguments);
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitObjectValue(ObjectValue node, TraverserContext<Node> context) {
                ObjectValue changedNode = node.transform(builder -> {
                    List<ObjectField> objectFields = sort(node.getObjectFields(), comparing(ObjectField::getName));
                    builder.objectFields(objectFields);
                });
                return changeNode(context, changedNode);
            }

            // SDL classes here

            @Override
            public TraversalControl visitSchemaDefinition(SchemaDefinition node, TraverserContext<Node> context) {
                SchemaDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.operationTypeDefinitions(sort(node.getOperationTypeDefinitions(), comparing(OperationTypeDefinition::getName)));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitEnumTypeDefinition(EnumTypeDefinition node, TraverserContext<Node> context) {
                EnumTypeDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.enumValueDefinitions(sort(node.getEnumValueDefinitions(), comparing(EnumValueDefinition::getName)));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitScalarTypeDefinition(ScalarTypeDefinition node, TraverserContext<Node> context) {
                ScalarTypeDefinition changedNode = node.transform(builder -> {
                    List<Directive> directives = sort(node.getDirectives(), comparing(Directive::getName));
                    builder.directives(directives);
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitInputObjectTypeDefinition(InputObjectTypeDefinition node, TraverserContext<Node> context) {
                InputObjectTypeDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.inputValueDefinitions(sort(node.getInputValueDefinitions(), comparing(InputValueDefinition::getName)));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitObjectTypeDefinition(ObjectTypeDefinition node, TraverserContext<Node> context) {
                ObjectTypeDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.implementz(sort(node.getImplements(), comparingTypes()));
                    builder.fieldDefinitions(sort(node.getFieldDefinitions(), comparing(FieldDefinition::getName)));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitInterfaceTypeDefinition(InterfaceTypeDefinition node, TraverserContext<Node> context) {
                InterfaceTypeDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.implementz(sort(node.getImplements(), comparingTypes()));
                    builder.definitions(sort(node.getFieldDefinitions(), comparing(FieldDefinition::getName)));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitUnionTypeDefinition(UnionTypeDefinition node, TraverserContext<Node> context) {
                UnionTypeDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.memberTypes(sort(node.getMemberTypes(), comparingTypes()));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitFieldDefinition(FieldDefinition node, TraverserContext<Node> context) {
                FieldDefinition changedNode = node.transform(builder -> {
                    builder.directives(sort(node.getDirectives(), comparing(Directive::getName)));
                    builder.inputValueDefinitions(sort(node.getInputValueDefinitions(), comparing(InputValueDefinition::getName)));
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitInputValueDefinition(InputValueDefinition node, TraverserContext<Node> context) {
                InputValueDefinition changedNode = node.transform(builder -> {
                    List<Directive> directives = sort(node.getDirectives(), comparing(Directive::getName));
                    builder.directives(directives);
                });
                return changeNode(context, changedNode);
            }

            @Override
            public TraversalControl visitDirectiveDefinition(DirectiveDefinition node, TraverserContext<Node> context) {
                DirectiveDefinition changedNode = node.transform(builder -> {
                    builder.inputValueDefinitions(sort(node.getInputValueDefinitions(), comparing(InputValueDefinition::getName)));
                    builder.directiveLocations(sort(node.getDirectiveLocations(), comparing(DirectiveLocation::getName)));
                });
                return changeNode(context, changedNode);
            }
        };

        AstTransformer astTransformer = new AstTransformer();
        Node newDoc = astTransformer.transform(nodeToBeSorted, visitor);
        //noinspection unchecked
        return (T) newDoc;
    }


    private Comparator<Type> comparingTypes() {
        return comparing(type -> TypeInfo.typeInfo(type).getName());
    }

    private Comparator<Selection> comparingSelections() {
        Function<Selection, String> byName = s -> {
            if (s instanceof FragmentSpread) {
                return ((FragmentSpread) s).getName();
            }
            if (s instanceof Field) {
                return ((Field) s).getName();
            }
            if (s instanceof InlineFragment) {
                TypeName typeCondition = ((InlineFragment) s).getTypeCondition();
                return typeCondition == null ? "" : typeCondition.getName();
            }
            return "";
        };
        Function<Selection, Integer> byType = s -> {
            if (s instanceof Field) {
                return 1;
            }
            if (s instanceof FragmentSpread) {
                return 2;
            }
            if (s instanceof InlineFragment) {
                return 3;
            }
            return 4;
        };
        return comparing(byType).thenComparing(comparing(byName));
    }

    private Comparator<Definition> comparingDefinitions() {
        Function<Definition, String> byName = d -> {
            if (d instanceof OperationDefinition) {
                String name = ((OperationDefinition) d).getName();
                return name == null ? "" : name;
            }
            if (d instanceof FragmentDefinition) {
                return ((FragmentDefinition) d).getName();
            }
            if (d instanceof DirectiveDefinition) {
                return ((DirectiveDefinition) d).getName();
            }
            if (d instanceof TypeDefinition) {
                return ((TypeDefinition) d).getName();
            }
            return "";
        };
        Function<Definition, Integer> byType = d -> {
            if (d instanceof OperationDefinition) {
                OperationDefinition.Operation operation = ((OperationDefinition) d).getOperation();
                if (OperationDefinition.Operation.QUERY == operation || operation == null) {
                    return 101;
                }
                if (OperationDefinition.Operation.MUTATION == operation) {
                    return 102;
                }
                if (OperationDefinition.Operation.SUBSCRIPTION == operation) {
                    return 104;
                }
                return 100;
            }
            if (d instanceof FragmentDefinition) {
                return 200;
            }
            // SDL
            if (d instanceof DirectiveDefinition) {
                return 300;
            }
            if (d instanceof SchemaDefinition) {
                return 400;
            }
            if (d instanceof TypeDefinition) {
                if (d instanceof ObjectTypeDefinition) {
                    return 501;
                }
                if (d instanceof InterfaceTypeDefinition) {
                    return 502;
                }
                if (d instanceof UnionTypeDefinition) {
                    return 503;
                }
                if (d instanceof EnumTypeDefinition) {
                    return 504;
                }
                if (d instanceof ScalarTypeDefinition) {
                    return 505;
                }
                if (d instanceof InputObjectTypeDefinition) {
                    return 506;
                }
                return 500;
            }
            return -1;
        };
        return comparing(byType).thenComparing(byName);
    }

    private SelectionSet sortSelectionSet(SelectionSet selectionSet) {
        if (selectionSet == null) {
            return null;
        }
        List<Selection> selections = sort(selectionSet.getSelections(), comparingSelections());
        return selectionSet.transform(builder -> builder.selections(selections));
    }

    private <T> List<T> sort(List<T> items, Comparator<T> comparing) {
        items = new ArrayList<>(items);
        items.sort(comparing);
        return items;
    }

    private <T, U extends Comparable<? super U>> Comparator<T> comparing(
            Function<? super T, ? extends U> keyExtractor) {
        return Comparator.comparing(keyExtractor, nullsLast(naturalOrder()));
    }

}