PerLevelDataLoaderDispatchStrategy.java
package graphql.execution.instrumentation.dataloader;
import graphql.Assert;
import graphql.GraphQLContext;
import graphql.Internal;
import graphql.Profiler;
import graphql.execution.DataLoaderDispatchStrategy;
import graphql.execution.ExecutionContext;
import graphql.execution.ExecutionStrategyParameters;
import graphql.execution.FieldValueInfo;
import graphql.execution.incremental.AlternativeCallContext;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.util.LockKit;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderRegistry;
import org.jspecify.annotations.NullMarked;
import org.jspecify.annotations.Nullable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
@Internal
@NullMarked
public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStrategy {
private final CallStack initialCallStack;
private final ExecutionContext executionContext;
private final boolean enableDataLoaderChaining;
private final Profiler profiler;
private final Map<AlternativeCallContext, CallStack> deferredCallStackMap = new ConcurrentHashMap<>();
private static class CallStack {
private final LockKit.ReentrantLock lock = new LockKit.ReentrantLock();
/**
* A general overview of teh tracked data:
* There are three aspects tracked per level:
* - number of execute object calls (executeObject)
* - number of fetches
* - number of sub selections finished fetching
* <p/>
* The level for an execute object call is the level of the field in the query: for
* { a {b {c}}} the level of a is 1, b is 2 and c is not an object
* <p/>
* For fetches the level is the level of the field fetched
* <p/>
* For sub selections finished it is the level of the fields inside the sub selection:
* {a1 { b c} a2 } the level of {a1 a2} is 1, the level of {b c} is 2
* <p/>
* <p/>
* A finished subselection means we can predict the number of execute object calls in the same level as the subselection:
* { a {x} b {y} }
* If a is a list of 3 objects and b is a list of 2 objects we expect 3 + 2 = 5 execute object calls on the level 1 to be happening
* <p/>
* An executed object call again means we can predict the number of fetches in the next level:
* Execute Object a with { a {f1 f2 f3} } means we expect 3 fetches on level 2.
* <p/>
* This means we know a level is ready to be dispatched if:
* - all subselections done in the parent level
* - all execute objects calls in the parent level are done
* - all expected fetched happened in the current level
*/
private final LevelMap expectedFetchCountPerLevel = new LevelMap();
private final LevelMap fetchCountPerLevel = new LevelMap();
// an object call means a sub selection of a field of type object/interface/union
// the number of fields for sub selections increases the expected fetch count for this level
private final LevelMap expectedExecuteObjectCallsPerLevel = new LevelMap();
private final LevelMap happenedExecuteObjectCallsPerLevel = new LevelMap();
// this means one sub selection has been fully fetched
// and the expected execute objects calls for the next level have been calculated
private final LevelMap happenedOnFieldValueCallsPerLevel = new LevelMap();
private final Set<Integer> dispatchedLevels = ConcurrentHashMap.newKeySet();
// all levels that are ready to be dispatched
private int highestReadyLevel;
/**
* Data for chained dispatching.
* A result path is used to identify a DataFetcher.
*/
private final List<DataLoaderInvocation> allDataLoaderInvocations = Collections.synchronizedList(new ArrayList<>());
private final Map<Integer, Set<DataLoaderInvocation>> levelToDataLoaderInvocation = new ConcurrentHashMap<>();
private final Set<Integer> dispatchingStartedPerLevel = ConcurrentHashMap.newKeySet();
private final Set<Integer> dispatchingFinishedPerLevel = ConcurrentHashMap.newKeySet();
private final Set<Integer> currentlyDelayedDispatchingLevels = ConcurrentHashMap.newKeySet();
private final List<FieldValueInfo> deferredFragmentRootFieldsFetched = new ArrayList<>();
public CallStack() {
// in the first level there is only one sub selection,
// so we only expect one execute object call (which is actually an executionStrategy call)
expectedExecuteObjectCallsPerLevel.set(0, 1);
}
public void addDataLoaderInvocationForLevel(int level, DataLoaderInvocation dataLoaderInvocation) {
levelToDataLoaderInvocation.computeIfAbsent(level, k -> new LinkedHashSet<>()).add(dataLoaderInvocation);
}
void increaseExpectedFetchCount(int level, int count) {
expectedFetchCountPerLevel.increment(level, count);
}
void clearExpectedFetchCount() {
expectedFetchCountPerLevel.clear();
}
void increaseFetchCount(int level) {
fetchCountPerLevel.increment(level, 1);
}
void clearFetchCount() {
fetchCountPerLevel.clear();
}
void increaseExpectedExecuteObjectCalls(int level, int count) {
expectedExecuteObjectCallsPerLevel.increment(level, count);
}
void clearExpectedObjectCalls() {
expectedExecuteObjectCallsPerLevel.clear();
}
void increaseHappenedExecuteObjectCalls(int level) {
happenedExecuteObjectCallsPerLevel.increment(level, 1);
}
void clearHappenedExecuteObjectCalls() {
happenedExecuteObjectCallsPerLevel.clear();
}
void increaseHappenedOnFieldValueCalls(int level) {
happenedOnFieldValueCallsPerLevel.increment(level, 1);
}
void clearHappenedOnFieldValueCalls() {
happenedOnFieldValueCallsPerLevel.clear();
}
boolean allExecuteObjectCallsHappened(int level) {
return happenedExecuteObjectCallsPerLevel.get(level) == expectedExecuteObjectCallsPerLevel.get(level);
}
boolean allSubSelectionsFetchingHappened(int subSelectionLevel) {
return happenedOnFieldValueCallsPerLevel.get(subSelectionLevel) == expectedExecuteObjectCallsPerLevel.get(subSelectionLevel - 1);
}
boolean allFetchesHappened(int level) {
return fetchCountPerLevel.get(level) == expectedFetchCountPerLevel.get(level);
}
void clearDispatchLevels() {
dispatchedLevels.clear();
}
@Override
public String toString() {
return "CallStack{" +
"expectedFetchCountPerLevel=" + expectedFetchCountPerLevel +
", fetchCountPerLevel=" + fetchCountPerLevel +
", expectedExecuteObjectCallsPerLevel=" + expectedExecuteObjectCallsPerLevel +
", happenedExecuteObjectCallsPerLevel=" + happenedExecuteObjectCallsPerLevel +
", happenedOnFieldValueCallsPerLevel=" + happenedOnFieldValueCallsPerLevel +
", dispatchedLevels" + dispatchedLevels +
'}';
}
public void setDispatchedLevel(int level) {
if (!dispatchedLevels.add(level)) {
Assert.assertShouldNeverHappen("level " + level + " already dispatched");
}
}
}
public PerLevelDataLoaderDispatchStrategy(ExecutionContext executionContext) {
this.initialCallStack = new CallStack();
this.executionContext = executionContext;
GraphQLContext graphQLContext = executionContext.getGraphQLContext();
this.enableDataLoaderChaining = graphQLContext.getBoolean(DataLoaderDispatchingContextKeys.ENABLE_DATA_LOADER_CHAINING, false);
this.profiler = executionContext.getProfiler();
}
@Override
public void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) {
Assert.assertTrue(parameters.getExecutionStepInfo().getPath().isRootPath());
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(0, fieldCount, initialCallStack);
}
@Override
public void executionSerialStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
CallStack callStack = getCallStack(parameters);
resetCallStack(callStack);
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(0, 1, callStack);
}
@Override
public void executionStrategyOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, ExecutionStrategyParameters parameters) {
CallStack callStack = getCallStack(parameters);
// the root fields are the root sub selection on level 1
onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, 1, callStack);
}
@Override
public void executionStrategyOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) {
CallStack callStack = getCallStack(parameters);
callStack.lock.runLocked(() ->
callStack.increaseHappenedOnFieldValueCalls(1)
);
}
private CallStack getCallStack(ExecutionStrategyParameters parameters) {
return getCallStack(parameters.getDeferredCallContext());
}
private CallStack getCallStack(@Nullable AlternativeCallContext alternativeCallContext) {
if (alternativeCallContext == null) {
return this.initialCallStack;
} else {
return deferredCallStackMap.computeIfAbsent(alternativeCallContext, k -> {
CallStack callStack = new CallStack();
int startLevel = alternativeCallContext.getStartLevel();
int fields = alternativeCallContext.getFields();
callStack.lock.runLocked(() -> {
// we make sure that startLevel-1 is considered done
callStack.expectedExecuteObjectCallsPerLevel.set(0, 0); // set to 1 in the constructor of CallStack
callStack.expectedExecuteObjectCallsPerLevel.set(startLevel - 1, 1);
callStack.happenedExecuteObjectCallsPerLevel.set(startLevel - 1, 1);
callStack.highestReadyLevel = startLevel - 1;
callStack.increaseExpectedFetchCount(startLevel, fields);
});
return callStack;
});
}
}
@Override
public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) {
CallStack callStack = getCallStack(parameters);
int curLevel = parameters.getPath().getLevel();
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, fieldCount, callStack);
}
@Override
public void executeObjectOnFieldValuesInfo
(List<FieldValueInfo> fieldValueInfoList, ExecutionStrategyParameters parameters) {
// the level of the sub selection that is fully fetched is one level more than parameters level
int curLevel = parameters.getPath().getLevel() + 1;
CallStack callStack = getCallStack(parameters);
onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, curLevel, callStack);
}
@Override
public void newSubscriptionExecution(FieldValueInfo fieldValueInfo, AlternativeCallContext alternativeCallContext) {
CallStack callStack = getCallStack(alternativeCallContext);
callStack.increaseFetchCount(1);
callStack.deferredFragmentRootFieldsFetched.add(fieldValueInfo);
onFieldValuesInfoDispatchIfNeeded(callStack.deferredFragmentRootFieldsFetched, 1, callStack);
}
@Override
public void deferredOnFieldValue(String resultKey, FieldValueInfo fieldValueInfo, Throwable
throwable, ExecutionStrategyParameters parameters) {
CallStack callStack = getCallStack(parameters);
boolean ready = callStack.lock.callLocked(() -> {
callStack.deferredFragmentRootFieldsFetched.add(fieldValueInfo);
Assert.assertNotNull(parameters.getDeferredCallContext());
return callStack.deferredFragmentRootFieldsFetched.size() == parameters.getDeferredCallContext().getFields();
});
if (ready) {
int curLevel = parameters.getPath().getLevel();
onFieldValuesInfoDispatchIfNeeded(callStack.deferredFragmentRootFieldsFetched, curLevel, callStack);
}
}
@Override
public void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) {
CallStack callStack = getCallStack(parameters);
// the level of the sub selection that is errored is one level more than parameters level
int curLevel = parameters.getPath().getLevel() + 1;
callStack.lock.runLocked(() ->
callStack.increaseHappenedOnFieldValueCalls(curLevel)
);
}
private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel,
int fieldCount,
CallStack callStack) {
callStack.lock.runLocked(() -> {
callStack.increaseHappenedExecuteObjectCalls(curLevel);
callStack.increaseExpectedFetchCount(curLevel + 1, fieldCount);
});
}
private void resetCallStack(CallStack callStack) {
callStack.lock.runLocked(() -> {
callStack.clearDispatchLevels();
callStack.clearExpectedObjectCalls();
callStack.clearExpectedFetchCount();
callStack.clearFetchCount();
callStack.clearHappenedExecuteObjectCalls();
callStack.clearHappenedOnFieldValueCalls();
callStack.expectedExecuteObjectCallsPerLevel.set(0, 1);
callStack.currentlyDelayedDispatchingLevels.clear();
callStack.allDataLoaderInvocations.clear();
callStack.levelToDataLoaderInvocation.clear();
callStack.highestReadyLevel = 0;
});
}
private void onFieldValuesInfoDispatchIfNeeded(List<FieldValueInfo> fieldValueInfoList,
int subSelectionLevel,
CallStack callStack) {
Integer dispatchLevel = callStack.lock.callLocked(() ->
handleSubSelectionFetched(fieldValueInfoList, subSelectionLevel, callStack)
);
// the handle on field values check for the next level if it is ready
if (dispatchLevel != null) {
dispatch(dispatchLevel, callStack);
}
}
//
// thread safety: called with callStack.lock
//
private @Nullable Integer handleSubSelectionFetched(List<FieldValueInfo> fieldValueInfos, int subSelectionLevel, CallStack
callStack) {
callStack.increaseHappenedOnFieldValueCalls(subSelectionLevel);
int expectedOnObjectCalls = getObjectCountForList(fieldValueInfos);
// we expect on the level of the current sub selection #expectedOnObjectCalls execute object calls
callStack.increaseExpectedExecuteObjectCalls(subSelectionLevel, expectedOnObjectCalls);
// maybe the object calls happened already (because the DataFetcher return directly values synchronously)
// therefore we check the next levels if they are ready
// this means we could skip some level because the higher level is also already ready,
// which means there is nothing to dispatch on these levels: if x and x+1 is ready, it means there are no
// data loaders used on x
//
// if data loader chaining is disabled (the old algo) the level we dispatch is not really relevant as
// we dispatch the whole registry anyway
return getHighestReadyLevel(subSelectionLevel + 1, callStack);
}
/**
* the amount of (non nullable) objects that will require an execute object call
*/
private int getObjectCountForList(List<FieldValueInfo> fieldValueInfos) {
int result = 0;
for (FieldValueInfo fieldValueInfo : fieldValueInfos) {
if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) {
result += 1;
} else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) {
result += getObjectCountForList(fieldValueInfo.getFieldValueInfos());
}
}
return result;
}
@Override
public void fieldFetched(ExecutionContext executionContext,
ExecutionStrategyParameters executionStrategyParameters,
DataFetcher<?> dataFetcher,
Object fetchedValue,
Supplier<DataFetchingEnvironment> dataFetchingEnvironment) {
CallStack callStack = getCallStack(executionStrategyParameters);
int level = executionStrategyParameters.getPath().getLevel();
boolean dispatchNeeded = callStack.lock.callLocked(() -> {
callStack.increaseFetchCount(level);
return dispatchIfNeeded(level, callStack);
});
if (dispatchNeeded) {
dispatch(level, callStack);
}
}
//
// thread safety : called with callStack.lock
//
private boolean dispatchIfNeeded(int level, CallStack callStack) {
boolean ready = checkLevelBeingReady(level, callStack);
if (ready) {
callStack.setDispatchedLevel(level);
return true;
}
return false;
}
//
// thread safety: called with callStack.lock
//
private @Nullable Integer getHighestReadyLevel(int startFrom, CallStack callStack) {
int curLevel = callStack.highestReadyLevel;
while (true) {
if (!checkLevelImpl(curLevel + 1, callStack)) {
callStack.highestReadyLevel = curLevel;
return curLevel >= startFrom ? curLevel : null;
}
curLevel++;
}
}
private boolean checkLevelBeingReady(int level, CallStack callStack) {
Assert.assertTrue(level > 0);
if (level <= callStack.highestReadyLevel) {
return true;
}
for (int i = callStack.highestReadyLevel + 1; i <= level; i++) {
if (!checkLevelImpl(i, callStack)) {
return false;
}
}
callStack.highestReadyLevel = level;
return true;
}
private boolean checkLevelImpl(int level, CallStack callStack) {
// a level with zero expectations can't be ready
if (callStack.expectedFetchCountPerLevel.get(level) == 0) {
return false;
}
// first we make sure that the expected fetch count is correct
// by verifying that the parent level all execute object + sub selection were fetched
if (!callStack.allExecuteObjectCallsHappened(level - 1)) {
return false;
}
if (level > 1 && !callStack.allSubSelectionsFetchingHappened(level - 1)) {
return false;
}
// the main check: all fetches must have happened
if (!callStack.allFetchesHappened(level)) {
return false;
}
return true;
}
void dispatch(int level, CallStack callStack) {
if (!enableDataLoaderChaining) {
profiler.oldStrategyDispatchingAll(level);
DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry();
dispatchAll(dataLoaderRegistry, level);
return;
}
Set<DataLoaderInvocation> dataLoaderInvocations = callStack.levelToDataLoaderInvocation.get(level);
if (dataLoaderInvocations != null) {
callStack.lock.runLocked(() -> {
callStack.dispatchingStartedPerLevel.add(level);
});
dispatchDLCFImpl(level, callStack, false, false);
} else {
callStack.lock.runLocked(() -> {
callStack.dispatchingStartedPerLevel.add(level);
callStack.dispatchingFinishedPerLevel.add(level);
});
}
}
private void dispatchAll(DataLoaderRegistry dataLoaderRegistry, int level) {
for (DataLoader<?, ?> dataLoader : dataLoaderRegistry.getDataLoaders()) {
dataLoader.dispatch().whenComplete((objects, throwable) -> {
if (objects != null && objects.size() > 0) {
Assert.assertNotNull(dataLoader.getName());
profiler.batchLoadedOldStrategy(dataLoader.getName(), level, objects.size());
}
});
}
}
private void dispatchDLCFImpl(Integer level, CallStack callStack, boolean delayed, boolean chained) {
List<DataLoaderInvocation> relevantDataLoaderInvocations = callStack.lock.callLocked(() -> {
List<DataLoaderInvocation> result = new ArrayList<>();
for (DataLoaderInvocation dataLoaderInvocation : callStack.allDataLoaderInvocations) {
if (dataLoaderInvocation.level == level) {
result.add(dataLoaderInvocation);
}
}
callStack.allDataLoaderInvocations.removeAll(result);
if (result.size() > 0) {
return result;
}
if (delayed) {
callStack.currentlyDelayedDispatchingLevels.remove(level);
} else {
callStack.dispatchingFinishedPerLevel.add(level);
}
return result;
});
if (relevantDataLoaderInvocations.size() == 0) {
return;
}
List<CompletableFuture> allDispatchedCFs = new ArrayList<>();
for (DataLoaderInvocation dataLoaderInvocation : relevantDataLoaderInvocations) {
CompletableFuture<List> dispatch = dataLoaderInvocation.dataLoader.dispatch();
allDispatchedCFs.add(dispatch);
dispatch.whenComplete((objects, throwable) -> {
if (objects != null && objects.size() > 0) {
profiler.batchLoadedNewStrategy(dataLoaderInvocation.name, level, objects.size(), delayed, chained);
}
});
}
CompletableFuture.allOf(allDispatchedCFs.toArray(new CompletableFuture[0]))
.whenComplete((unused, throwable) -> {
dispatchDLCFImpl(level, callStack, delayed, true);
}
);
}
public void newDataLoaderInvocation(String resultPath,
int level,
DataLoader dataLoader,
String dataLoaderName,
Object key,
@Nullable AlternativeCallContext alternativeCallContext) {
if (!enableDataLoaderChaining) {
return;
}
DataLoaderInvocation dataLoaderInvocation = new DataLoaderInvocation(resultPath, level, dataLoader, dataLoaderName, key);
CallStack callStack = getCallStack(alternativeCallContext);
boolean startNewDelayedDispatching = callStack.lock.callLocked(() -> {
callStack.allDataLoaderInvocations.add(dataLoaderInvocation);
boolean started = callStack.dispatchingStartedPerLevel.contains(level);
if (!started) {
callStack.addDataLoaderInvocationForLevel(level, dataLoaderInvocation);
}
boolean finished = callStack.dispatchingFinishedPerLevel.contains(level);
// we need to start a new delayed dispatching if
// the normal dispatching is finished and there is no currently delayed dispatching for this level
boolean newDelayedInvocation = finished && !callStack.currentlyDelayedDispatchingLevels.contains(level);
if (newDelayedInvocation) {
callStack.currentlyDelayedDispatchingLevels.add(level);
}
return newDelayedInvocation;
});
if (startNewDelayedDispatching) {
dispatchDLCFImpl(level, callStack, true, false);
}
}
/**
* A single data loader invocation.
*/
private static class DataLoaderInvocation {
final String resultPath;
final int level;
final DataLoader dataLoader;
final String name;
final Object key;
public DataLoaderInvocation(String resultPath, int level, DataLoader dataLoader, String name, Object key) {
this.resultPath = resultPath;
this.level = level;
this.dataLoader = dataLoader;
this.name = name;
this.key = key;
}
@Override
public String toString() {
return "ResultPathWithDataLoader{" +
"resultPath='" + resultPath + '\'' +
", level=" + level +
", key=" + key +
", name='" + name + '\'' +
'}';
}
}
}