8000 attrs.evolve: support generics and unions (#15050) · python/mypy@2a4c473 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a4c473

Browse files
authored
attrs.evolve: support generics and unions (#15050)
Fixes `attrs.evolve` signature generation to support the `inst` parameter being - a generic attrs class - a union of attrs classes - a mix of the two In the case of unions, we "meet" the fields of the potential attrs classes, so that the resulting signature is the lower bound. Fixes #15088.
1 parent 0845818 commit 2a4c473

File tree

3 files changed

+171
-27
lines changed

3 files changed

+171
-27
lines changed

mypy/plugins/attrs.py

Lines changed: 90 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
from __future__ import annotations
44

5-
from typing import Iterable, List, cast
5+
from collections import defaultdict
6+
from functools import reduce
7+
from typing import Iterable, List, Mapping, cast
68
from typing_extensions import Final, Literal
79

810
import mypy.plugin # To avoid circular imports.
911
from mypy.applytype import apply_generic_arguments
1012
from mypy.checker import TypeChecker
1113
from mypy.errorcodes import LITERAL_REQ
12-
from mypy.expandtype import expand_type
14+
from mypy.expandtype import expand_type, expand_type_by_instance
1315
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
16+
from mypy.meet import meet_types
1417
from mypy.messages import format_type_bare
1518
from mypy.nodes import (
1619
ARG_NAMED,
@@ -67,6 +70,7 @@
6770
Type,
6871
TypeOfAny,
6972
TypeVarType,
73+
UninhabitedType,
7074
UnionType,
7175
get_proper_type,
7276
)
@@ -942,12 +946,82 @@ def _get_attrs_init_type(typ: Instance) -> CallableType | None:
942946
return init_method.type
943947

944948

945-
def _get_attrs_cls_and_init(typ: ProperType) -> tuple[Instance | None, CallableType | None]:
946-
if isinstance(typ, TypeVarType):
947-
typ = get_proper_type(typ.upper_bound)
948-
if not isinstance(typ, Instance):
949-
return None, None
950-
return typ, _get_attrs_init_type(typ)
949+
def _fail_not_attrs_class(ctx: mypy.plugin.FunctionSigContext, t: Type, parent_t: Type) -> None:
950+
t_name = format_type_bare(t, ctx.api.options)
951+
if parent_t is t:
952+
msg = (
953+
f'Argument 1 to "evolve" has a variable type "{t_name}" not bound to an attrs class'
954+
if isinstance(t, TypeVarType)
955+
else f'Argument 1 to "evolve" has incompatible type "{t_name}"; expected an attrs class'
956+
)
957+
else:
958+
pt_name = format_type_bare(parent_t, ctx.api.options)
959+
msg = (
960+
f'Argument 1 to "evolve" has type "{pt_name}" whose item "{t_name}" is not bound to an attrs class'
961+
if isinstance(t, TypeVarType)
962+
else f'Argument 1 to "evolve" has incompatible type "{pt_name}" whose item "{t_name}" is not an attrs class'
963+
)
964+
965+
ctx.api.fail(msg, ctx.context)
966+
967+
968+
def _get_expanded_attr_types(
969+
ctx: mypy.plugin.FunctionSigContext,
970+
typ: ProperType,
971+
display_typ: ProperType,
972+
parent_typ: ProperType,
973+
) -> list[Mapping[str, Type]] | None:
974+
"""
975+
For a given type, determine what attrs classes it can be: for each class, return the field types.
976+
For generic classes, the field types are expanded.
977+
If the type contains Any or a non-attrs type, returns None; in the latter case, also reports an error.
978+
"""
979+
if isinstance(typ, AnyType):
980+
return None
981+
elif isinstance(typ, UnionType):
982+
ret: list[Mapping[str, Type]] | None = []
983+
for item in typ.relevant_items():
984+
item = get_proper_type(item)
985+
item_types = _get_expanded_attr_types(ctx, item, item, parent_typ)
986+
if ret is not None and item_types is not None:
987+
ret += item_types
988+
else:
989+
ret = None # but keep iterating to emit all errors
990+
return ret
991+
elif isinstance(typ, TypeVarType):
992+
return _get_expanded_attr_types(
993+
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
994+
)
995+
elif isinstance(typ, Instance):
996+
init_func = _get_attrs_init_type(typ)
997+
if init_func is None:
998+
_fail_not_attrs_class(ctx, display_typ, parent_typ)
999+
return None
1000+
init_func = expand_type_by_instance(init_func, typ)
1001+
# [1:] to skip the self argument of AttrClass.__init__
1002+
field_names = cast(List[str], init_func.arg_names[1:])
1003+
field_types = init_func.arg_types[1:]
1004+
return [dict(zip(field_names, field_types))]
1005+
else:
1006+
_fail_not_attrs_class(ctx, display_typ, parent_typ)
1007+
return None
1008+
1009+
1010+
def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]:
1011+
"""
1012+
"Meets" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound.
1013+
"""
1014+
field_to_types = defaultdict(list)
1015+
for fields in types:
1016+
for name, typ in fields.items():
1017+
field_to_types[name].append(typ)
1018+
1019+
return {
1020+
name: get_proper_type(reduce(meet_types, f_types))
1021+
if len(f_types) == len(types)
1022+
else UninhabitedType()
1023+
for name, f_types in field_to_types.items()
1024+
}
9511025

9521026

9531027
def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
@@ -971,27 +1045,18 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
9711045
# </hack>
9721046

9731047
inst_type = get_proper_type(inst_type)
974-
if isinstance(inst_type, AnyType):
975-
return ctx.default_signature # evolve(Any, ....) -> Any
9761048
inst_type_str = format_type_bare(inst_type, ctx.api.options)
9771049

978-
attrs_type, attrs_init_type = _get_attrs_cls_and_init(inst_type)
979-
if attrs_type is None or attrs_init_type is None:
980-
ctx.api.fail(
981-
f'Argument 1 to "evolve" has a variable type "{inst_type_str}" not bound to an attrs class'
982-
if isinstance(inst_type, TypeVarType)
983-
else f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class',
984-
ctx.context,
985-
)
1050+
attr_types = _get_expanded_attr_types(ctx, inst_type, inst_type, inst_type)
1051+
if attr_types is None:
9861052
return ctx.default_signature
1053+
fields = _meet_fields(attr_types)
9871054

988-
# AttrClass.__init__ has the following signature (or similar, if having kw-only & defaults):
989-
# def __init__(self, attr1: Type1, attr2: Type2) -> None:
990-
# We want to generate a signature for evolve that looks like this:
991-
# def evolve(inst: AttrClass, *, attr1: Type1 = ..., attr2: Type2 = ...) -> AttrClass:
992-
return attrs_init_type.copy_modified(
993-
arg_names=["inst"] + attrs_init_type.arg_names[1:],
994-
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT for _ in attrs_init_type.arg_kinds[1:]],
1055+
return CallableType(
1056+
arg_names=["inst", *fields.keys()],
1057+
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT] * len(fields),
1058+
arg_types=[inst_type, *fields.values()],
9951059
ret_type=inst_type,
1060+
fallback=ctx.default_signature.fallback,
9961061
name=f"{ctx.default_signature.name} of {inst_type_str}",
9971062
)

