FFFF Merge pull request #4308 from graphql-java/test/per-level-dispatch-st… · graphql-java/graphql-java@761df31 · GitHub
[go: up one dir, main page]

Skip to content

Commit 761df31

Browse files
authored
Merge pull request #4308 from graphql-java/test/per-level-dispatch-strategy-coverage
Add unit tests for PerLevelDataLoaderDispatchStrategy coverage
2 parents d96a3f4 + 1bf69f1 commit 761df31

File tree

2 files changed

+195
-6
lines changed

2 files changed

+195
-6
lines changed

src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import graphql.schema.DataFetchingEnvironment;
1414
import org.dataloader.DataLoader;
1515
import org.dataloader.DataLoaderRegistry;
16+
import graphql.VisibleForTesting;
1617
import org.jspecify.annotations.NullMarked;
1718
import org.jspecify.annotations.Nullable;
1819

@@ -30,7 +31,8 @@
3031
@NullMarked
3132
public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStrategy {
3233

33-
private final CallStack initialCallStack;
34+
@VisibleForTesting
35+
final CallStack initialCallStack;
3436
private final ExecutionContext executionContext;
3537
private final boolean enableDataLoaderChaining;
3638

@@ -145,7 +147,8 @@ public void clear() {
145147

146148
}
147149

148-
private static class CallStack {
150+
// package-private for testing
151+
static class CallStack {
149152

150153
/**
151154
* We track three things per level:
@@ -177,8 +180,10 @@ private static class CallStack {
177180
*/
178181

179182
static class StateForLevel {
180-
private final int happenedCompletionFinishedCount;
181-
private final int happenedExecuteObjectCalls;
183+
@VisibleForTesting
184+
final int happenedCompletionFinishedCount;
185+
@VisibleForTesting
186+
final int happenedExecuteObjectCalls;
182187

183188

184189
public StateForLevel() {
@@ -216,7 +221,8 @@ public StateForLevel increaseHappenedExecuteObjectCalls() {
216221

217222
private final Map<Integer, AtomicReference<StateForLevel>> stateForLevelMap = new ConcurrentHashMap<>();
218223

219-
private final Set<Integer> dispatchedLevels = ConcurrentHashMap.newKeySet();
224+
@VisibleForTesting
225+
final Set<Integer> dispatchedLevels = ConcurrentHashMap.newKeySet();
220226

221227
public ChainedDLStack chainedDLStack = new ChainedDLStack();
222228

@@ -439,7 +445,8 @@ private CallStack getCallStack(@Nullable AlternativeCallContext alternativeCallC
439445
}
440446

441447

442-
private boolean markLevelAsDispatchedIfReady(int level, CallStack callStack) {
448+
@VisibleForTesting
449+
boolean markLevelAsDispatchedIfReady(int level, CallStack callStack) {
443450
boolean ready = isLevelReady(level, callStack);
444451
if (ready) {
445452
if (!callStack.dispatchedLevels.add(level)) {
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package graphql.execution.instrumentation.dataloader
2+
3+
import graphql.EngineRunningState
4+
import graphql.ExecutionInput
5+
import graphql.GraphQLContext
6+
import graphql.Profiler
7+
import graphql.Scalars
8+
import graphql.execution.AsyncExecutionStrategy
9+
import graphql.execution.CoercedVariables
10+
import graphql.execution.ExecutionContextBuilder
11+
import graphql.execution.ExecutionId
12+
import graphql.execution.ExecutionStepInfo
13+
import graphql.execution.ExecutionStrategyParameters
14+
import graphql.execution.MergedSelectionSet
15+
import graphql.execution.NonNullableFieldValidator
16+
import graphql.execution.ResultPath
17+
import graphql.execution.ValueUnboxer
18+
import graphql.execution.instrumentation.SimplePerformantInstrumentation
19+
import graphql.schema.DataFetcher
20+
import graphql.schema.DataFetchingEnvironment
21+
import org.dataloader.DataLoaderRegistry
22+
import spock.lang.Specification
23+
24+
import java.util.concurrent.CountDownLatch
25+
import java.util.concurrent.Executors
26+
import java.util.concurrent.TimeUnit
27+
import java.util.function.Supplier
28+
29+
import static graphql.StarWarsSchema.starWarsSchema
30+
31+
/**
32+
* Tests for concurrency-dependent code paths in {@link PerLevelDataLoaderDispatchStrategy}
33+
* that are otherwise non-deterministically covered by integration tests.
34+
*/
35+
class PerLevelDataLoaderDispatchStrategyTest extends Specification {
36+
37+
def executionContext
38+
def strategy
39+
40+
void setup() {
41+
def dummyStrategy = new AsyncExecutionStrategy()
42+
def ei = ExecutionInput.newExecutionInput("{ hero { name } }").build()
43+
def builder = ExecutionContextBuilder.newExecutionContextBuilder()
44+
.instrumentation(SimplePerformantInstrumentation.INSTANCE)
45+
.executionId(ExecutionId.from("test"))
46+
.graphQLSchema(starWarsSchema)
47+
.queryStrategy(dummyStrategy)
48+
.mutationStrategy(dummyStrategy)
49+
.subscriptionStrategy(dummyStrategy)
50+
.coercedVariables(CoercedVariables.emptyVariables())
51+
.graphQLContext(GraphQLContext.newContext().build())
52+
.executionInput(ei)
53+
.root("root")
54+
.dataLoaderRegistry(new DataLoaderRegistry())
55+
.locale(Locale.getDefault())
56+
.valueUnboxer(ValueUnboxer.DEFAULT)
57+
.profiler(Profiler.NO_OP)
58+
.engineRunningState(new EngineRunningState(ei, Profiler.NO_OP))
59+
executionContext = builder.build()
60+
strategy = new PerLevelDataLoaderDispatchStrategy(executionContext)
61+
}
62+
63+
private ExecutionStrategyParameters paramsAtLevel(int level) {
64+
def path = ResultPath.rootPath()
65+
for (int i = 0; i < level; i++) {
66+
path = path.segment("f" + i)
67+
}
68+
return ExecutionStrategyParameters.newParameters()
69+
.executionStepInfo(ExecutionStepInfo.newExecutionStepInfo()
70+
.type(Scalars.GraphQLString)
71+
.path(path)
72+
.build())
73+
.fields(MergedSelectionSet.newMergedSelectionSet().build())
74+
.nonNullFieldValidator(new NonNullableFieldValidator(executionContext))
75+
.path(path)
76+
.build()
77+
}
78+
79+
def "markLevelAsDispatchedIfReady returns false when level already dispatched"() {
80+
given:
81+
def callStack = strategy.initialCallStack
82+
def dispatchedLevels = callStack.dispatchedLevels
83+
84+
and: "set up level 0 via executionStrategy and dispatch level 1 via fieldFetched"
85+
def rootParams = paramsAtLevel(0)
86+
strategy.executionStrategy(executionContext, rootParams, 1)
87+
def level1Params = paramsAtLevel(1)
88+
strategy.fieldFetched(executionContext, level1Params,
89+
{ env -> null } as DataFetcher,
90+
"value",
91+
{ -> null } as Supplier<DataFetchingEnvironment>)
92+
93+
and: "make isLevelReady(2) return true by matching completionFinished to executeObjectCalls at level 0"
94+
def state0 = callStack.get(0)
95+
callStack.tryUpdateLevel(0, state0, state0.increaseHappenedCompletionFinishedCount())
96+
97+
expect:
98+
dispatchedLevels.contains(1)
99+
100+
when: "first dispatch of level 2"
101+
def firstResult = strategy.markLevelAsDispatchedIfReady(2, callStack)
102+
103+
then:
104+
firstResult
105+
dispatchedLevels.contains(2)
106+
107+
when: "second dispatch of level 2 (simulates another thread arriving late)"
108+
def secondResult = strategy.markLevelAsDispatchedIfReady(2, callStack)
109+
110+
then:
111+
!secondResult
112+
}
113+
114+
def "concurrent onCompletionFinished races to dispatch same level"() {
115+
given:
116+
def rootParams = paramsAtLevel(0)
117+
strategy.executionStrategy(executionContext, rootParams, 1)
118+
119+
and: "increment executeObjectCalls at level 0 from 1 to 2"
120+
def level0Params = paramsAtLevel(0)
121+
strategy.executeObject(executionContext, level0Params, 1)
122+
123+
and: "dispatch level 1 via fieldFetched"
124+
def level1Params = paramsAtLevel(1)
125+
strategy.fieldFetched(executionContext, level1Params,
126+
{ env -> null } as DataFetcher,
127+
"value",
128+
{ -> null } as Supplier<DataFetchingEnvironment>)
129+
130+
when: "two threads concurrently complete level 0"
131+
def startLatch = new CountDownLatch(1)
132+
def executor = Executors.newFixedThreadPool(2)
133+
134+
def task = {
135+
startLatch.await()
136+
strategy.executeObjectOnFieldValuesInfo(Collections.emptyList(), level0Params)
137+
} as Runnable
138+
139+
executor.submit(task)
140+
executor.submit(task)
141+
startLatch.countDown()
142+
executor.shutdown()
143+
executor.awaitTermination(5, TimeUnit.SECONDS)
144+
145+
then: "level 2 is dispatched exactly once (regardless of which thread won)"
146+
strategy.initialCallStack.dispatchedLevels.contains(2)
147+
}
148+
149+
def "executeObjectOnFieldValuesException calls onCompletionFinished"() {
150+
given:
151+
def rootParams = paramsAtLevel(0)
152+
strategy.executionStrategy(executionContext, rootParams, 1)
153+
154+
and: "dispatch level 1 via fieldFetched"
155+
def level1Params = paramsAtLevel(1)
156+
strategy.fieldFetched(executionContext, level1Params,
157+
{ env -> null } as DataFetcher,
158+
"value",
159+
{ -> null } as Supplier<DataFetchingEnvironment>)
160+
161+
when:
162+
def level2Params = paramsAtLevel(2)
163+
strategy.executeObjectOnFieldValuesException(
164+
new RuntimeException("test error"), level2Params)
165+
166+
then:
167+
strategy.initialCallStack.get(2).happenedCompletionFinishedCount > 0
168+
}
169+
170+
def "executionStrategyOnFieldValuesException calls onCompletionFinished"() {
171+
given:
172+
def rootParams = paramsAtLevel(0)
173+
strategy.executionStrategy(executionContext, rootParams, 1)
174+
175+
when:
176+
strategy.executionStrategyOnFieldValuesException(
177+
new RuntimeException("test error"), rootParams)
178+
179+
then:
180+
strategy.initialCallStack.get(0).happenedCompletionFinishedCount > 0
181+
}
182+
}

0 commit comments

Comments
 (0)
0