10000 Add implementation · kiq7/python-dependency-injector@1d28e62 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d28e62

Browse files
committed
Add implementation
1 parent c787ac2 commit 1d28e62

File tree

1 file changed

+96
-25
lines changed

1 file changed

+96
-25
lines changed

src/dependency_injector/wiring.py

Lines changed: 96 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
TypeVar,
2121
Type,
2222
Union,
23+
Set,
2324
cast,
2425
)
2526

@@ -82,22 +83,53 @@ class GenericMeta(type):
8283
Container = Any
8384

8485

85-
class Registry:
86+
class PatchedRegistry:
8687

8788
def __init__(self):
88-
self._storage = set()
89+
self._callables: Set[Callable[..., Any]] = set()
90+
self._attributes: Set[PatchedAttribute] = set()
8991

90-
def add(self, patched: Callable[..., Any]) -> None:
91-
self._storage.add(patched)
92+
def add_callable(self, patched: Callable[..., Any]) -> None:
93+
self._callables.add(patched)
9294

93-
def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
94-
for patched in self._storage:
95+
def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
96+
for patched in self._callables:
9597
if patched.__module__ != module.__name__:
9698
continue
9799
yield patched
98100

101+
def add_attribute(self, patched: 'PatchedAttribute'):
102+
self._attributes.add(patched)
99103

100-
_patched_registry = Registry()
104+
def get_attributes_from_module(self, module: ModuleType) -> Iterator['PatchedAttribute']:
105+
for attribute in self._attributes:
106+
if not attribute.is_in_module(module):
107+
continue
108+
yield attribute
109+
110+
def clear_module_attributes(self, module: ModuleType):
111+
for attribute in self._attributes.copy():
112+
if not attribute.is_in_module(module):
113+
continue
114+
self._attributes.remove(attribute)
115+
116+
117+
class PatchedAttribute:
118+
119+
def __init__(self, member: Any, name: str, marker: '_Marker'):
120+
self.member = member
121+
self.name = name
122+
self.marker = marker
123+
124+
@property
125+
def module_name(self) -> str:
126+
if isinstance(self.member, ModuleType):
127+
return self.member.__name__
128+
else:
129+
return self.member.__module__
130+
131+
def is_in_module(self, module: ModuleType) -> bool:
132+
return self.module_name == module.__name__
101133

102134

103135
class ProvidersMap:
@@ -278,9 +310,6 @@ def _is_starlette_request_cls(self, instance: object) -> bool:
278310
and issubclass(instance, starlette.requests.Request)
279311

280312

