8000 Fix missing type store for overloads (#16803) · python/mypy@5bf7742 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5bf7742

Browse files
authored
Fix missing type store for overloads (#16803)
Add missing call to store inferred types if an overload match is found early. All other code paths already do that. ### Some background on the issue this fixes I recently saw an interesting pattern in `aiohttp` to type values in an `dict[str, Any]` by subclassing dict. ```py T = TypeVar("T") U = TypeVar("U") class Key(Generic[T]): ... class CustomDict(dict[Key[Any] | str, Any]): @overload # type: ignore[override] def get(self, __key: Key[T]) -> T | None: ... @overload def get(self, __key: Key[T], __default: U) -> T | U: ... @overload def get(self, __key: str) -> Any | None: ... @overload def get(self, __key: str, __default: Any) -> Any: ... def get(self, __key: Key[Any] | str, __default: Any = None) -> Any: """Forward to super implementation.""" return super().get(__key, __default) # overloads for __getitem__, setdefault, pop # ... @overload # type: ignore[override] def __setitem__(self, key: Key[T], value: T) -> None: ... @overload def __setitem__(self, key: str, value: Any) -> None: ... def __setitem__(self, key: Key[Any] | str, value: Any) -> None: """Forward to super implementation.""" return super().__setitem__(key, value) ``` With the exception that these overloads aren't technically compatible with the supertype, they do the job. ```py d = CustomDict() key = Key[int]() other_key = "other" assert_type(d.get(key), int | None) assert_type(d.get("other"), Any | None) ``` The issue exists for the `__setitem__` case. Without this PR the following would create an issue. Here `var` would be inferred as `dict[Never, Never]`, even though it should be `dict[Any, Any]` which is the case for non-subclassed dicts. ```py def a2(d: CustomDict) -> None: if (var := d.get("arg")) is None: var = d["arg"] = {} reveal_type(var) ```
1 parent 55247c4 commit 5bf7742

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

mypy/checkexpr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,6 +2825,7 @@ def infer_overload_return_type(
28252825
# Return early if possible; otherwise record info, so we can
28262826
# check for ambiguity due to 'Any' below.
28272827
if not args_contain_any:
2828+
self.chk.store_types(m)
28282829
return ret_type, infer_type
28292830
p_infer_type = get_proper_type(infer_type)
28302831
if isinstance(p_infer_type, CallableType):

test-data/unit/check-generics.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,30 @@ if int():
14801480
b = f(b)
14811481
[builtins fixtures/list.pyi]
14821482

1483+
[case testGenericDictWithOverload]
1484+
from typing import Dict, Generic, TypeVar, Any, overload
1485+
T = TypeVar("T")
1486+
1487+
class Key(Generic[T]): ...
1488+
class CustomDict(dict):
1489+
@overload # type: ignore[override]
1490+
def __setitem__(self, key: Key[T], value: T) -> None: ...
1491+
@overload
1492+
def __setitem__(self, key: str, value: Any) -> None: ...
1493+
def __setitem__(self, key, value):
1494+
return super().__setitem__(key, value)
1495+
1496+
def a1(d: Dict[str, Any]) -> None:
1497+
if (var := d.get("arg")) is None:
1498+
var = d["arg"] = {}
1499+
reveal_type(var) # N: Revealed type is "builtins.dict[Any, Any]"
1500+
1501+
def a2(d: CustomDict) -> None:
1502+
if (var := d.get("arg")) is None:
1503+
var = d["arg"] = {}
1504+
reveal_type(var) # N: Revealed type is "builtins.dict[Any, Any]"
1505+
[builtins fixtures/dict.pyi]
1506+
14831507

14841508
-- Type variable scoping
14851509
-- ---------------------

test-data/unit/typexport-basic.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,27 @@ LambdaExpr(10) : def (x: builtins.int) -> builtins.int
12361236
LambdaExpr(12) : def (y: builtins.str) -> builtins.str
12371237
LambdaExpr(13) : def (x: builtins.str) -> builtins.str
12381238

1239+
[case testExportOverloadArgTypeDict]
1240+
## DictExpr
1241+
from typing import TypeVar, Generic, Any, overload, Dict
1242+
T = TypeVar("T")
1243+
class Key(Generic[T]): ...
1244+
@overload
1245+
def f(x: Key[T], y: T) -> T: ...
1246+
@overload
1247+
def f(x: int, y: Any) -> Any: ...
1248+
def f(x, y): ...
1249+
d: Dict = {}
1250+
d.get(
1251+
"", {})
1252+
f(
1253+
2, {})
1254+
[builtins fixtures/dict.pyi]
1255+
[out]
1256+
DictExpr(10) : builtins.dict[Any, Any]
1257+
DictExpr(12) : builtins.dict[Any, Any]
1258+
DictExpr(14) : builtins.dict[Any, Any]
1259+
12391260
-- TODO
12401261
--
12411262
-- test expressions

0 commit comments

Comments
 (0)
0