ExhaustedDataLoaderDispatchStrategy.java

package graphql.execution.instrumentation.dataloader;

import graphql.Assert;
import graphql.Internal;
import graphql.Profiler;
import graphql.execution.DataLoaderDispatchStrategy;
import graphql.execution.ExecutionContext;
import graphql.execution.ExecutionStrategyParameters;
import graphql.execution.incremental.AlternativeCallContext;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderRegistry;
import org.jspecify.annotations.NullMarked;
import org.jspecify.annotations.Nullable;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

@Internal
@NullMarked
public class ExhaustedDataLoaderDispatchStrategy implements DataLoaderDispatchStrategy {

    private final CallStack initialCallStack;
    private final ExecutionContext executionContext;

    private final Profiler profiler;

    private final Map<AlternativeCallContext, CallStack> alternativeCallContextMap = new ConcurrentHashMap<>();


    private static class CallStack {

        // 30 bits for objectRunningCount
        // 1 bit for dataLoaderToDispatch
        // 1 bit for currentlyDispatching

        // Bit positions (from right to left)
        static final int currentlyDispatchingShift = 0;
        static final int dataLoaderToDispatchShift = 1;
        static final int objectRunningCountShift = 2;

        // mask
        static final int booleanMask = 1;
        static final int objectRunningCountMask = (1 << 30) - 1;

        public static int getObjectRunningCount(int state) {
            return (state >> objectRunningCountShift) & objectRunningCountMask;
        }

        public static int setObjectRunningCount(int state, int objectRunningCount) {
            return (state & ~(objectRunningCountMask << objectRunningCountShift)) |
                   (objectRunningCount << objectRunningCountShift);
        }

        public static int setDataLoaderToDispatch(int state, boolean dataLoaderToDispatch) {
            return (state & ~(booleanMask << dataLoaderToDispatchShift)) |
                   ((dataLoaderToDispatch ? 1 : 0) << dataLoaderToDispatchShift);
        }

        public static int setCurrentlyDispatching(int state, boolean currentlyDispatching) {
            return (state & ~(booleanMask << currentlyDispatchingShift)) |
                   ((currentlyDispatching ? 1 : 0) << currentlyDispatchingShift);
        }


        public static boolean getDataLoaderToDispatch(int state) {
            return ((state >> dataLoaderToDispatchShift) & booleanMask) != 0;
        }

        public static boolean getCurrentlyDispatching(int state) {
            return ((state >> currentlyDispatchingShift) & booleanMask) != 0;
        }


        public int incrementObjectRunningCount() {
            while (true) {
                int oldState = getState();
                int objectRunningCount = getObjectRunningCount(oldState);
                int newState = setObjectRunningCount(oldState, objectRunningCount + 1);
                if (tryUpdateState(oldState, newState)) {
                    return newState;
                }
            }
        }

        public int decrementObjectRunningCount() {
            while (true) {
                int oldState = getState();
                int objectRunningCount = getObjectRunningCount(oldState);
                int newState = setObjectRunningCount(oldState, objectRunningCount - 1);
                if (tryUpdateState(oldState, newState)) {
                    return newState;
                }
            }
        }

        // for debugging
        public static String printState(int state) {
            return "objectRunningCount: " + getObjectRunningCount(state) +
                   ",dataLoaderToDispatch: " + getDataLoaderToDispatch(state) +
                   ",currentlyDispatching: " + getCurrentlyDispatching(state);
        }

        private final AtomicInteger state = new AtomicInteger();

        public int getState() {
            return state.get();
        }

        public boolean tryUpdateState(int oldState, int newState) {
            return state.compareAndSet(oldState, newState);
        }

        private final AtomicInteger deferredFragmentRootFieldsCompleted = new AtomicInteger();

        public CallStack() {
        }


        public void clear() {
            deferredFragmentRootFieldsCompleted.set(0);
            state.set(0);
        }
    }

    public ExhaustedDataLoaderDispatchStrategy(ExecutionContext executionContext) {
        this.initialCallStack = new CallStack();
        this.executionContext = executionContext;

        this.profiler = executionContext.getProfiler();
    }


