8000 Support type inference for defaultdict() (#8167) · python/mypy@c1cd529 · GitHub
[go: up one dir, main page]

Skip to content

Commit c1cd529

Browse files
authored
Support type inference for defaultdict() (#8167)
This allows inferring type of `x`, for example: ``` from collections import defaultdict x = defaultdict(list) # defaultdict[str, List[int]] x['foo'].append(1) ``` The implemention is not pretty and we have probably reached about the maximum reasonable level of special casing in type inference now. There is a hack to work around the problem with leaking type variable types in nested generics calls (I think). This will break some (likely very rare) use cases.
1 parent 5f16416 commit c1cd529

File tree

8 files changed

+250
-51
lines changed

8 files changed

+250
-51
lines changed

mypy/checker.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,14 +2813,26 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
28132813
partial_type = PartialType(None, name)
28142814
elif isinstance(init_type, Instance):
28152815
fullname = init_type.type.fullname
2816-
if (isinstance(lvalue, (NameExpr, MemberExpr)) and
2816+
is_ref = isinstance(lvalue, RefExpr)
2817+
if (is_ref and
28172818
(fullname == 'builtins.list' or
28182819
fullname == 'builtins.set' or
28192820
fullname == 'builtins.dict' or
28202821
fullname == 'collections.OrderedDict') and
28212822
all(isinstance(t, (NoneType, UninhabitedType))
28222823
for t in get_proper_types(init_type.args))):
28232824
partial_type = PartialType(init_type.type, name)
2825+
elif is_ref and fullname == 'collections.defaultdict':
2826+
arg0 = get_proper_type(init_type.args[0])
2827+
arg1 = get_proper_type(init_type.args[1])
2828+
if (isinstance(arg0, (NoneType, UninhabitedType)) and
2829+
isinstance(arg1, Instance) and
2830+
self.is_valid_defaultdict_partial_value_type(arg1)):
2831+
# Erase type argument, if one exists (this fills in Anys)
2832+
arg1 = self.named_type(arg1.type.fullname)
2833+
partial_type = PartialType(init_type.type, name, arg1)
2834+
else:
2835+
return False
28242836
else:
28252837
return False
28262838
else:
@@ -2829,6 +2841,28 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
28292841
self.partial_types[-1].map[name] = lvalue
28302842
return True
28312843

2844+
def is_valid_defaultdict_partial_value_type(self, t: Instance) -> bool:
2845+
"""Check if t can be used as the basis for a partial defaultddict value type.
2846+
2847+
Examples:
2848+
2849+
* t is 'int' --> True
2850+
* t is 'list[<nothing>]' --> True
2851+
* t is 'dict[...]' --> False (only generic types with a single type
2852+
argument supported)
2853+
"""
2854+
if len(t.args) == 0:
2855+
return True
2856+
if len(t.args) == 1:
2857+
arg = get_proper_type(t.args[0])
2858+
# TODO: This is too permissive -- we only allow TypeVarType since
2859+
# they leak in cases like defaultdict(list) due to a bug.
2860+
# This can result in incorrect types being inferred, but only
2861+
# in rare cases.
2862+
if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)):
2863+
return True
2864+
return False
2865+
28322866
def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
28332867
"""Store inferred variable type.
28342868
@@ -3018,16 +3052,21 @@ def try_infer_partial_type_from_indexed_assignment(
30183052
if partial_types is None:
30193053
return
30203054
typename = type_type.fullname
3021-
if typename == 'builtins.dict' or typename == 'collections.OrderedDict':
3055+
if (typename == 'builtins.dict'
3056+
or typename == 'collections.OrderedDict'
3057+
or typename == 'collections.defaultdict'):
30223058
# TODO: Don't infer things twice.
30233059
key_type = self.expr_checker.accept(lvalue.index)
30243060
value_type = self.expr_checker.accept(rvalue)
30253061
if (is_valid_inferred_type(key_type) and
3026-
is_valid_inferred_type(value_type)):
3027-
if not self.current_node_deferred:
3028-
var.type = self.named_generic_type(typename,
3029-
[key_type, value_type])
3030-
del partial_types[var]
3062+
is_valid_inferred_type(value_type) and
3063+
not self.current_node_deferred and
3064+
not (typename == 'collections.defaultdict' and
3065+
var.type.value_type is not None and
3066+
not is_equivalent(value_type, var.type.value_type))):
3067+
var.type = self.named_generic_type(typename,
3068+
[key_type, value_type])
3069+
del partial_types[var]
30313070

30323071
def visit_expression_stmt(self, s: ExpressionStmt) -> None:
30333072
self.expr_checker.accept(s.expr, allow_none_return=True, always_allow_any=True)

mypy/checkexpr.py

Lines changed: 90 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -567,42 +567,91 @@ def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]:
567567
} # type: ClassVar[Dict[str, Dict[str, List[str]]]]
568568

569569
def try_infer_partial_type(self, e: CallExpr) -> None:
570-
if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr):
571-
var = e.callee.expr.node
572-
if var is None and isinstance(e.callee.expr, MemberExpr):
573-
var = self.get_partial_self_var(e.callee.expr)
574-
if not isinstance(var, Var):
570+
"""Try to make partial type precise from a call."""
571+
if not isinstance(e.callee, MemberExpr):
572+
return
573+
callee = e.callee
574+
if isinstance(callee.expr, RefExpr):
575+
# Call a method with a RefExpr callee, such as 'x.method(...)'.
576+
ret = self.get_partial_var(callee.expr)
577+
if ret is None:
575578
return
576-
partial_types = self.chk.find_partial_types(var)
577-
if partial_types is not None and not self.chk.current_node_deferred:
578-
partial_type = var.type
579-
if (partial_type is None or
580-
not isinstance(partial_type, PartialType) or
581-
partial_type.type is None):
582-
# A partial None type -> can't infer anything.
583-
return
584-
typename = partial_type.type.fullname
585-
methodname = e.callee.name
586-
# Sometimes we can infer a full type for a partial List, Dict or Set type.
587-
# TODO: Don't infer argument expression twice.
588-
if (typename in self.item_args and methodname in self.item_args[typename]
589-
and e.arg_kinds == [ARG_POS]):
590-
item_type = self.accept(e.args[0])
591-
if mypy.checker.is_valid_inferred_type(item_type):
592-
var.type = self.chk.named_generic_type(typename, [item_type])
593-
del partial_types[var]
594-
elif (typename in self.container_args
595-
and methodname in self.container_args[typename]
596-
and e.arg_kinds == [ARG_POS]):
597-
arg_type = get_proper_type(self.accept(e.args[0]))
598-
if isinstance(arg_type, Instance):
599-
arg_typename = arg_type.type.fullname
600-
if arg_typename in self.container_args[typename][methodname]:
601-
if all(mypy.checker.is_valid_inferred_type(item_type)
602-
for item_type in arg_type.args):
603-
var.type = self.chk.named_generic_type(typename,
604-
list(arg_type.args))
605-
del partial_types[var]
579+
var, partial_types = ret
580+
typ = self.try_infer_partial_value_type_from_call(e, callee.name, var)
581+
if typ is not None:
582+
var.type = typ
583+
del partial_types[var]
584+
elif isinstance(callee.expr, IndexExpr) and isinstance(callee.expr.base, RefExpr):
585+
# Call 'x[y].method(...)'; may infer type of 'x' if it's a partial defaultdict.
586+
if callee.expr.analyzed is not None:
587+
return # A special form
588+
base = callee.expr.base
589+
index = callee.expr.index
590+
ret = self.get_partial_var(base)
591+
if ret is None:
592+
return
593+
var, partial_types = ret
594+
partial_type = get_partial_instance_type(var.type)
595+
if partial_type is None or partial_type.value_type is None:
596+
return
597+
value_type = self.try_infer_partial_value_type_from_call(e, callee.name, var)
598+
if value_type is not None:
599+
# Infer key type.
600+
key_type = self.accept(index)
601+
if mypy.checker.is_valid_inferred_type(key_type):
602+
# Store inferred partial type.
603+
assert partial_type.type is not None
604+
typename = partial_type.type.fullname
605+
var.type = self.chk.named_generic_type(typename,
606+
[key_type, value_type])
607+
del partial_types[var]
608+
609+
def get_partial_var(self, ref: RefExpr) -> Optional[Tuple[Var, Dict[Var, Context]]]:
610+
var = ref.node
611+
if var is None and isinstance(ref, MemberExpr):
612+
var = self.get_partial_self_var(ref)
613+
if not isinstance(var, Var):
614+
return None
615+
partial_types = self.chk.find_partial_types(var)
616+
if partial_types is None:
617+
return None
618+
return var, partial_types
619+
620+
def try_infer_partial_value_type_from_call(
621+
self,
622+
e: CallExpr,
623+
methodname: str,
624+
var: Var) -> Optional[Instance]:
625+
"""Try to make partial type precise from a call such as 'x.append(y)'."""
626+
if self.chk.current_node_deferred:
627+
return None
628+
partial_type = get_partial_instance_type(var.type)
629+
if partial_type is None:
630+
return None
631+
if partial_type.value_type:
632+
typename = partial_type.value_type.type.fullname
633+
else:
634+
assert partial_type.type is not None
635+
typename = partial_type.type.fullname
636+
# Sometimes we can infer a full type for a partial List, Dict or Set type.
637+
# TODO: Don't infer argument expression twice.
638+
if (typename in self.item_args and methodname in self.item_args[typename]
639+
and e.arg_kinds == [ARG_POS]):
640+
item_type = self.accept(e.args[0])
641+
if mypy.checker.is_valid_inferred_type(item_type):
642+
return self.chk.named_generic_type(typename, [item_type])
643+
elif (typename in self.container_args
644+
and methodname in self.container_args[typename]
645+
and e.arg_kinds == [ARG_POS]):
646+
arg_type = get_proper_type(self.accept(e.args[0]))
647+
if isinstance(arg_type, Instance):
648+
arg_typename = arg_type.type.fullname
649+
if arg_typename in self.container_args[typename][methodname]:
650+
if all(mypy.checker.is_valid_inferred_type(item_type)
651+
for item_type in arg_type.args):
652+
return self.chk.named_generic_type(typename,
653+
list(arg_type.args))
654+
return None
606655

607656
def apply_function_plugin(self,
608657
callee: CallableType,
@@ -4299,3 +4348,9 @@ def is_operator_method(fullname: Optional[str]) -> bool:
42994348
short_name in nodes.op_methods.values() or
43004349
short_name in nodes.reverse_op_methods.values() or
43014350
short_name in nodes.unary_op_methods.values())
4351+
4352+
4353+
def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]:
4354+
if t is None or not isinstance(t, PartialType) or t.type is None:
4355+
return None
4356+
return t

mypy/types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1763,13 +1763,18 @@ class PartialType(ProperType):
17631763
# None for the 'None' partial type; otherwise a generic class
17641764
type = None # type: Optional[mypy.nodes.TypeInfo]
17651765
var = None # type: mypy.nodes.Var
1766+
# For partial defaultdict[K, V], the type V (K is unknown). If V is generic,
1767+
# the type argument is Any and will be replaced later.
1768+
value_type = None # type: Optional[Instance]
17661769

17671770
def __init__(self,
17681771
type: 'Optional[mypy.nodes.TypeInfo]',
1769-
var: 'mypy.nodes.Var') -> None:
1772+
var: 'mypy.nodes.Var',
1773+
value_type: 'Optional[Instance]' = None) -> None:
17701774
super().__init__()
17711775
self.type = type
17721776
self.var = var
1777+
self.value_type = value_type
17731778

17741779
def accept(self, visitor: 'TypeVisitor[T]') -> T:
17751780
return visitor.visit_partial_type(self)

mypyc/test-data/fixtures/ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class ellipsis: pass
2626
# Primitive types are special in generated code.
2727

2828
class int:
29+
@overload
30+
def __init__(self) -> None: pass
31+
@overload
2932
def __init__(self, x: object, base: int = 10) -> None: pass
3033
def __add__(self, n: int) -> int: pass
3134
def __sub__(self, n: int) -> int: pass

test-data/unit/check-inference.test

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,3 +2976,97 @@ x: Optional[str]
29762976
y = filter(None, [x])
29772977
reveal_type(y) # N: Revealed type is 'builtins.list[builtins.str*]'
29782978
[builtins fixtures/list.pyi]
2979+
2980+
[case testPartialDefaultDict]
2981+
from collections import defaultdict
2982+
x = defaultdict(int)
2983+
x[''] = 1
2984+
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'
2985+
2986+
y = defaultdict(int) # E: Need type annotation for 'y'
2987+
2988+
z = defaultdict(int) # E: Need type annotation for 'z'
2989+
z[''] = ''
2990+
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
2991+
[builtins fixtures/dict.pyi]
2992+
2993+
[case testPartialDefaultDictInconsistentValueTypes]
2994+
from collections import defaultdict
2995+
a = defaultdict(int) # E: Need type annotation for 'a'
2996+
a[''] = ''
2997+
a[''] = 1
2998+
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'
2999+
[builtins fixtures/dict.pyi]
3000+
3001+
[case testPartialDefaultDictListValue]
3002+
# flags: --no-strict-optional
3003+
from collections import defaultdict
3004+
a = defaultdict(list)
3005+
a['x'].append(1)
3006+
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'
3007+
3008+
b = defaultdict(lambda: [])
3009+
b[1].append('x')
3010+
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]'
3011+
[builtins fixtures/dict.pyi]
3012+
3013+
[case testPartialDefaultDictListValueStrictOptional]
3014+
# flags: --strict-optional
3015+
from collections import defaultdict
3016+
a = defaultdict(list)
3017+
a['x'].append(1)
3018+
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'
3019+
3020+
b = defaultdict(lambda: [])
3021+
b[1].append('x')
3022+
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]'
3023+
[builtins fixtures/dict.pyi]
3024+
3025+
[case testPartialDefaultDictSpecialCases]
3026+
from collections import defaultdict
3027+
class A:
3028+
def f(self) -> None:
3029+
self.x = defaultdict(list)
3030+
self.x['x'].append(1)
3031+
reveal_type(self.x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'
3032+
self.y = defaultdict(list) # E: Need type annotation for 'y'
3033+
s = self
3034+
s.y['x'].append(1)
3035+
3036+
x = {} # E: Need type annotation for 'x' (hint: "x: Dict[<type>, <type>] = ...")
3037+
x['x'].append(1)
3038+
3039+
y = defaultdict(list) # E: Need type annotation for 'y'
3040+
y[[]].append(1)
3041+
[builtins fixtures/dict.pyi]
3042+
3043+
[case testPartialDefaultDictSpecialCases2]
3044+
from collections import defaultdict
3045+
3046+
x = defaultdict(lambda: [1]) # E: Need type annotation for 'x'
3047+
x[1].append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
3048+
reveal_type(x) # N: Revealed type is 'collections.defaultdict[Any, builtins.list[builtins.int]]'
3049+
3050+
xx = defaultdict(lambda: {'x': 1}) # E: Need type annotation for 'xx'
3051+
xx[1]['z'] = 3
3052+
reveal_type(xx) # N: Revealed type is 'collections.defaultdict[Any, builtins.dict[builtins.str, builtins.int]]'
3053+
3054+
y = defaultdict(dict) # E: Need type annotation for 'y'
3055+
y['x'][1] = [3]
3056+
3057+
z = defaultdict(int) # E: Need type annotation for 'z'
3058+
z[1].append('')
3059+
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
3060+
[builtins fixtures/dict.pyi]
3061+
3062+
[case testPartialDefaultDictSpecialCase3]
3063+
from collections import defaultdict
3064+
3065+
x = defaultdict(list)
3066+
x['a'] = [1, 2, 3]
3067+
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int*]]'
3068+
3069+
y = defaultdict(list) # E: Need type annotation for 'y'
3070+
y['a'] = []
3071+
reveal_type(y) # N: Revealed type is 'collections.defaultdict[Any, Any]'
3072+
[builtins fixtures/dict.pyi]

test-data/unit/fixtures/dict.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class list(Sequence[T]): # needed by some test cases
4242
def __iter__(self) -> Iterator[T]: pass
4343
def __mul__(self, x: int) -> list[T]: pass
4444
def __contains__(self, item: object) -> bool: pass
45+
def append(self, item: T) -> None: pass
4546

4647
class tuple(Generic[T]): pass
4748
class function: pass
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Iterable, Union, Optional, Dict, TypeVar
1+
from typing import Any, Iterable, Union, Optional, Dict, TypeVar, overload, Optional, Callable
22

33
def namedtuple(
44
typename: str,
@@ -10,8 +10,10 @@ def namedtuple(
1010
defaults: Optional[Iterable[Any]] = ...
1111
) -> Any: ...
1212

13-
K = TypeVar('K')
14-
V = TypeVar('V')
13+
KT = TypeVar('KT')
14+
VT = TypeVar('VT')
1515

16-
class OrderedDict(Dict[K, V]):
17-
def __setitem__(self, k: K, v: V) -> None: ...
16+
class OrderedDict(Dict[KT, VT]): ...
17+
18+
class defaultdict(Dict[KT, VT]):
19+
def __init__(self, default_factory: Optional[Callable[[], VT]]) -> None: ...

test-data/unit/python2eval.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,11 @@ if MYPY:
420420
x = b'abc'
421421
[out]
422422

423-
[case testNestedGenericFailedInference]
423+
[case testDefaultDictInference]
424424
from collections import defaultdict
425425
def foo() -> None:
426-
x = defaultdict(list) # type: ignore
426+
x = defaultdict(list)
427427
x['lol'].append(10)
428428
reveal_type(x)
429429
[out]
430-
_testNestedGenericFailedInference.py:5: note: Revealed type is 'collections.defaultdict[Any, builtins.list[Any]]'
430+
_testDefaultDictInference.py:5: note: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'

0 commit comments

Comments
 (0)
0