8000 refactor: adjust type hintings to satisfy mypy · pytest-dev/pytest-factoryboy@5f92690 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5f92690

Browse files
committed
refactor: adjust type hintings to satisfy mypy
1 parent c63914b commit 5f92690

File tree

6 files changed

+55
-37
lines changed

6 files changed

+55
-37
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ line-length = 120
77
target-version = ['py37', 'py38', 'py39', 'py310']
88

99
[tool.mypy]
10+
exclude = ['docs/']
1011
allow_redefinition = false
1112
check_untyped_defs = true
1213
disallow_untyped_decorators = true
@@ -26,3 +27,7 @@ warn_unreachable = true
2627
warn_no_return = true
2728
pretty = true
2829
show_error_codes = true
30+
31+
[[tool.mypy.overrides]]
32+
module = ["tests.*"]
33+
disallow_untyped_decorators = false

pytest_factoryboy/codegen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import pathlib
88
import shutil
99
import tempfile
10-
import typing
1110
from dataclasses import field, dataclass
1211
from functools import lru_cache
1312
from types import ModuleType
13+
from typing import Any
14+
from typing_extensions import Literal
1415

1516
import mako.template
1617
from appdirs import AppDirs
@@ -25,8 +26,8 @@
2526
@dataclass
2627
class FixtureDef:
2728
name: str
28-
function_name: typing.Literal["model_fixture", "attr_fixture", "factory_fixture", "subfactory_fixture"]
29-
function_kwargs: dict = field(default_factory=dict)
29+
function_name: Literal["model_fixture", "attr_fixture", "factory_fixture", "subfactory_fixture"]
30+
function_kwargs: dict[str, Any] = field(default_factory=dict)
3031
deps: list[str] = field(default_factory=list)
3132
related: list[str] = field(default_factory=list)
3233

@@ -122,7 +123,9 @@ def make_module(code: str, module_name: str, package_name: str) -> ModuleType:
122123
tmp_module_path.write_text(code)
123124
name = f"{package_name}.{module_name}"
124125
spec = importlib.util.spec_from_file_location(name, tmp_module_path)
126+
assert spec # NOTE: satisfy `mypy`
125127
mod = importlib.util.module_from_spec(spec)
128+
assert spec.loader # NOTE: satisfy `mypy`
126129
spec.loader.exec_module(mod)
127130
return mod
128131

pytest_factor A93C yboy/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import sys
33
import pathlib
44

5+
__all__ = ("PostGenerationContext", "path_with_stem")
6+
57
try:
68
from factory.declarations import PostGenerationContext
79
except ImportError: # factory_boy < 3.2.0

pytest_factoryboy/fixture.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414

1515
from .codegen import make_fixture_model_module, FixtureDef
1616
from .compat import PostGenerationContext
17-
from typing import TYPE_CHECKING, overload
18-
from typing_extensions import Protocol
17+
from typing import TYPE_CHECKING, overload, cast
18+
from typing_extensions import Protocol, TypeAlias
1919

2020
if TYPE_CHECKING:
2121
from typing import Any, Callable, TypeVar
22-
from _pytest.fixtures import FixtureRequest
22+
from _pytest.fixtures import SubRequest, FixtureFunction
2323
from factory.builder import BuildStep
2424
from factory.declarations import PostGeneration
2525
from factory.declarations import PostGenerationContext
2626

27-
FactoryType = type[factory.Factory]
27+
FactoryType: TypeAlias = factory.Factory
2828
T = TypeVar("T")
2929
F = TypeVar("F", bound=FactoryType)
3030

@@ -37,9 +37,9 @@ class DeferredFunction:
3737
name: str
3838
factory: FactoryType
3939
is_related: bool
40-
function: Callable[[FixtureRequest], Any]
40+
function: Callable[[SubRequest], Any]
4141

42-
def __call__(self, request: FixtureRequest) -> Any:
42+
def __call__(self, request: SubRequest) -> Any:
4343
return self.function(request)
4444

4545

@@ -51,7 +51,7 @@ def __call__(self, factory_class: F, _name: str | None = None, **kwargs: Any) ->
5151

5252

