PerLevelDataLoaderDispatchStrategy.java

package graphql.execution.instrumentation.dataloader;

import graphql.Assert;
import graphql.GraphQLContext;
import graphql.Internal;
import graphql.Profiler;
import graphql.execution.DataLoaderDispatchStrategy;
import graphql.execution.ExecutionContext;
import graphql.execution.ExecutionStrategyParameters;
import graphql.execution.FieldValueInfo;
import graphql.execution.incremental.AlternativeCallContext;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderRegistry;
import org.jspecify.annotations.NullMarked;
import org.jspecify.annotations.Nullable;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

@Internal
@NullMarked
public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStrategy {

    private final CallStack initialCallStack;
    private final ExecutionContext executionContext;
    private final boolean enableDataLoaderChaining;


    private final Profiler profiler;

    private final Map<AlternativeCallContext, CallStack> alternativeCallContextMap = new ConcurrentHashMap<>();

    private static class ChainedDLStack {

        private final Map<Integer, AtomicReference<@Nullable StateForLevel>> stateMapPerLevel = new ConcurrentHashMap<>();

        // a state for level points to a previous one
        // all the invocations that are linked together are the relevant invocations for the next dispatch
        private static class StateForLevel {
            final @Nullable DataLoaderInvocation dataLoaderInvocation;
            final boolean dispatchingStarted;
            final boolean dispatchingFinished;
            final boolean currentlyDelayedDispatching;
            final @Nullable StateForLevel prev;

            public StateForLevel(@Nullable DataLoaderInvocation dataLoaderInvocation,
                                 boolean dispatchingStarted,
                                 boolean dispatchingFinished,
                                 boolean currentlyDelayedDispatching,
                                 @Nullable StateForLevel prev) {
                this.dataLoaderInvocation = dataLoaderInvocation;
                this.dispatchingStarted = dispatchingStarted;
                this.dispatchingFinished = dispatchingFinished;
                this.currentlyDelayedDispatching = currentlyDelayedDispatching;
                this.prev = prev;
            }
        }


        public @Nullable StateForLevel aboutToStartDispatching(int level, boolean normalDispatchOrDelayed, boolean chained) {
            AtomicReference<@Nullable StateForLevel> currentStateRef = stateMapPerLevel.computeIfAbsent(level, __ -> new AtomicReference<>());
            while (true) {
                StateForLevel currentState = currentStateRef.get();


                boolean dispatchingStarted = false;
                boolean dispatchingFinished = false;
                boolean currentlyDelayedDispatching = false;

                if (currentState != null) {
                    dispatchingStarted = currentState.dispatchingStarted;
                    dispatchingFinished = currentState.dispatchingFinished;
                    currentlyDelayedDispatching = currentState.currentlyDelayedDispatching;

                }

                if (!chained) {
                    if (normalDispatchOrDelayed) {
                        dispatchingStarted = true;
                    } else {
                        currentlyDelayedDispatching = true;
                    }
                }

                if (currentState == null || currentState.dataLoaderInvocation == null) {
                    if (normalDispatchOrDelayed) {
                        dispatchingFinished = true;
                    } else {
                        currentlyDelayedDispatching = false;
                    }
                }

                StateForLevel newState = new StateForLevel(null, dispatchingStarted, dispatchingFinished, currentlyDelayedDispatching, null);

                if (currentStateRef.compareAndSet(currentState, newState)) {
                    return currentState;
                }
            }
        }


        public boolean newDataLoaderInvocation(DataLoaderInvocation dataLoaderInvocation) {
            int level = dataLoaderInvocation.level;
            AtomicReference<@Nullable StateForLevel> currentStateRef = stateMapPerLevel.computeIfAbsent(level, __ -> new AtomicReference<>());
            while (true) {
                StateForLevel currentState = currentStateRef.get();

                boolean dispatchingStarted = false;
                boolean dispatchingFinished = false;
                boolean currentlyDelayedDispatching = false;

                if (currentState != null) {
                    dispatchingStarted = currentState.dispatchingStarted;
                    dispatchingFinished = currentState.dispatchingFinished;
                    currentlyDelayedDispatching = currentState.currentlyDelayedDispatching;

                }

                // we need to start a new delayed dispatching if
                // the normal dispatching is finished and there is no currently delayed dispatching for this level
                boolean newDelayedInvocation = dispatchingFinished && !currentlyDelayedDispatching;
                if (newDelayedInvocation) {
                    currentlyDelayedDispatching = true;
                }

                StateForLevel newState = new StateForLevel(dataLoaderInvocation, dispatchingStarted, dispatchingFinished, currentlyDelayedDispatching, currentState);

                if (currentStateRef.compareAndSet(currentState, newState)) {
                    return newDelayedInvocation;
                }
            }
        }

        public void clear() {
            stateMapPerLevel.clear();
        }

    }

    private static class CallStack {

        /**
         * We track three things per level:
         * - the number of execute object calls
         * - the number of object completion calls
         * - if the level is already dispatched
         * <p/>
         * The number of execute object calls is the number of times the execution
         * of a field with sub selection (meaning it is an object) started.
         * <p/>
         * For each execute object call there will be one matching object completion call,
         * indicating that the all fields in the sub selection have been fetched AND completed.
         * Completion implies the fetched value is "resolved" (CompletableFuture is completed if it was a CF)
         * and it the engine has processed it and called any needed subsequent execute object calls (if the result
         * was none null and of Object of [Object] (or [[Object]] etc).
         * <p/>
         * Together we know a that a level is ready for dispatch if:
         * - the parent was dispatched
         * - the #executeObject == #completionFinished in the grandparent level.
         * <p/>
         * The second condition implies that all execute object calls in the parent level happened
         * which again implies that all fetch fields in the current level have happened.
         * <p/>
         * For the first level we track only if all expected fetched field calls have happened.
         */

        /**
         * The whole algo is impleted lock free and relies purely on CAS methods to handle concurrency.
         */

        static class StateForLevel {
            private final int happenedCompletionFinishedCount;
            private final int happenedExecuteObjectCalls;


            public StateForLevel() {
                this.happenedCompletionFinishedCount = 0;
                this.happenedExecuteObjectCalls = 0;
            }

            public StateForLevel(int happenedCompletionFinishedCount, int happenedExecuteObjectCalls) {
                this.happenedCompletionFinishedCount = happenedCompletionFinishedCount;
                this.happenedExecuteObjectCalls = happenedExecuteObjectCalls;
            }

            public StateForLevel(StateForLevel other) {
                this.happenedCompletionFinishedCount = other.happenedCompletionFinishedCount;
                this.happenedExecuteObjectCalls = other.happenedExecuteObjectCalls;
            }

            public StateForLevel copy() {
                return new StateForLevel(this);
            }

            public StateForLevel increaseHappenedCompletionFinishedCount() {
                return new StateForLevel(happenedCompletionFinishedCount + 1, happenedExecuteObjectCalls);
            }

            public StateForLevel increaseHappenedExecuteObjectCalls() {
                return new StateForLevel(happenedCompletionFinishedCount, happenedExecuteObjectCalls + 1);
            }

        }

        private volatile int expectedFirstLevelFetchCount;
        private final AtomicInteger happenedFirstLevelFetchCount = new AtomicInteger();


        private final Map<Integer, AtomicReference<StateForLevel>> stateForLevelMap = new ConcurrentHashMap<>();

        private final Set<Integer> dispatchedLevels = ConcurrentHashMap.newKeySet();

        public ChainedDLStack chainedDLStack = new ChainedDLStack();

        private final AtomicInteger deferredFragmentRootFieldsCompleted = new AtomicInteger();

        public CallStack() {
        }


        public StateForLevel get(int level) {
            AtomicReference<StateForLevel> dataPerLevelAtomicReference = stateForLevelMap.computeIfAbsent(level, __ -> new AtomicReference<>(new StateForLevel()));
            return Assert.assertNotNull(dataPerLevelAtomicReference.get());
        }

        public boolean tryUpdateLevel(int level, StateForLevel oldData, StateForLevel newData) {
            AtomicReference<StateForLevel> dataPerLevelAtomicReference = Assert.assertNotNull(stateForLevelMap.get(level));
            return dataPerLevelAtomicReference.compareAndSet(oldData, newData);
        }


        public void clear() {
            dispatchedLevels.clear();
            stateForLevelMap.clear();
            expectedFirstLevelFetchCount = 0;
            happenedFirstLevelFetchCount.set(0);
            deferredFragmentRootFieldsCompleted.set(0);
            chainedDLStack.clear();
        }
    }

    public PerLevelDataLoaderDispatchStrategy(ExecutionContext executionContext) {
        this.initialCallStack = new CallStack();
        this.executionContext = executionContext;

        GraphQLContext graphQLContext = executionContext.getGraphQLContext();

        this.enableDataLoaderChaining = graphQLContext.getBoolean(DataLoaderDispatchingContextKeys.ENABLE_DATA_LOADER_CHAINING, false);
        this.profiler = executionContext.getProfiler();
    }


    @Override
    public void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) {
        Assert.assertTrue(parameters.getExecutionStepInfo().getPath().isRootPath());
        // no concurrency access happening
        CallStack.StateForLevel currentState = initialCallStack.get(0);
        initialCallStack.tryUpdateLevel(0, currentState, new CallStack.StateForLevel(0, 1));
        initialCallStack.expectedFirstLevelFetchCount = fieldCount;
    }

    @Override
    public void executionSerialStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        callStack.clear();
        CallStack.StateForLevel currentState = initialCallStack.get(0);
        initialCallStack.tryUpdateLevel(0, currentState, new CallStack.StateForLevel(0, 1));
        // field count is always 1 for serial execution
        initialCallStack.expectedFirstLevelFetchCount = 1;
    }

    @Override
    public void executionStrategyOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        onCompletionFinished(0, callStack);

    }

    @Override
    public void executionStrategyOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        onCompletionFinished(0, callStack);
    }


    @Override
    public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) {
        CallStack callStack = getCallStack(parameters);
        int curLevel = parameters.getPath().getLevel();
        while (true) {
            CallStack.StateForLevel currentState = callStack.get(curLevel);
            if (callStack.tryUpdateLevel(curLevel, currentState, currentState.increaseHappenedExecuteObjectCalls())) {
                return;
            }
        }
    }

    @Override
    public void executeObjectOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, ExecutionStrategyParameters parameters) {
        int curLevel = parameters.getPath().getLevel();
        CallStack callStack = getCallStack(parameters);
        onCompletionFinished(curLevel, callStack);
    }

    @Override
    public void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        int curLevel = parameters.getPath().getLevel();
        onCompletionFinished(curLevel, callStack);
    }


    private void onCompletionFinished(int level, CallStack callStack) {
        while (true) {
            CallStack.StateForLevel currentState = callStack.get(level);
            if (callStack.tryUpdateLevel(level, currentState, currentState.increaseHappenedCompletionFinishedCount())) {
                break;
            }
        }

        // due to synchronous DataFetcher the completion calls on higher levels
        // can happen before the completion calls on lower level
        // this means sometimes a lower level completion means multiple levels are ready
        // hence this loop here until a level is not ready or already dispatched
        int currentLevel = level + 2;
        while (true) {
            boolean levelReady;
            if (callStack.dispatchedLevels.contains(currentLevel)) {
                break;
            }
            levelReady = markLevelAsDispatchedIfReady(currentLevel, callStack);
            if (levelReady) {
                dispatch(currentLevel, callStack);
            } else {
                break;
            }
            currentLevel++;
        }

    }


    @Override
    public void fieldFetched(ExecutionContext executionContext,
                             ExecutionStrategyParameters executionStrategyParameters,
                             DataFetcher<?> dataFetcher,
                             Object fetchedValue,
                             Supplier<DataFetchingEnvironment> dataFetchingEnvironment) {
        CallStack callStack = getCallStack(executionStrategyParameters);
        int level = executionStrategyParameters.getPath().getLevel();
        AlternativeCallContext deferredCallContext = executionStrategyParameters.getDeferredCallContext();
        if (level == 1 || (deferredCallContext != null && level == deferredCallContext.getStartLevel())) {
            int happenedFirstLevelFetchCount = callStack.happenedFirstLevelFetchCount.incrementAndGet();
            if (happenedFirstLevelFetchCount == callStack.expectedFirstLevelFetchCount) {
                callStack.dispatchedLevels.add(level);
                dispatch(level, callStack);
            }
        }
    }


    @Override
    public void newSubscriptionExecution(AlternativeCallContext alternativeCallContext) {
        CallStack callStack = new CallStack();
        alternativeCallContextMap.put(alternativeCallContext, callStack);

    }

    @Override
    public void subscriptionEventCompletionDone(AlternativeCallContext alternativeCallContext) {
        CallStack callStack = getCallStack(alternativeCallContext);
        // this means the single root field is completed (it was never "fetched" because it is
        // the event payload) and we can mark level 1 (root fields) as dispatched and level 0 as completed
        callStack.dispatchedLevels.add(1);
        while (true) {
            CallStack.StateForLevel currentState = callStack.get(0);
            if (callStack.tryUpdateLevel(0, currentState, currentState.increaseHappenedExecuteObjectCalls())) {
                break;
            }
        }
        onCompletionFinished(0, callStack);
    }

    @Override
    public void deferredOnFieldValue(String resultKey, FieldValueInfo fieldValueInfo, Throwable
            throwable, ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        int deferredFragmentRootFieldsCompleted = callStack.deferredFragmentRootFieldsCompleted.incrementAndGet();
        Assert.assertNotNull(parameters.getDeferredCallContext());
        if (deferredFragmentRootFieldsCompleted == parameters.getDeferredCallContext().getFields()) {
            onCompletionFinished(parameters.getDeferredCallContext().getStartLevel() - 1, callStack);
        }

    }


    private CallStack getCallStack(ExecutionStrategyParameters parameters) {
        return getCallStack(parameters.getDeferredCallContext());
    }

    private CallStack getCallStack(@Nullable AlternativeCallContext alternativeCallContext) {
        if (alternativeCallContext == null) {
            return this.initialCallStack;
        } else {
            return alternativeCallContextMap.computeIfAbsent(alternativeCallContext, k -> {
                /*
                  This is only for handling deferred cases. Subscription cases will also get a new callStack, but
                  it is explicitly created in `newSubscriptionExecution`.
                  The reason we are doing this lazily is, because we don't have explicit startDeferred callback.
                 */
                CallStack callStack = new CallStack();
                // on which level the fields are
                int startLevel = k.getStartLevel();
                // how many fields are deferred on this level
                int fields = k.getFields();
                if (startLevel > 1) {
                    // parent level is considered dispatched and all fields completed
                    callStack.dispatchedLevels.add(startLevel - 1);
                    CallStack.StateForLevel stateForLevel = callStack.get(startLevel - 2);
                    CallStack.StateForLevel newStateForLevel = stateForLevel.increaseHappenedExecuteObjectCalls().increaseHappenedCompletionFinishedCount();
                    callStack.tryUpdateLevel(startLevel - 2, stateForLevel, newStateForLevel);
                }
                // the parent will have one completion therefore we set the expectation to 1
                CallStack.StateForLevel stateForLevel = callStack.get(startLevel - 1);
                callStack.tryUpdateLevel(startLevel - 1, stateForLevel, stateForLevel.increaseHappenedExecuteObjectCalls());

                // for the current level we set the fetch expectations
                callStack.expectedFirstLevelFetchCount = fields;
                return callStack;
            });
        }
    }


    private boolean markLevelAsDispatchedIfReady(int level, CallStack callStack) {
        boolean ready = isLevelReady(level, callStack);
        if (ready) {
            if (!callStack.dispatchedLevels.add(level)) {
                // meaning another thread came before us, so they will take care of dispatching
                return false;
            }
            return true;
        }
        return false;
    }


    private boolean isLevelReady(int level, CallStack callStack) {
        Assert.assertTrue(level > 1);
        // we expect that parent has been dispatched and that all parents fields are completed
        // all parent fields completed means all parent parent on completions finished calls must have happened
        int happenedExecuteObjectCalls = callStack.get(level - 2).happenedExecuteObjectCalls;
        return callStack.dispatchedLevels.contains(level - 1) &&
               happenedExecuteObjectCalls > 0 && happenedExecuteObjectCalls == callStack.get(level - 2).happenedCompletionFinishedCount;

    }

    void dispatch(int level, CallStack callStack) {
        if (!enableDataLoaderChaining) {
            profiler.oldStrategyDispatchingAll(level);
            DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry();
            dispatchAll(dataLoaderRegistry, level);
            return;
        }
        dispatchDLCFImpl(level, callStack, true, false);
    }

    private void dispatchAll(DataLoaderRegistry dataLoaderRegistry, int level) {
        for (DataLoader<?, ?> dataLoader : dataLoaderRegistry.getDataLoaders()) {
            dataLoader.dispatch().whenComplete((objects, throwable) -> {
                if (objects != null && objects.size() > 0) {
                    Assert.assertNotNull(dataLoader.getName());
                    profiler.batchLoadedOldStrategy(dataLoader.getName(), level, objects.size());
                }
            });
        }
    }

    private void dispatchDLCFImpl(Integer level, CallStack callStack, boolean normalOrDelayed, boolean chained) {

        ChainedDLStack.StateForLevel stateForLevel = callStack.chainedDLStack.aboutToStartDispatching(level, normalOrDelayed, chained);
        if (stateForLevel == null || stateForLevel.dataLoaderInvocation == null) {
            return;
        }

        List<CompletableFuture> allDispatchedCFs = new ArrayList<>();
        while (stateForLevel != null && stateForLevel.dataLoaderInvocation != null) {
            final DataLoaderInvocation invocation = stateForLevel.dataLoaderInvocation;
            CompletableFuture<List> dispatch = invocation.dataLoader.dispatch();
            allDispatchedCFs.add(dispatch);
            dispatch.whenComplete((objects, throwable) -> {
                if (objects != null && objects.size() > 0) {
                    profiler.batchLoadedNewStrategy(invocation.name, level, objects.size(), !normalOrDelayed, chained);
                }
            });
            stateForLevel = stateForLevel.prev;
        }
        CompletableFuture.allOf(allDispatchedCFs.toArray(new CompletableFuture[0]))
                .whenComplete((unused, throwable) -> {
                    dispatchDLCFImpl(level, callStack, normalOrDelayed, true);
                        }
                );

    }


    public void newDataLoaderInvocation(String resultPath,
                                        int level,
                                        DataLoader dataLoader,
                                        String dataLoaderName,
                                        Object key,
                                        @Nullable AlternativeCallContext alternativeCallContext) {
        if (!enableDataLoaderChaining) {
            return;
        }
        DataLoaderInvocation dataLoaderInvocation = new DataLoaderInvocation(resultPath, level, dataLoader, dataLoaderName, key);
        CallStack callStack = getCallStack(alternativeCallContext);
        boolean newDelayedInvocation = callStack.chainedDLStack.newDataLoaderInvocation(dataLoaderInvocation);
        if (newDelayedInvocation) {
            dispatchDLCFImpl(level, callStack, false, false);
        }
    }

    /**
     * A single data loader invocation.
     */
    private static class DataLoaderInvocation {
        final String resultPath;
        final int level;
        final DataLoader dataLoader;
        final String name;
        final Object key;

        public DataLoaderInvocation(String resultPath, int level, DataLoader dataLoader, String name, Object key) {
            this.resultPath = resultPath;
            this.level = level;
            this.dataLoader = dataLoader;
            this.name = name;
            this.key = key;
        }

        @Override
        public String toString() {
            return "ResultPathWithDataLoader{" +
                   "resultPath='" + resultPath + '\'' +
                   ", level=" + level +
                   ", key=" + key +
                   ", name='" + name + '\'' +
                   '}';
        }
    }

}