8000 Use type guards (#183) · graphql-python/graphql-core@e4c26df · GitHub
[go: up one dir, main page]

Skip to content

Commit e4c26df

Browse files
committed
Use type guards (#183)
1 parent 12e29fe commit e4c26df

37 files changed

+233
-306
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Changelog = "https://github.com/graphql-python/graphql-core/releases"
4242
[tool.poetry.dependencies]
4343
python = "^3.7"
4444
typing-extensions = [
45-
{ version = "^4.3", python = "<3.8" }
45+
{ version = "^4.4", python = "<3.10" }
4646
]
4747

4848
[tool.poetry.group.test]
@@ -138,7 +138,7 @@ module = [
138138
disallow_untyped_defs = false
139139

140140
[tool.pytest.ini_options]
141-
minversion = "7.1"
141+
minversion = "7.2"
142142
# Only run benchmarks as tests.
143143
# To actually run the benchmarks, use --benchmark-enable on the command line.
144144
# To run the slow tests (fuzzing), add --run-slow on the command line.

src/graphql/execution/collect_fields.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import Any, Dict, List, Set, Union, cast
2+
from typing import Any, Dict, List, Set, Union
33

44
from ..language import (
55
FieldNode,
@@ -9,7 +9,6 @@
99
SelectionSetNode,
1010
)
1111
from ..type import (
12-
GraphQLAbstractType,
1312
GraphQLIncludeDirective,
1413
GraphQLObjectType,
1514
GraphQLSchema,
@@ -166,7 +165,7 @@ def does_fragment_condition_match(
166165
if conditional_type is type_:
167166
return True
168167
if is_abstract_type(conditional_type):
169-
return schema.is_sub_type(cast(GraphQLAbstractType, conditional_type), type_)
168+
return schema.is_sub_type(conditional_type, type_)
170169
return False
171170

172171

src/graphql/execution/execute.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from typing import TypedDict
2222
except ImportError: # Python < 3.8
2323
from typing_extensions import TypedDict
24+
try:
25+
from typing import TypeGuard
26+
except ImportError: # Python < 3.10
27+
from typing_extensions import TypeGuard
2428

2529
from ..error import GraphQLError, GraphQLFormattedError, located_error
2630
from ..language import (
@@ -39,7 +43,6 @@
3943
GraphQLFieldResolver,
4044
GraphQLLeafType,
4145
GraphQLList,
42-
GraphQLNonNull,
4346
GraphQLObjectType,
4447
GraphQLOutputType,
4548
GraphQLResolveInfo,
@@ -187,7 +190,9 @@ class ExecutionContext:
187190
errors: List[GraphQLError]
188191
middleware_manager: Optional[MiddlewareManager]
189192

190-
is_awaitable = staticmethod(default_is_awaitable)
193+
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] = staticmethod(
194+
default_is_awaitable # type: ignore
195+
)
191196

192197
def __init__(
193198
self,
@@ -607,7 +612,7 @@ def complete_value(
607612
# result is null.
608613
if is_non_null_type(return_type):
609614
completed = self.complete_value(
610-
cast(GraphQLNonNull, return_type).of_type,
615+
return_type.of_type,
611616
field_nodes,
612617
info,
613618
path,
@@ -627,25 +632,25 @@ def complete_value(
627632
# If field type is List, complete each item in the list with inner type
628633
if is_list_type(return_type):
629634
return self.complete_list_value(
630-
cast(GraphQLList, return_type), field_nodes, info, path, result
635+
return_type, field_nodes, info, path, result
631636
)
632637

633638
# If field type is a leaf type, Scalar or Enum, serialize to a valid value,
634639
# returning null if serialization is not possible.
635640
if is_leaf_type(return_type):
636-
return self.complete_leaf_value(cast(GraphQLLeafType, return_type), result)
641+
return self.complete_leaf_value(return_type, result)
637642

638643
# If field type is an abstract type, Interface or Union, determine the runtime
639644
# Object type and complete for that type.
640645
if is_abstract_type(return_type):
641646
return self.complete_abstract_value(
642-
cast(GraphQLAbstractType, return_type), field_nodes, info, path, result
647+
return_type, field_nodes, info, path, result
643648
)
644649

645650
# If field type is Object, execute and complete all sub-selections.
646651
if is_object_type(return_type):
647652
return self.complete_object_value(
648-
cast(GraphQLObjectType, return_type), field_nodes, info, path, result
653+
return_type, field_nodes, info, path, result
649654
)
650655

651656
# Not reachable. All possible output types have been considered.
@@ -684,7 +689,6 @@ async def async_iterable_to_list(
684689
"Expected Iterable, but did not find one for field"
685690
f" '{info.parent_type.name}.{info.field_name}'."
686691
)
687-
result = cast(Iterable[Any], result)
688692

689693
# This is specified as a simple map, however we're optimizing the path where
690694
# the list contains no coroutine objects by avoiding creating another coroutine
@@ -876,8 +880,6 @@ def ensure_valid_runtime_type(
876880
field_nodes,
877881
)
878882

879-
runtime_type = cast(GraphQLObjectType, runtime_type)
880-
881883
if not self.schema.is_sub_type(return_type, runtime_type):
882884
raise GraphQLError(
883885
f"Runtime Object type '{runtime_type.name}' is not a possible"

src/graphql/execution/values.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Collection, Dict, List, Optional, Union, cast
1+
from typing import Any, Callable, Collection, Dict, List, Optional, Union
22

33
from ..error import GraphQLError
44
from ..language import (
@@ -21,7 +21,6 @@
2121
from ..type import (
2222
GraphQLDirective,
2323
GraphQLField,
24-
GraphQLInputType,
2524
GraphQLSchema,
2625
is_input_type,
2726
is_non_null_type,
@@ -92,7 +91,6 @@ def coerce_variable_values(
9291
)
9392
continue
9493

95-
var_type = cast(GraphQLInputType, var_type)
9694
if var_name not in inputs:
9795
if var_def_node.default_value:
9896
coerced_values[var_name] = value_from_ast(

src/graphql/language/parser.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,8 @@ def __init__(
201201
no_location: bool = False,
202202
allow_legacy_fragment_variables: bool = False,
203203
):
204-
source = (
205-
cast(Source, source) if is_source(source) else Source(cast(str, source))
206-
)
204+
if not is_source(source):
205+
source = Source(cast(str, source))
207206

208207
self._lexer = Lexer(source)
209208
self._no_location = no_location

src/graphql/language/predicates.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
from .ast import (
24
DefinitionNode,
35
ExecutableDefinitionNode,
@@ -15,6 +17,12 @@
1517
)
1618

1719

20+
try:
21+
from typing import TypeGuard
22+
except ImportError: # Python < 3.10
23+
from typing_extensions import TypeGuard
24+
25+
1826
__all__ = [
1927
"is_definition_node",
2028
"is_executable_definition_node",
@@ -29,27 +37,27 @@
2937
]
3038

3139

32-
def is_definition_node(node: Node) -> bool:
40+
def is_definition_node(node: Node) -> TypeGuard[DefinitionNode]:
3341
"""Check whether the given node represents a definition."""
3442
return isinstance(node, DefinitionNode)
3543

3644

37-
def is_executable_definition_node(node: Node) -> bool:
45+
def is_executable_definition_node(node: Node) -> TypeGuard[ExecutableDefinitionNode]:
3846
"""Check whether the given node represents an executable definition."""
3947
return isinstance(node, ExecutableDefinitionNode)
4048

4149

42-
def is_selection_node(node: Node) -> bool:
50+
def is_selection_node(node: Node) -> TypeGuard[SelectionNode]:
4351
"""Check whether the given node represents a selection."""
4452
return isinstance(node, SelectionNode)
4553

4654

47-
def is_value_node(node: Node) -> bool:
55+
def is_value_node(node: Node) -> TypeGuard[ValueNode]:
4856
"""Check whether the given node represents a value."""
4957
return isinstance(node, ValueNode)
5058

5159

52-
def is_const_value_node(node: Node) -> bool:
60+
def is_const_value_node(node: Node) -> TypeGuard[ValueNode]:
5361
"""Check whether the given node represents a constant value."""
5462
return is_value_node(node) and (
5563
any(is_const_value_node(value) for value in node.values)
@@ -60,26 +68,28 @@ def is_const_value_node(node: Node) -> bool:
6068
)
6169

6270

63-
def is_type_node(node: Node) -> bool:
71+
def is_type_node(node: Node) -> TypeGuard[TypeNode]:
6472
"""Check whether the given node represents a type."""
6573
return isinstance(node, TypeNode)
6674

6775

68-
def is_type_system_definition_node(node: Node) -> bool:
76+
def is_type_system_definition_node(node: Node) -> TypeGuard[TypeSystemDefinitionNode]:
6977
"""Check whether the given node represents a type system definition."""
7078
return isinstance(node, TypeSystemDefinitionNode)
7179

7280

73-
def is_type_definition_node(node: Node) -> bool:
81+
def is_type_definition_node(node: Node) -> TypeGuard[TypeDefinitionNode]:
7482
"""Check whether the given node represents a type definition."""
7583
return isinstance(node, TypeDefinitionNode)
7684

7785

78-
def is_type_system_extension_node(node: Node) -> bool:
86+
def is_type_system_extension_node(
87+
node: Node,
88+
) -> TypeGuard[Union[SchemaExtensionNode, TypeExtensionNode]]:
7989
"""Check whether the given node represents a type system extension."""
8090
return isinstance(node, (SchemaExtensionNode, TypeExtensionNode))
8191

8292

83-
def is_type_extension_node(node: Node) -> bool:
93+
def is_type_extension_node(node: Node) -> TypeGuard[TypeExtensionNode]:
8494
"""Check whether the given node represents a type extension."""
8595
return isinstance(node, TypeExtensionNode)

src/graphql/language/source.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
from .location import SourceLocation
44

55

6+
try:
7+
from typing import TypeGuard
8+
except ImportError: # Python < 3.10
9+
from typing_extensions import TypeGuard
10+
11+
612
__all__ = ["Source", "is_source"]
713

814
DEFAULT_NAME = "GraphQL request"
@@ -66,7 +72,7 @@ def __ne__(self, other: Any) -> bool:
6672
return not self == other
6773

6874

69-
def is_source(source: Any) -> bool:
75+
def is_source(source: Any) -> TypeGuard[Source]:
7076
"""Test if the given value is a Source object.
7177
7278
For internal use only.

src/graphql/pyutils/is_awaitable.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import inspect
22
from types import CoroutineType, GeneratorType
3-
from typing import Any
3+
from typing import Any, Awaitable
4+
5+
6+
try:
7+
from typing import TypeGuard
8+
except ImportError: # Python < 3.10
9+
from typing_extensions import TypeGuard
410

511

612
__all__ = ["is_awaitable"]
713

814
CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE
915

1016

11-
def is_awaitable(value: Any) -> bool:
17+
def is_awaitable(value: Any) -> TypeGuard[Awaitable]:
1218
"""Return true if object can be passed to an ``await`` expression.
1319
1420
Instead of testing if the object is an instance of abc.Awaitable, it checks
@@ -18,7 +24,7 @@ def is_awaitable(value: Any) -> bool:
1824
# check for coroutine objects
1925
isinstance(value, CoroutineType)
2026
# check for old-style generator based coroutine objects
21-
or isinstance(value, GeneratorType)
27+
or isinstance(value, GeneratorType) # for Python < 3.11
2228
and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE)
2329
# check for other awaitables (e.g. futures)
2430
or hasattr(value, "__await__")

src/graphql/pyutils/is_iterable.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
from typing import Any, ByteString, Collection, Iterable, Mapping, Text, ValuesView
33

44

5+
try:
6+
from typing import TypeGuard
7+
except ImportError: # Python < 3.10
8+
from typing_extensions import TypeGuard
9+
10+
511
__all__ = ["is_collection", "is_iterable"]
612

713
collection_types: Any = [Collection]
@@ -16,14 +22,14 @@
1622
not_iterable_types: Any = (ByteString, Mapping, Text)
1723

1824

19-
def is_collection(value: Any) -> bool:
25+
def is_collection(value: Any) -> TypeGuard[Collection]:
2026
"""Check if value is a collection, but not a string or a mapping."""
2127
return isinstance(value, collection_types) and not isinstance(
2228
value, not_iterable_types
2329
)
2430

2531

26-
def is_iterable(value: Any) -> bool:
32+
def is_iterable(value: Any) -> TypeGuard[Iterable]:
2733
"""Check if value is an iterable, but not a string or a mapping."""
2834
return isinstance(value, iterable_types) and not isinstance(
2935
value, not_iterable_types

0 commit comments

Comments
 (0)
0