test-data/unit/check-attr.test

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,81 @@ reveal_type(ret) # N: Revealed type is "Any"
19701970

19711971
[typing fixtures/typing-medium.pyi]
19721972

1973+
[case testEvolveGeneric]
1974+
import attrs
1975+
from typing import Generic, TypeVar
1976+
1977+
T = TypeVar('T')
1978+
1979+
@attrs.define
1980+
class A(Generic[T]):
1981+
x: T
1982+
1983+
1984+
a = A(x=42)
1985+
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
1986+
a2 = attrs.evolve(a, x=42)
1987+
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
1988+
a2 = attrs.evolve(a, x='42') # E: Argument "x" to "evolve" of "A[int]" has incompatible type "str"; expected "int"
1989+
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
1990+
1991+
[builtins fixtures/attr.pyi]
1992+
1993+
[case testEvolveUnion]
1994+
# flags: --python-version 3.10
1995+
from typing import Generic, TypeVar
1996+
import attrs
1997+
1998+
T = TypeVar('T')
1999+
2000+
2001+
@attrs.define
2002+
class A(Generic[T]):
2003+
x: T # exercises meet(T=int, int) = int
2004+
y: bool # exercises meet(bool, int) = bool
2005+
z: str # exercises meet(str, bytes) = <nothing>
2006+
w: dict # exercises meet(dict, <nothing>) = <nothing>
2007+
2008+
2009+
@attrs.define
2010+
class B:
2011+
x: int
2012+
y: bool
2013+
z: bytes
2014+
2015+
2016+
a_or_b: A[int] | B
2017+
a2 = attrs.evolve(a_or_b, x=42, y=True)
2018+
a2 = attrs.evolve(a_or_b, x=42, y=True, z='42') # E: Argument "z" to "evolve" of "Union[A[int], B]" has incompatible type "str"; expected <nothing>
2019+
a2 = attrs.evolve(a_or_b, x=42, y=True, w={}) # E: Argument "w" to "evolve" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>
2020+
2021+
[builtins fixtures/attr.pyi]
2022+
2023+
[case testEvolveUnionOfTypeVar]
2024+
# flags: --python-version 3.10
2025+
import attrs
2026+
from typing import TypeVar
2027+
2028+
@attrs.define
2029+
class A:
2030+
x: int
2031+
y: int
2032+
z: str
2033+
w: dict
2034+
2035+
2036+
class B:
2037+
pass
2038+
2039+
TA = TypeVar('TA', bound=A)
2040+
TB = TypeVar('TB', bound=B)
2041+
2042+
def f(b_or_t: TA | TB | int) -> None:
2043+
a2 = attrs.evolve(b_or_t) # E: Argument 1 to "evolve" has type "Union[TA, TB, int]" whose item "TB" is not bound to an attrs class # E: Argument 1 to "evolve" has incompatible type "Union[TA, TB, int]" whose item "int" is not an attrs class
2044+
2045+
2046+
[builtins fixtures/attr.pyi]
2047+
19732048
[case testEvolveTypeVarBound]
19742049
import attrs
19752050
from typing import TypeVar
@@ -1997,11 +2072,12 @@ f(B(x=42))
19972072