    @Override
    public void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) {
        Assert.assertTrue(parameters.getExecutionStepInfo().getPath().isRootPath());
        initialCallStack.incrementObjectRunningCount();
    }

    @Override
    public void finishedFetching(ExecutionContext executionContext, ExecutionStrategyParameters newParameters) {
        CallStack callStack = getCallStack(newParameters);
        decrementObjectRunningAndMaybeDispatch(callStack);
    }

    @Override
    public void executionSerialStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        callStack.clear();
        callStack.incrementObjectRunningCount();
    }

    @Override
    public void newSubscriptionExecution(AlternativeCallContext alternativeCallContext) {
        CallStack callStack = new CallStack();
        alternativeCallContextMap.put(alternativeCallContext, callStack);
        callStack.incrementObjectRunningCount();
    }

    @Override
    public void subscriptionEventCompletionDone(AlternativeCallContext alternativeCallContext) {
        CallStack callStack = getCallStack(alternativeCallContext);
        decrementObjectRunningAndMaybeDispatch(callStack);
    }

    @Override
    public void deferFieldFetched(ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        int deferredFragmentRootFieldsCompleted = callStack.deferredFragmentRootFieldsCompleted.incrementAndGet();
        Assert.assertNotNull(parameters.getDeferredCallContext());
        if (deferredFragmentRootFieldsCompleted == parameters.getDeferredCallContext().getFields()) {
            decrementObjectRunningAndMaybeDispatch(callStack);
        }
    }

    @Override
    public void startComplete(ExecutionStrategyParameters parameters) {
        getCallStack(parameters).incrementObjectRunningCount();
    }

    @Override
    public void stopComplete(ExecutionStrategyParameters parameters) {
        CallStack callStack = getCallStack(parameters);
        decrementObjectRunningAndMaybeDispatch(callStack);
    }

    private CallStack getCallStack(ExecutionStrategyParameters parameters) {
        return getCallStack(parameters.getDeferredCallContext());
    }

    private CallStack getCallStack(@Nullable AlternativeCallContext alternativeCallContext) {
        if (alternativeCallContext == null) {
            return this.initialCallStack;
        } else {
            return alternativeCallContextMap.computeIfAbsent(alternativeCallContext, k -> {
                /*
                  This is only for handling deferred cases. Subscription cases will also get a new callStack, but
                  it is explicitly created in `newSubscriptionExecution`.
                  The reason we are doing this lazily is, because we don't have explicit startDeferred callback.
                 */
                CallStack callStack = new CallStack();
                callStack.incrementObjectRunningCount();
                return callStack;
            });
        }
    }


    private void decrementObjectRunningAndMaybeDispatch(CallStack callStack) {
        int newState = callStack.decrementObjectRunningCount();
        if (CallStack.getObjectRunningCount(newState) == 0 && CallStack.getDataLoaderToDispatch(newState) && !CallStack.getCurrentlyDispatching(newState)) {
            dispatchImpl(callStack);
        }
    }

    private void newDataLoaderInvocationMaybeDispatch(CallStack callStack) {
        int currentState;
        while (true) {
            int oldState = callStack.getState();
            if (CallStack.getDataLoaderToDispatch(oldState)) {
                return;
            }
            int newState = CallStack.setDataLoaderToDispatch(oldState, true);
            if (callStack.tryUpdateState(oldState, newState)) {
                currentState = newState;
                break;
            }
        }

        if (CallStack.getObjectRunningCount(currentState) == 0 && !CallStack.getCurrentlyDispatching(currentState)) {
            dispatchImpl(callStack);
        }
    }


    private void dispatchImpl(CallStack callStack) {
        while (true) {
            int oldState = callStack.getState();
            if (!CallStack.getDataLoaderToDispatch(oldState)) {
                int newState = CallStack.setCurrentlyDispatching(oldState, false);
                if (callStack.tryUpdateState(oldState, newState)) {
                    return;
                }
            }
            int newState = CallStack.setCurrentlyDispatching(oldState, true);
            newState = CallStack.setDataLoaderToDispatch(newState, false);
            if (callStack.tryUpdateState(oldState, newState)) {
                break;
            }
        }

        DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry();
        List<DataLoader<?, ?>> dataLoaders = dataLoaderRegistry.getDataLoaders();
        List<CompletableFuture<? extends List<?>>> allDispatchedCFs = new ArrayList<>();
        for (DataLoader<?, ?> dataLoader : dataLoaders) {
            CompletableFuture<? extends List<?>> dispatch = dataLoader.dispatch();
            allDispatchedCFs.add(dispatch);
        }
        CompletableFuture.allOf(allDispatchedCFs.toArray(new CompletableFuture[0]))
                .whenComplete((unused, throwable) -> {
                    dispatchImpl(callStack);
                });

    }


    public void newDataLoaderInvocation(@Nullable AlternativeCallContext alternativeCallContext) {
        CallStack callStack = getCallStack(alternativeCallContext);
        newDataLoaderInvocationMaybeDispatch(callStack);
    }


}