281-
inspect_filter = InspectFilter()
282-
283-
284313
def wire( # noqa: C901
285314
container: Container,
286315
*,
@@ -301,16 +330,23 @@ def wire( # noqa: C901
301330
providers_map = ProvidersMap(container)
302331

303332
for module in modules:
304-
for name, member in inspect.getmembers(module):
305-
if inspect_filter.is_excluded(member):
333+
for member_name, member in inspect.getmembers(module):
334+
if _inspect_filter.is_excluded(member):
306335
continue
307-
if inspect.isfunction(member):
308-
_patch_fn(module, name, member, providers_map)
309-
elif inspect.isclass(member):
310-
for method_name, method in inspect.getmembers(member, _is_method):
311-
_patch_method(member, method_name, method, providers_map)
312336

313-
for patched in _patched_registry.get_from_module(module):
337+
if _is_marker(member):
338+
_patch_attribute(module, member_name, member, providers_map)
339+
elif inspect.isfunction(member):
340+
_patch_fn(module, member_name, member, providers_map)
341+
elif inspect.isclass(member):
342+
cls = member
343+
for cls_member_name, cls_member in inspect.getmembers(cls):
344+
if _is_marker(cls_member):
345+
_patch_attribute(cls, cls_member_name, cls_member, providers_map)
346+
elif _is_method(cls_member):
347+
_patch_method(cls, cls_member_name, cls_member, providers_map)
348+
349+
for patched in _patched_registry.get_callables_from_module(module):
314350
_bind_injections(patched, providers_map)
315351

316352

@@ -335,15 +371,19 @@ def unwire(
335371
for method_name, method in inspect.getmembers(member, inspect.isfunction):
336372
_unpatch(member, method_name, method)
337373

338-
for patched in _patched_registry.get_from_module(module):
374+
for patched in _patched_registry.get_callables_from_module(module):
339375
_unbind_injections(patched)
340376

377+
for patched_attribute in _patched_registry.get_attributes_from_module(module):
378+
_unpatch_attribute(patched_attribute)
379+
_patched_registry.clear_module_attributes(module)
380+
341381

342382
def inject(fn: F) -> F:
343383
"""Decorate callable with injecting decorator."""
344384
reference_injections, reference_closing = _fetch_reference_injections(fn)
345385
patched = _get_patched(fn, reference_injections, reference_closing)
346-
_patched_registry.add(patched)
386+
_patched_registry.add_callable(patched)
347387
return cast(F, patched)
348388

349389

@@ -358,7 +398,7 @@ def _patch_fn(
358398
if not reference_injections:
359399
return
360400
fn = _get_patched(fn, reference_injections, reference_closing)
361-
_patched_registry.add(fn)
401+
_patched_registry.add_callable(fn)
362402

363403
_bind_injections(fn, providers_map)
364404

@@ -384,7 +424,7 @@ def _patch_method(
384424
if not reference_injections:
385425
return
386426
fn = _get_patched(fn, reference_injections, reference_closing)
387-
_patched_registry.add(fn)
427+
_patched_registry.add_callable(fn)
388428

389429
_bind_injections(fn, providers_map)
390430

@@ -411,6 +451,31 @@ def _unpatch(
411451
_unbind_injections(fn)
412452

413453

454+
def _patch_attribute(
455+
member: Any,
456+
name: str,
457+
marker: '_Marker',
458+
providers_map: ProvidersMap,
459+
) -> None:
460+
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
461+
if provider is None:
462+
return
463+
464+
_patched_registry.add_attribute(PatchedAttribute(member, name, marker))
465+
466+
if isinstance(marker, Provide):
467+
instance = provider()
468+
setattr(member, name, instance)
469+
elif isinstance(marker, Provider):
470+
setattr(member, name, provider)
471+
else:
472+
raise Exception(f'Unknown type of marker {marker}')
473+
474+
475+
def _unpatch_attribute(patched: PatchedAttribute) -> None:
476+
setattr(patched.member, patched.name, patched.marker)
477+
478+
414479
def _fetch_reference_injections(
415480
fn: Callable[..., Any],
416481
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -484,6 +549,10 @@ def _is_method(member):
484549
return inspect.ismethod(member) or inspect.isfunction(member)
485550

486551

552+
def _is_marker(member):
553+
return isinstance(member, Provide) or isinstance(member, Provider)
554+
555+
487556
def _get_patched(fn, reference_injections, reference_closing):
488557
if inspect.iscoroutinefunction(fn):
489558
patched = _get_async_patched(fn)
@@ -825,9 +894,6 @@ def uninstall(self):
825894
importlib.invalidate_caches()
826895

827896

828-
_loader = AutoLoader()
829-
830-
831897
def register_loader_containers(*containers: Container) -> None:
832898
"""Register containers in auto-wiring module loader."""
833899
_loader.register_containers(*containers)
@@ -851,3 +917,8 @@ def uninstall_loader() -> None:
851917
def is_loader_installed() -> bool:
852918
"""Check if auto-wiring module loader hook is installed."""
853919
return _loader.installed
920+
921+
922+
_patched_registry = PatchedRegistry()
923+
_inspect_filter = InspectFilter()
924+
_loader = AutoLoader()

0 commit comments

Comments
 (0)
0