|
1 |
| -__all__ = ["Intersection", "get_type_hints"] |
2 |
| - |
3 |
| -from typing import Any, get_args, get_origin |
4 |
| -from typing import get_type_hints as get_type_hints_old |
5 |
| - |
6 |
| -from basedtyping import Intersection |
7 |
| - |
8 |
| - |
9 |
| -def get_type_hints( |
10 |
| - obj: Any, |
11 |
| - globalns: Any | None = None, |
12 |
| - localns: Any | None = None, |
13 |
| - include_extras: bool = False, |
14 |
| -) -> dict[str, Any]: |
15 |
| - if get_origin(obj) == Intersection: |
16 |
| - args = get_args(obj) |
17 |
| - new_type_hints = {} |
18 |
| - for arg in args: |
19 |
| - new_type_hints.update( |
20 |
| - get_type_hints( |
21 |
| - arg, |
22 |
| - globalns=globalns, |
23 |
| - localns=localns, |
24 |
| - include_extras=include_extras, |
25 |
| - ) |
26 |
| - ) |
27 |
| - return new_type_hints |
28 |
| - else: |
29 |
| - return get_type_hints_old( |
30 |
| - obj, globalns=globalns, localns=localns, include_extras=include_extras |
31 |
| - ) |
| 1 | +__all__ = ["Intersection"] |
| 2 | +""" |
| 3 | +The idea of this is to simulate the return type of an intersection of two classes. |
| 4 | +This currently only works for direct methods or attributes of the class. |
| 5 | +""" |
| 6 | + |
| 7 | +from inspect import Signature, signature |
| 8 | +from typing import Any |
| 9 | + |
| 10 | + |
| 11 | +def signatures_compatible(s1: Signature, s2: Signature) -> bool: |
| 12 | + # TODO: Test for non overlapping signatures |
| 13 | + return s1 == s2 |
| 14 | + |
| 15 | + |
| 16 | +excluded_methods = ["__class__", "__init_subclass__", "__subclasshook__", "__new__"] |
| 17 | +get_attribute_excludes = ["__intersects__", "_test_lsp"] |
| 18 | + |
| 19 | + |
| 20 | +class Intersection: |
| 21 | + __intersects__: set[type[object]] |
| 22 | + |
| 23 | + def __init__(self, *intersects: type[object]) -> None: |
| 24 | + self.__intersects__ = set(reversed(intersects)) |
| 25 | + self._test_lsp() |
| 26 | + |
| 27 | + def __class_getitem__(cls, key): |
| 28 | + return cls(*key) |
| 29 | + |
| 30 | + def _test_lsp(self): |
| 31 | + intersected_attrs: dict[str, type] = {} |
| 32 | + signatures: dict[str, Signature] = {} |
| 33 | + for i in self.__intersects__: |
| 34 | + # Resolve basic annotations, ensuring no clashes |
| 35 | + for annotation_name, annotation_type in i.__annotations__.items(): |
| 36 | + if annotation_name in intersected_attrs: |
| 37 | + if annotation_type != intersected_attrs[annotation_name]: |
| 38 | + raise TypeError("LSP Violation") |
| 39 | + else: |
| 40 | + intersected_attrs[annotation_name] = annotation_type |
| 41 | + |
| 42 | + for method_name in dir(i): |
| 43 | + method = getattr(i, method_name) |
| 44 | + if callable(method) and method_name not in excluded_methods: |
| 45 | + sig = signature(method) |
| 46 | + if method_name in signatures: |
| 47 | + if not signatures_compatible(sig, signatures[method_name]): |
| 48 | + raise TypeError( |
| 49 | + f"Signatures {sig} and {signatures[method_name]} not compatible" |
| 50 | + ) |
| 51 | + else: |
| 52 | + signatures[method_name] = sig |
| 53 | + |
| 54 | + def __repr__(self) -> str: |
| 55 | + attrs = list(i.__name__ for i in self.__intersects__) |
| 56 | + return " & ".join(attrs) |
| 57 | + |
| 58 | + def __getattribute__(self, name: str): |
| 59 | + if name in get_attribute_excludes: |
| 60 | + return super().__getattribute__(name) |
| 61 | + for i in self.__intersects__: |
| 62 | + if hasattr(i, name) and callable(getattr(i, name)): |
| 63 | + return signature(getattr(i, name)) |
| 64 | + elif name in i.__annotations__: |
| 65 | + return i.__annotations__[name] |
| 66 | + |
| 67 | + if Any in self.__intersects__: |
| 68 | + return Any |
| 69 | + raise AttributeError(f"Attribute not found on type {self}") |
0 commit comments