8000 Fix subtyping between ParamSpecs (#15892) · python/mypy@b3d0937 · GitHub
[go: up one dir, main page]

Skip to content

Commit b3d0937

Browse files
authored
Fix subtyping between ParamSpecs (#15892)
Fixes #14169 Fixes #14168 Two sings here: * Actually check prefix when we should * `strict_concatenate` check should be off by default (IIUC it is not mandated by the PEP)
1 parent 76c16a4 commit b3d0937

File tree

5 files changed

+94
-16
lines changed

5 files changed

+94
-16
lines changed

mypy/expandtype.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,6 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
383383
t = t.expand_param_spec(repl)
384384
return t.copy_modified(
385385
arg_types=self.expand_types(t.arg_types),
386-
arg_kinds=t.arg_kinds,
387-
arg_names=t.arg_names,
388386
ret_type=t.ret_type.accept(self),
389387
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
390388
)
@@ -402,6 +400,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
402400
arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:],
403401
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
404402
ret_type=t.ret_type.accept(self),
403+
from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
405404
)
406405

407406
var_arg = t.var_arg()

mypy/messages.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,9 +2116,11 @@ def report_protocol_problems(
21162116
return
21172117

21182118
# Report member type conflicts
2119-
conflict_types = get_conflict_protocol_types(subtype, supertype, class_obj=class_obj)
2119+
conflict_types = get_conflict_protocol_types(
2120+
subtype, supertype, class_obj=class_obj, options=self.options
2121+
)
21202122
if conflict_types and (
2121-
not is_subtype(subtype, erase_type(supertype))
2123+
not is_subtype(subtype, erase_type(supertype), options=self.options)
21222124
or not subtype.type.defn.type_vars
21232125
or not supertype.type.defn.type_vars
21242126
):
@@ -2780,7 +2782,11 @@ def [T <: int] f(self, x: int, y: T) -> None
27802782
slash = True
27812783

27822784
# If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list
2783-
if isinstance(tp.definition, FuncDef) and hasattr(tp.definition, "arguments"):
2785+
if (
2786+
isinstance(tp.definition, FuncDef)
2787+
and hasattr(tp.definition, "arguments")
2788+
and not tp.from_concatenate
2789+
):
27842790
definition_arg_names = [arg.variable.name for arg in tp.definition.arguments]
27852791
if (
27862792
len(definition_arg_names) > len(tp.arg_names)
@@ -2857,7 +2863,7 @@ def get_missing_protocol_members(left: Instance, right: Instance, skip: list[str
28572863

28582864

28592865
def get_conflict_protocol_types(
2860-
left: Instance, right: Instance, class_obj: bool = False
2866+
left: Instance, right: Instance, class_obj: bool = False, options: Options | None = None
28612867
) -> list[tuple[str, Type, Type]]:
28622868
"""Find members that are defined in 'left' but have incompatible types.
28632869
Return them as a list of ('member', 'got', 'expected').
@@ -2872,9 +2878,9 @@ def get_conflict_protocol_types(
28722878
subtype = mypy.typeops.get_protocol_member(left, member, class_obj)
28732879
if not subtype:
28742880
continue
2875-
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
2881+
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True, options=options)
28762882
if IS_SETTABLE in get_member_flags(member, right):
2877-
is_compat = is_compat and is_subtype(supertype, subtype)
2883+
is_compat = is_compat and is_subtype(supertype, subtype, options=options)
28782884
if not is_compat:
28792885
conflicts.append((member, subtype, supertype))
28802886
return conflicts

mypy/subtypes.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def check_mixed(
600600
type_state.record_negative_subtype_cache_entry(self._subtype_kind, left, right)
601601
return nominal
602602
if right.type.is_protocol and is_protocol_implementation(
603-
left, right, proper_subtype=self.proper_subtype
603+
left, right, proper_subtype=self.proper_subtype, options=self.options
604604
):
605605
return True
606606
# We record negative cache entry here, and not in the protocol check like we do for
@@ -647,7 +647,7 @@ def visit_param_spec(self, left: ParamSpecType) -> bool:
647647
and right.id == left.id
648648
and right.flavor == left.flavor
649649
):
650-
return True
650+
return self._is_subtype(left.prefix, right.prefix)
651651
if isinstance(right, Parameters) and are_trivial_parameters(right):
652652
return True
653653
return self._is_subtype(left.upper_bound, self.right)
@@ -696,7 +696,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
696696
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
697697
strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate)
698698
if self.options
699-
else True,
699+
else False,
700700
)
701701
elif isinstance(right, Overloaded):
702702
return all(self._is_subtype(left, item) for item in right.items)
@@ -863,7 +863,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
863863
strict_concat = (
864864
(self.options.extra_checks or self.options.strict_concatenate)
865865
if self.options
866-
else True
866+
else False
867867
)
868868
if left_index not in matched_overloads and (
869869
is_callable_compatible(
@@ -1003,6 +1003,7 @@ def is_protocol_implementation(
10031003
proper_subtype: bool = False,
10041004
class_obj: bool = False,
10051005
skip: list[str] | None = None,
1006+
options: Options | None = None,
10061007
) -> bool:
10071008
"""Check whether 'left' implements the protocol 'right'.
10081009
@@ -1068,7 +1069,9 @@ def f(self) -> A: ...
10681069
# Nominal check currently ignores arg names
10691070
# NOTE: If we ever change this, be sure to also change the call to
10701071
# SubtypeVisitor.build_subtype_kind(...) down below.
1071-
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=ignore_names)
1072+
is_compat = is_subtype(
1073+
subtype, supertype, ignore_pos_arg_names=ignore_names, options=options
1074+
)
10721075
else:
10731076
is_compat = is_proper_subtype(subtype, supertype)
10741077
if not is_compat:
@@ -1080,7 +1083,7 @@ def f(self) -> A: ...
10801083
superflags = get_member_flags(member, right)
10811084
if IS_SETTABLE in superflags:
10821085
# Check opposite direction for settable attributes.
1083-
if not is_subtype(supertype, subtype):
1086+
if not is_subtype(supertype, subtype, options=options):
10841087
return False
10851088
if not class_obj:
10861089
if IS_SETTABLE not in superflags:
@@ -1479,7 +1482,7 @@ def are_parameters_compatible(
14791482
ignore_pos_arg_names: bool = False,
14801483
check_args_covariantly: bool = False,
14811484
allow_partial_overlap: bool = False,
1482-
strict_concatenate_check: bool = True,
1485+
strict_concatenate_check: bool = False,
14831486
) -> bool:
14841487
"""Helper function for is_callable_compatible, used for Parameter compatibility"""
14851488
if right.is_ellipsis_args:

test-data/unit/check-overloading.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6483,7 +6483,7 @@ P = ParamSpec("P")
64836483
R = TypeVar("R")
64846484

64856485
@overload
6486-
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ...
6486+
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
64876487
@overload
64886488
def func(x: Callable[P, R]) -> Callable[Concatenate[str, P], R]: ...
64896489
def func(x: Callable[..., R]) -> Callable[..., R]: ...

test-data/unit/check-parameter-specification.test

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,3 +1576,73 @@ def test() -> None: ...
15761576
# TODO: avoid this error, although it may be non-trivial.
15771577
apply(apply, test) # E: Argument 2 to "apply" has incompatible type "Callable[[], None]"; expected "Callable[P, T]"
15781578
[builtins fixtures/paramspec.pyi]
1579+
1580+
[case testParamSpecPrefixSubtypingGenericInvalid]
1581+
from typing import Generic
1582+
from typing_extensions import ParamSpec, Concatenate
1583+
1584+
P = ParamSpec("P")
1585+
1586+
class A(Generic[P]):
1587+
def foo(self, *args: P.args, **kwargs: P.kwargs):
1588+
...
1589+
1590+
def bar(b: A[P]) -> A[Concatenate[int, P]]:
1591+
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
1592+
[builtins fixtures/paramspec.pyi]
1593+
1594+
[case testParamSpecPrefixSubtypingProtocolInvalid]
1595+
from typing import Protocol
1596+
from typing_extensions import ParamSpec, Concatenate
1597+
1598+
P = ParamSpec("P")
1599+
1600+
class A(Protocol[P]):
1601+
def foo(self, *args: P.args, **kwargs: P.kwargs):
1602+
...
1603+
1604+
def bar(b: A[P]) -> A[Concatenate[int, P]]:
1605+
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
1606+
[builtins fixtures/paramspec.pyi]
1607+
1608+
[case testParamSpecPrefixSubtypingValidNonStrict]
1609+
from typing import Protocol
1610+
from typing_extensions import ParamSpec, Concatenate
1611+
1612+
P = ParamSpec("P")
1613+
1614+
class A(Protocol[P]):
1615+
def foo(self, a: int, *args: P.args, **kwargs: P.kwargs):
1616+
...
1617+
1618+
class B(Protocol[P]):
1619+
def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs):
1620+
...
1621+
1622+
def bar(b: B[P]) -> A[Concatenate[int, P]]:
1623+
return b
1624+
[builtins fixtures/paramspec.pyi]
1625+
1626+
[case testParamSpecPrefixSubtypingInvalidStrict]
1627+
# flags: --extra-checks
1628+
from typing import Protocol
1629+
from typing_extensions import ParamSpec, Concatenate
1630+
1631+
P = ParamSpec("P")
1632+
1633+
class A(Protocol[P]):
1634+
def foo(self, a: int, *args: P.args, **kwargs: P.kwargs):
1635+
...
1636+
1637+
class B(Protocol[P]):
1638+
def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs):
1639+
...
1640+
1641+
def bar(b: B[P]) -> A[Concatenate[int, P]]:
1642+
return b # E: Incompatible return value type (got "B[P]", expected "A[[int, **P]]") \
1643+
# N: Following member(s) of "B[P]" have conflicts: \
1644+
# N: Expected: \
1645+
# N: def foo(self, a: int, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \
1646+
# N: Got: \
1647+
# N: def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs) -> Any
1648+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)
0