-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Support type inference for defaultdict() #8167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
This allows inferring type of `x`, for example: ``` from collections import defaultdict x = defaultdict(list) # defaultdict[str, List[int]] x['foo'].append(1) ``` The implemention is not pretty and we have probably reached about the maximum reasonable level of special casing in type inference now. There is a hack to work around the problem with leaking type variable types in nested generics calls (I think). This will break some (likely very rare) use cases.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2813,14 +2813,24 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool | |
partial_type = PartialType(None, name) | ||
elif isinstance(init_type, Instance): | ||
fullname = init_type.type.fullname | ||
if (isinstance(lvalue, (NameExpr, MemberExpr)) and | ||
is_ref = isinstance(lvalue, (NameExpr, MemberExpr)) | ||
if (is_ref and | ||
|
||
fullname == 'builtins.set' or | ||
fullname == 'builtins.dict' or | ||
fullname == 'collections.OrderedDict') and | ||
all(isinstance(t, (NoneType, UninhabitedType)) | ||
for t in get_proper_types(init_type.args))): | ||
partial_type = PartialType(init_type.type, name) | ||
elif is_ref and fullname == 'collections.defaultdict': | ||
arg0 = get_proper_type(init_type.args[0]) | ||
arg1 = get_proper_type(init_type.args[1]) | ||
if (isinstance(arg0, (NoneType, UninhabitedType)) and | ||
isinstance(arg1, Instance) and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be cleaner to move this check to the below helper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. It's here since we rely on the narrowed down type below, so we'd need a type check anyway. |
||
self.is_valid_ordereddict_partial_value_type(arg1)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name is really weird:
For example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, I'll rename it. |
||
partial_type = PartialType(init_type.type, name, arg1) | ||
else: | ||
return False | ||
else: | ||
return False | ||
else: | ||
|
@@ -2829,6 +2839,15 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool | |
self.partial_types[-1].map[name] = lvalue | ||
return True | ||
|
||
def is_valid_ordereddict_partial_value_type(self, t: Instance) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a short docstring with an example (also the name needs fixing, see above). |
||
if len(t.args) == 0: | ||
return True | ||
if len(t.args) == 1: | ||
arg = get_proper_type(t.args[0]) | ||
if isinstance(arg, (TypeVarType, UninhabitedType)): # TODO: This is too permissive | ||
return True | ||
return False | ||
|
||
def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None: | ||
"""Store inferred variable type. | ||
|
||
|
@@ -3018,16 +3037,21 @@ def try_infer_partial_type_from_indexed_assignment( | |
if partial_types is None: | ||
return | ||
typename = type_type.fullname | ||
if typename == 'builtins.dict' or typename == 'collections.OrderedDict': | ||
if (typename == 'builtins.dict' | ||
or typename == 'collections.OrderedDict' | ||
or typename == 'collections.defaultdict'): | ||
# TODO: Don't infer things twice. | ||
key_type = self.expr_checker.accept(lvalue.index) | ||
value_type = self.expr_checker.accept(rvalue) | ||
if (is_valid_inferred_type(key_type) and | ||
is_valid_inferred_type(value_type)): | ||
if not self.current_node_deferred: | ||
var.type = self.named_generic_type(typename, | ||
[key_type, value_type]) | ||
del partial_types[var] | ||
is_valid_inferred_type(value_type) and | ||
not self.current_node_deferred and | ||
not (typename == 'collections.defaultdict' and | ||
var.type.value_type is not None and | ||
not is_equivalent(value_type, var.type.value_type))): | ||
var.type = self.named_generic_type(typename, | ||
[key_type, value_type]) | ||
del partial_types[var] | ||
|
||
8000 | def visit_expression_stmt(self, s: ExpressionStmt) -> None: | |
self.expr_checker.accept(s.expr, allow_none_return=True, always_allow_any=True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2976,3 +2976,72 @@ x: Optional[str] | |
y = filter(None, [x]) | ||
reveal_type(y) # N: Revealed type is 'builtins.list[builtins.str*]' | ||
[builtins fixtures/list.pyi] | ||
|
||
[case testPartialDefaultDict] | ||
from collections import defaultdict | ||
x = defaultdict(int) | ||
x[''] = 1 | ||
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a similar test case where value type is generic, e.g.: x = defaultdict(list) # No error
x['a'] = [1, 2, 3] and x = defaultdict(list) # Error here
x['a'] = [] |
||
|
||
y = defaultdict(int) # E: Need type annotation for 'y' | ||
|
||
z = defaultdict(int) # E: Need type annotation for 'z' | ||
z[''] = '' | ||
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]' | ||
[builtins fixtures/dict.pyi] | ||
|
||
[case testPartialDefaultDictInconsistentValueTypes] | ||
from collections import defaultdict | ||
a = defaultdict(int) # E: Need type annotation for 'a' | ||
a[''] = '' | ||
a[''] = 1 | ||
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]' | ||
[builtins fixtures/dict.pyi] | ||
|
||
[case testPartialDefaultDictListValue] | ||
from collections import defaultdict | ||
a = defaultdict(list) | ||
a['x'].append(1) | ||
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]' | ||
|
||
b = defaultdict(lambda: []) | ||
b[1].append('x') | ||
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]' | ||
[builtins fixtures/dict.pyi] | ||
|
||
[case testPartialDefaultDictSpecialCases] | ||
from collections import defaultdict | ||
class A: | ||
def f(self) -> None: | ||
self.x = defaultdict(list) | ||
self.x['x'].append(1) | ||
reveal_type(self.x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]' | ||
self.y = defaultdict(list) # E: Need type annotation for 'y' | ||
s = self | ||
s.y['x'].append(1) | ||
|
||
x = {} # E: Need type annotation for 'x' (hint: "x: Dict[<type>, <type>] = ...") | ||
x['x'].append(1) | ||
|
||
y = defaultdict(list) # E: Need type annotation for 'y' | ||
y[[]].append(1) | ||
[builtins fixtures/dict.pyi] | ||
|
||
[case testPartialDefaultDictSpecialCases2] | ||
from collections import defaultdict | ||
|
||
x = defaultdict(lambda: [1]) # E: Need type annotation for 'x' | ||
x[1].append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int" | ||
reveal_type(x) # N: Revealed type is 'collections.defaultdict[Any, builtins.list[builtins.int]]' | ||
|
||
xx = defaultdict(lambda: {'x': 1}) # E: Need type annotation for 'xx' | ||
xx[1]['z'] = 3 | ||
reveal_type(xx) # N: Revealed type is 'collections.defaultdict[Any, builtins.dict[builtins.str, builtins.int]]' | ||
|
||
y = defaultdict(dict) # E: Need type annotation for 'y' | ||
y['x'][1] = [3] | ||
|
||
z = defaultdict(int) # E: Need type annotation for 'z' | ||
z[1].append('') | ||
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]' | ||
[builtins fixtures/dict.pyi] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While you are at it maybe change this to
RefExpr
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.