19982073
[case testEvolveTypeVarBoundNonAttrs]
19992074
import attrs
2000-
from typing import TypeVar
2075+
from typing import Union, TypeVar
20012076

20022077
TInt = TypeVar('TInt', bound=int)
20032078
TAny = TypeVar('TAny')
20042079
TNone = TypeVar('TNone', bound=None)
2080+
TUnion = TypeVar('TUnion', bound=Union[str, int])
20052081

20062082
def f(t: TInt) -> None:
20072083
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TInt" not bound to an attrs class
@@ -2012,6 +2088,9 @@ def g(t: TAny) -> None:
20122088
def h(t: TNone) -> None:
20132089
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TNone" not bound to an attrs class
20142090

2091+
def x(t: TUnion) -> None:
2092+
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "str" is not an attrs class # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "int" is not an attrs class
2093+
20152094
[builtins fixtures/attr.pyi]
20162095

20172096
[case testEvolveTypeVarConstrained]

test-data/unit/fixtures/attr.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ class object:
99
class type: pass
1010
class bytes: pass
1111
class function: pass
12-
class bool: pass
1312
class float: pass
1413
class int:
1514
@overload
1615
def __init__(self, x: Union[str, bytes, int] = ...) -> None: ...
1716
@overload
1817
def __init__(self, x: Union[str, bytes], base: int) -> None: ...
18+
class bool(int): pass
1919
class complex:
2020
@overload
2121
def __init__(self, real: float = ..., im: float = ...) -> None: ...

0 commit comments

Comments
 (0)
0