From 28ce3efce4c0b9d162dd7b5887e52c55d3f0de13 Mon Sep 17 00:00:00 2001 From: dondonz <13839920+dondonz@users.noreply.github.com> Date: Tue, 19 Mar 2024 11:54:32 +1100 Subject: [PATCH] Backport PR 3525 max result nodes --- .../java/graphql/execution/Execution.java | 1 + .../graphql/execution/ExecutionContext.java | 5 + .../graphql/execution/ExecutionStrategy.java | 29 +++- .../java/graphql/execution/FetchedValue.java | 2 +- .../graphql/execution/FieldValueInfo.java | 2 +- .../graphql/execution/ResultNodesInfo.java | 55 +++++++ src/test/groovy/graphql/GraphQLTest.groovy | 141 ++++++++++++++++++ 7 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 src/main/java/graphql/execution/ResultNodesInfo.java diff --git a/src/main/java/graphql/execution/Execution.java b/src/main/java/graphql/execution/Execution.java index 401258dedb..37b7d5094c 100644 --- a/src/main/java/graphql/execution/Execution.java +++ b/src/main/java/graphql/execution/Execution.java @@ -96,6 +96,7 @@ public CompletableFuture execute(Document document, GraphQLSche .executionInput(executionInput) .build(); + executionContext.getGraphQLContext().put(ResultNodesInfo.RESULT_NODES_INFO, executionContext.getResultNodesInfo()); InstrumentationExecutionParameters parameters = new InstrumentationExecutionParameters( executionInput, graphQLSchema, instrumentationState diff --git a/src/main/java/graphql/execution/ExecutionContext.java b/src/main/java/graphql/execution/ExecutionContext.java index 9c1aa2032b..8c547e23c6 100644 --- a/src/main/java/graphql/execution/ExecutionContext.java +++ b/src/main/java/graphql/execution/ExecutionContext.java @@ -57,6 +57,7 @@ public class ExecutionContext { private final ValueUnboxer valueUnboxer; private final ExecutionInput executionInput; private final Supplier queryTree; + private final ResultNodesInfo resultNodesInfo = new ResultNodesInfo(); ExecutionContext(ExecutionContextBuilder builder) { this.graphQLSchema = builder.graphQLSchema; @@ -291,4 +292,8 @@ public ExecutionContext transform(Consumer builderConsu builderConsumer.accept(builder); return builder.build(); } + + public ResultNodesInfo getResultNodesInfo() { + return resultNodesInfo; + } } diff --git a/src/main/java/graphql/execution/ExecutionStrategy.java b/src/main/java/graphql/execution/ExecutionStrategy.java index 9e38a58372..67a5bf4b85 100644 --- a/src/main/java/graphql/execution/ExecutionStrategy.java +++ b/src/main/java/graphql/execution/ExecutionStrategy.java @@ -61,6 +61,7 @@ import static graphql.execution.FieldValueInfo.CompleteValueType.NULL; import static graphql.execution.FieldValueInfo.CompleteValueType.OBJECT; import static graphql.execution.FieldValueInfo.CompleteValueType.SCALAR; +import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES; import static graphql.execution.instrumentation.SimpleInstrumentationContext.nonNullCtx; import static graphql.schema.DataFetchingEnvironmentImpl.newDataFetchingEnvironment; import static graphql.schema.GraphQLTypeUtil.isEnum; @@ -238,7 +239,23 @@ protected CompletableFuture fetchField(ExecutionContext executionC MergedField field = parameters.getField(); GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType(); GraphQLFieldDefinition fieldDef = getFieldDef(executionContext.getGraphQLSchema(), parentType, field.getSingleField()); - GraphQLCodeRegistry codeRegistry = executionContext.getGraphQLSchema().getCodeRegistry(); + return fetchField(fieldDef, executionContext, parameters); + } + + private CompletableFuture fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext executionContext, ExecutionStrategyParameters parameters) { + + int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount(); + + Integer maxNodes; + if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) { + if (resultNodesCount > maxNodes) { + executionContext.getResultNodesInfo().maxResultNodesExceeded(); + return CompletableFuture.completedFuture(new FetchedValue(null, null, ImmutableKit.emptyList(), null)); + } + } + + MergedField field = parameters.getField(); + GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType(); // if the DF (like PropertyDataFetcher) does not use the arguments or execution step info then dont build any @@ -273,6 +290,7 @@ protected CompletableFuture fetchField(ExecutionContext executionC .queryDirectives(queryDirectives) .build(); }); + GraphQLCodeRegistry codeRegistry = executionContext.getGraphQLSchema().getCodeRegistry(); DataFetcher dataFetcher = codeRegistry.getDataFetcher(parentType, fieldDef); Instrumentation instrumentation = executionContext.getInstrumentation(); @@ -555,6 +573,15 @@ protected FieldValueInfo completeValueForList(ExecutionContext executionContext, List fieldValueInfos = new ArrayList<>(size.orElse(1)); int index = 0; for (Object item : iterableValues) { + int resultNodesCount = executionContext.getResultNodesInfo().incrementAndGetResultNodesCount(); + Integer maxNodes; + if ((maxNodes = executionContext.getGraphQLContext().get(MAX_RESULT_NODES)) != null) { + if (resultNodesCount > maxNodes) { + executionContext.getResultNodesInfo().maxResultNodesExceeded(); + return new FieldValueInfo(NULL, completedFuture(ExecutionResult.newExecutionResult().build()), fieldValueInfos); + } + } + ResultPath indexedPath = parameters.getPath().segment(index); ExecutionStepInfo stepInfoForListElement = executionStepInfoFactory.newExecutionStepInfoForListElement(executionStepInfo, index); diff --git a/src/main/java/graphql/execution/FetchedValue.java b/src/main/java/graphql/execution/FetchedValue.java index 28d2ce6da6..8cc520cb28 100644 --- a/src/main/java/graphql/execution/FetchedValue.java +++ b/src/main/java/graphql/execution/FetchedValue.java @@ -19,7 +19,7 @@ public class FetchedValue { private final Object localContext; private final ImmutableList errors; - private FetchedValue(Object fetchedValue, Object rawFetchedValue, ImmutableList errors, Object localContext) { + FetchedValue(Object fetchedValue, Object rawFetchedValue, ImmutableList errors, Object localContext) { this.fetchedValue = fetchedValue; this.rawFetchedValue = rawFetchedValue; this.errors = errors; diff --git a/src/main/java/graphql/execution/FieldValueInfo.java b/src/main/java/graphql/execution/FieldValueInfo.java index 168ffab735..c889e98b34 100644 --- a/src/main/java/graphql/execution/FieldValueInfo.java +++ b/src/main/java/graphql/execution/FieldValueInfo.java @@ -25,7 +25,7 @@ public enum CompleteValueType { private final CompletableFuture fieldValue; private final List fieldValueInfos; - private FieldValueInfo(CompleteValueType completeValueType, CompletableFuture fieldValue, List fieldValueInfos) { + FieldValueInfo(CompleteValueType completeValueType, CompletableFuture fieldValue, List fieldValueInfos) { assertNotNull(fieldValueInfos, () -> "fieldValueInfos can't be null"); this.completeValueType = completeValueType; this.fieldValue = fieldValue; diff --git a/src/main/java/graphql/execution/ResultNodesInfo.java b/src/main/java/graphql/execution/ResultNodesInfo.java new file mode 100644 index 0000000000..afc366f6be --- /dev/null +++ b/src/main/java/graphql/execution/ResultNodesInfo.java @@ -0,0 +1,55 @@ +package graphql.execution; + +import graphql.Internal; +import graphql.PublicApi; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * This class is used to track the number of result nodes that have been created during execution. + * After each execution the GraphQLContext contains a ResultNodeInfo object under the key {@link ResultNodesInfo#RESULT_NODES_INFO} + *

+ * The number of result can be limited (and should be for security reasons) by setting the maximum number of result nodes + * in the GraphQLContext under the key {@link ResultNodesInfo#MAX_RESULT_NODES} to an Integer + *

+ */ +@PublicApi +public class ResultNodesInfo { + + public static final String MAX_RESULT_NODES = "__MAX_RESULT_NODES"; + public static final String RESULT_NODES_INFO = "__RESULT_NODES_INFO"; + + private volatile boolean maxResultNodesExceeded = false; + private final AtomicInteger resultNodesCount = new AtomicInteger(0); + + @Internal + public int incrementAndGetResultNodesCount() { + return resultNodesCount.incrementAndGet(); + } + + @Internal + public void maxResultNodesExceeded() { + this.maxResultNodesExceeded = true; + } + + /** + * The number of result nodes created. + * Note: this can be higher than max result nodes because + * a each node that exceeds the number of max nodes is set to null, + * but still is a result node (with value null) + * + * @return number of result nodes created + */ + public int getResultNodesCount() { + return resultNodesCount.get(); + } + + /** + * If the number of result nodes has exceeded the maximum allowed numbers. + * + * @return true if the number of result nodes has exceeded the maximum allowed numbers + */ + public boolean isMaxResultNodesExceeded() { + return maxResultNodesExceeded; + } +} diff --git a/src/test/groovy/graphql/GraphQLTest.groovy b/src/test/groovy/graphql/GraphQLTest.groovy index ec2523d3f8..d994ad9a0f 100644 --- a/src/test/groovy/graphql/GraphQLTest.groovy +++ b/src/test/groovy/graphql/GraphQLTest.groovy @@ -13,6 +13,7 @@ import graphql.execution.ExecutionId import graphql.execution.ExecutionIdProvider import graphql.execution.ExecutionStrategyParameters import graphql.execution.MissingRootTypeException +import graphql.execution.ResultNodesInfo import graphql.execution.SubscriptionExecutionStrategy import graphql.execution.ValueUnboxer import graphql.execution.instrumentation.ChainedInstrumentation @@ -49,6 +50,7 @@ import static graphql.ExecutionInput.Builder import static graphql.ExecutionInput.newExecutionInput import static graphql.Scalars.GraphQLInt import static graphql.Scalars.GraphQLString +import static graphql.execution.ResultNodesInfo.MAX_RESULT_NODES import static graphql.schema.GraphQLArgument.newArgument import static graphql.schema.GraphQLFieldDefinition.newFieldDefinition import static graphql.schema.GraphQLInputObjectField.newInputObjectField @@ -1440,4 +1442,143 @@ many lines'''] then: !er.errors.isEmpty() } + + def "max result nodes not breached"() { + given: + def sdl = ''' + + type Query { + hello: String + } + ''' + def df = { env -> "world" } as DataFetcher + def fetchers = ["Query": ["hello": df]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ hello h1: hello h2: hello h3: hello } " + def ei = newExecutionInput(query).build() + ei.getGraphQLContext().put(MAX_RESULT_NODES, 4); + + when: + def er = graphQL.execute(ei) + def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo + then: + !rni.maxResultNodesExceeded + rni.resultNodesCount == 4 + er.data == [hello: "world", h1: "world", h2: "world", h3: "world"] + } + + def "max result nodes breached"() { + given: + def sdl = ''' + + type Query { + hello: String + } + ''' + def df = { env -> "world" } as DataFetcher + def fetchers = ["Query": ["hello": df]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ hello h1: hello h2: hello h3: hello } " + def ei = newExecutionInput(query).build() + ei.getGraphQLContext().put(MAX_RESULT_NODES, 3); + + when: + def er = graphQL.execute(ei) + def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo + then: + rni.maxResultNodesExceeded + rni.resultNodesCount == 4 + er.data == [hello: "world", h1: "world", h2: "world", h3: null] + } + + def "max result nodes breached with list"() { + given: + def sdl = ''' + + type Query { + hello: [String] + } + ''' + def df = { env -> ["w1", "w2", "w3"] } as DataFetcher + def fetchers = ["Query": ["hello": df]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ hello}" + def ei = newExecutionInput(query).build() + ei.getGraphQLContext().put(MAX_RESULT_NODES, 3); + + when: + def er = graphQL.execute(ei) + def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo + then: + rni.maxResultNodesExceeded + rni.resultNodesCount == 4 + er.data == [hello: null] + } + + def "max result nodes breached with list 2"() { + given: + def sdl = ''' + + type Query { + hello: [Foo] + } + type Foo { + name: String + } + ''' + def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher + def fetchers = ["Query": ["hello": df]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ hello {name}}" + def ei = newExecutionInput(query).build() + // we have 7 result nodes overall + ei.getGraphQLContext().put(MAX_RESULT_NODES, 6); + + when: + def er = graphQL.execute(ei) + def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo + then: + rni.resultNodesCount == 7 + rni.maxResultNodesExceeded + er.data == [hello: [[name: "w1"], [name: "w2"], [name: null]]] + } + + def "max result nodes not breached with list"() { + given: + def sdl = ''' + + type Query { + hello: [Foo] + } + type Foo { + name: String + } + ''' + def df = { env -> [[name: "w1"], [name: "w2"], [name: "w3"]] } as DataFetcher + def fetchers = ["Query": ["hello": df]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ hello {name}}" + def ei = newExecutionInput(query).build() + // we have 7 result nodes overall + ei.getGraphQLContext().put(MAX_RESULT_NODES, 7); + + when: + def er = graphQL.execute(ei) + def rni = ei.getGraphQLContext().get(ResultNodesInfo.RESULT_NODES_INFO) as ResultNodesInfo + then: + !rni.maxResultNodesExceeded + rni.resultNodesCount == 7 + er.data == [hello: [[name: "w1"], [name: "w2"], [name: "w3"]]] + } + }