5353
@overload
54-
def register(
54+
def register( # type: ignore[misc]
5555
factory_class: None = None,
5656
_name: str | None = None,
5757
**kwargs: Any,
@@ -177,7 +177,7 @@ def register(
177177
return factory_class
178178

179179

180-
def inject_into_caller(name: str, function: Callable, locals_: dict[str, Any]) -> None:
180+
def inject_into_caller(name: str, function: Callable[..., Any], locals_: dict[str, Any]) -> None:
181181
"""Inject a function into the caller's locals, making sure that the function will work also within classes."""
182182
# We need to check if the caller frame is a class, since in that case the first argument is the class itself.
183183
# In that case, we can apply the staticmethod() decorator to the injected function, so that the first param
@@ -191,7 +191,7 @@ def inject_into_caller(name: str, function: Callable, locals_: dict[str, Any]) -
191191
# Therefore, we can just check for __qualname__ to figure out if we are in a class, and apply the @staticmethod.
192192
is_class_or_function = "__qualname__" in locals_
193193
if is_class_or_function:
194-
function = staticmethod(function)
194+
function = staticmethod(function) # type: ignore[assignment]
195195

196196
locals_[name] = function
197197

@@ -238,20 +238,21 @@ def is_dep(value: Any) -> bool:
238238
]
239239

240240

241-
def evaluate(request: FixtureRequest, value: LazyFixture | Any) -> Any:
241+
def evaluate(request: SubRequest, value: LazyFixture | Any) -> Any:
242242
"""Evaluate the declaration (lazy fixtures, etc)."""
243243
return value.evaluate(request) if isinstance(value, LazyFixture) else value
244244

245245

246-
def model_fixture(request: FixtureRequest, factory_name: str) -> Any:
246+
def model_fixture(request: SubRequest, factory_name: str) -> Any:
247247
"""Model fixture implementation."""
248248
factoryboy_request = request.getfixturevalue("factoryboy_request")
249249

250250
# Try to evaluate as much post-generation dependencies as possible
251251
factoryboy_request.evaluate(request)
252252

253+
fixture_name = str(request.fixturename)
253254
factory_class: FactoryType = request.getfixturevalue(factory_name)
254-
prefix = "".join((request.fixturename, SEPARATOR))
255+
prefix = "".join((fixture_name, SEPARATOR))
255256

256257
# Create model fixture instance
257258

@@ -279,7 +280,7 @@ class Factory(factory_class):
279280

280281
# Cache the instance value on pytest level so that the fixture can be resolved before the return
281282
request._fixturedef.cached_result = (instance, 0, None)
282-
request._fixture_defs[request.fixturename] = request._fixturedef
283+
request._fixture_defs[fixture_name] = request._fixturedef
283284

284285
# Defer post-generation declarations
285286
deferred: list[DeferredFunction] = []
@@ -289,7 +290,7 @@ class Factory(factory_class):
289290
decl = factory_class._meta.post_declarations.declarations[attr]
290291

291292
if isinstance(decl, factory.RelatedFactory):
292-
deferred.append(make_deferred_related(factory_class, request.fixturename, attr))
293+
deferred.append(make_deferred_related(factory_class, fixture_name, attr))
293294
else:
294295
argname = "".join((prefix, attr))
295296
extra = {}
@@ -309,7 +310,7 @@ class Factory(factory_class):
309310
extra=extra,
310311
)
311312
deferred.append(
312-
make_deferred_postgen(step, factory_class, request.fixturename, instance, attr, decl, postgen_context)
313+
make_deferred_postgen(step, factory_class, fixture_name, instance, attr, decl, postgen_context)
313314
)
314315
factoryboy_request.defer(deferred)
315316

@@ -329,7 +330,7 @@ def make_deferred_related(factory: FactoryType, fixture: str, attr: str) -> Defe
329330
"""
330331
name = SEPARATOR.join((fixture, attr))
331332

332-
def deferred_impl(request: FixtureRequest) -> Any:
333+
def deferred_impl(request: SubRequest) -> Any:
333334
return request.getfixturevalue(name)
334335

335336
return DeferredFunction(
@@ -362,7 +363,7 @@ def make_deferred_postgen(
362363
"""
363364
name = SEPARATOR.join((fixture, attr))
364365

365-
def deferred_impl(request: FixtureRequest) -> Any:
366+
def deferred_impl(request: SubRequest) -> Any:
366367
return declaration.call(instance, step, context)
367368

368369
return DeferredFunction(
@@ -373,17 +374,17 @@ def deferred_impl(request: FixtureRequest) -> Any:
373374
)
374375

375376

376-
def factory_fixture(request: FixtureRequest, factory_class: F) -> F:
377+
def factory_fixture(request: SubRequest, factory_class: F) -> F:
377378
"""Factory fixture implementation."""
378379
return factory_class
379380

380381

381-
def attr_fixture(request: FixtureRequest, value: T) -> T:
382+
def attr_fixture(request: SubRequest, value: T) -> T:
382383
"""Attribute fixture implementation."""
383384
return value
384385

385386

386-
def subfactory_fixture(request: FixtureRequest, factory_class: FactoryType) -> Any:
387+
def subfactory_fixture(request: SubRequest, factory_class: FactoryType) -> Any:
387388
"""SubFactory/RelatedFactory fixture implementation."""
388389
fixture = inflection.underscore(factory_class._meta.model.__name__)
389390
return request.getfixturevalue(fixture)
@@ -397,7 +398,7 @@ def get_caller_locals(depth: int = 2) -> dict[str, Any]:
397398
class LazyFixture:
398399
"""Lazy fixture."""
399400

400-
def __init__(self, fixture: Callable | str) -> None:
401+
def __init__(self, fixture: FixtureFunction | str) -> None:
401402
"""Lazy pytest fixture wrapper.
402403
403404
:param fixture: Fixture name or callable with dependencies.
@@ -409,7 +410,7 @@ def __init__(self, fixture: Callable | str) -> None:
409410
else:
410411
self.args = [self.fixture]
411412

412-
def evaluate(self, request: FixtureRequest) -> Any:
413+
def evaluate(self, request: SubRequest) -> Any:
413414
"""Evaluate the lazy fixture.
414415
415416
:param request: pytest request object.

pytest_factoryboy/plugin.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
if TYPE_CHECKING:
1010
from typing import Any
1111
from factory import Factory
12-
from _pytest.fixtures import FixtureRequest
12+
from _pytest.fixtures import FixtureRequest, SubRequest
1313
from _pytest.config import PytestPluginManager
1414
from _pytest.python import Metafunc
1515
from _pytest.nodes import Item
@@ -29,7 +29,7 @@ def __init__(self) -> None:
2929
self.deferred: list[list[DeferredFunction]] = []
3030
self.results: dict[str, dict[str, Any]] = defaultdict(dict)
3131
self.model_factories: dict[str, type[Factory]] = {}
32-
self.in_progress: set = set()
32+
self.in_progress: set[DeferredFunction] = set()
3333

3434
def defer(self, functions: list[DeferredFunction]) -> None:
3535
"""Defer post-generation declaration execution until the end of the test setup.
@@ -39,7 +39,7 @@ def defer(self, functions: list[DeferredFunction]) -> None:
3939
"""
4040
self.deferred.append(functions)
4141

42-
def get_deps(self, request: FixtureRequest, fixture: str, deps: set[str] | None = None) -> set[str]:
42+
def get_deps(self, request: SubRequest, fixture: str, deps: set[str] | None = None) -> set[str]:
4343
request = request.getfixturevalue("request")
4444

4545
if deps is None:
@@ -54,15 +54,15 @@ def get_deps(self, request: FixtureRequest, fixture: str, deps: set[str] | None
5454
deps.update(self.get_deps(request, argname, deps))
5555
return deps
5656

57-
def get_current_deps(self, request: FixtureRequest) -> set[str]:
57+
def get_current_deps(self, request: FixtureRequest | SubRequest) -> set[str]:
5858
deps = set()
5959
while hasattr(request, "_parent_request"):
6060
if request.fixturename and request.fixturename not in getattr(request, "_fixturedefs", {}):
6161
deps.add(request.fixturename)
62-
request = request._parent_request
62+
request = request._parent_request # type: ignore[union-attr]
6363
return deps
6464

65-
def execute(self, request: FixtureRequest, function: DeferredFunction, deferred: list[DeferredFunction]) -> None:
65+
def execute(self, request: SubRequest, function: DeferredFunction, deferred: list[DeferredFunction]) -> None:
6666
"""Execute deferred function and store the result."""
6767
if function in self.in_progress:
6868
raise CycleDetected()
@@ -79,15 +79,15 @@ def execute(self, request: FixtureRequest, function: DeferredFunction, deferred:
7979
deferred.remove(function)
8080
self.in_progress.remove(function)
8181

82-
def after_postgeneration(self, request: FixtureRequest) -> None:
82+
def after_postgeneration(self, request: SubRequest) -> None:
8383
"""Call _after_postgeneration hooks."""
8484
for model in list(self.results.keys()):
8585
results = self.results.pop(model)
8686
obj = request.getfixturevalue(model)
8787
factory = self.model_factories[model]
8888
factory._after_postgeneration(obj, create=True, results=results)
8989

90-
def evaluate(self, request: FixtureRequest) -> None:
90+
def evaluate(self, request: SubRequest) -> None:
9191
"""Finalize, run deferred post-generation actions, etc."""
9292
while self.deferred:
9393
try:
@@ -114,7 +114,7 @@ def pytest_runtest_call(item: Item) -> None:
114114
"""Before the test item is called."""
115115
# TODO: We should instead do an `if isinstance(item, Function)`.
116116
try:
117-
request = item._request
117+
request = item._request # type: ignore[attr-defined]
118118
except AttributeError:
119119
# pytest-pep8 plugin passes Pep8Item here during tests.
120120
return

tests/test_postgen_dependencies.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test post-generation dependencies."""
22
from __future__ import annotations
33

4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55

66
import factory
77
import pytest
@@ -19,6 +19,12 @@ class Foo:
1919
value: int
2020
expected: int
2121

22+
bar: Bar | None = None
23+
24+
# NOTE: following attributes are used internally only for assertions
25+
_create: bool | None = None
26+
_postgeneration_results: dict[str, Any] = field(default_factory=dict)
27+
2228

2329
@dataclass
2430
class Bar:
@@ -59,7 +65,7 @@ def set1(foo: Foo, create: bool, value: Any, **kwargs: Any) -> str:
5965

6066
@classmethod
6167
def _after_postgeneration(cls, obj: Foo, create: bool, results: dict[str, Any] | None = None) -> None:
62-
obj._postgeneration_results = results
68+
obj._postgeneration_results = results or {}
6369
obj._create = create
6470

6571

@@ -111,8 +117,9 @@ def test_after_postgeneration(foo: Foo):
111117
assert len(foo._postgeneration_results) == 2
112118

113119

120+
@dataclass
114121
class Ordered:
115-
value = None
122+
value: str | None = None
116123

117124

118125
@register

0 commit comments

Comments
 (0)
0