8000 Support type inference for defaultdict() by JukkaL · Pull Request #8167 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Support type inference for defaultdict() #8167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support type inference for defaultdict()
This allows inferring type of `x`, for example:

```
from collections import defaultdict

x = defaultdict(list)  # defaultdict[str, List[int]]
x['foo'].append(1)
```

The implemention is not pretty and we have probably
reached about the maximum reasonable level of special
casing in type inference now.

There is a hack to work around the problem with leaking
type variable types in nested generics calls (I think).
This will break some (likely very rare) use cases.
  • Loading branch information
JukkaL committed Dec 18, 2019
commit f7c2236f34406d779dfa7645136bbb49e513869d
38 changes: 31 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,14 +2813,24 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
partial_type = PartialType(None, name)
elif isinstance(init_type, Instance):
fullname = init_type.type.fullname
if (isinstance(lvalue, (NameExpr, MemberExpr)) and
is_ref = isinstance(lvalue, (NameExpr, MemberExpr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you are at it maybe change this to RefExpr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

if (is_ref and
(fullname == 'builtins.list' or
fullname == 'builtins.set' or
fullname == 'builtins.dict' or
fullname == 'collections.OrderedDict') and
all(isinstance(t, (NoneType, UninhabitedType))
for t in get_proper_types(init_type.args))):
partial_type = PartialType(init_type.type, name)
elif is_ref and fullname == 'collections.defaultdict':
arg0 = get_proper_type(init_type.args[0])
arg1 = get_proper_type(init_type.args[1])
if (isinstance(arg0, (NoneType, UninhabitedType)) and
isinstance(arg1, Instance) and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cleaner to move this check to the below helper.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. It's here since we rely on the narrowed down type below, so we'd need a type check anyway.

self.is_valid_ordereddict_partial_value_type(arg1)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is really weird:

  • Why ordereddict and not defaultdict?
  • I don't think it should be called valid, maybe empty?

For example is_empty_defaultdict_value_type().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I'll rename it.

partial_type = PartialType(init_type.type, name, arg1)
else:
return False
else:
return False
else:
Expand All @@ -2829,6 +2839,15 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
self.partial_types[-1].map[name] = lvalue
return True

def is_valid_ordereddict_partial_value_type(self, t: Instance) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a short docstring with an example (also the name needs fixing, see above).

if len(t.args) == 0:
return True
if len(t.args) == 1:
arg = get_proper_type(t.args[0])
if isinstance(arg, (TypeVarType, UninhabitedType)): # TODO: This is too permissive
return True
return False

def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
"""Store inferred variable type.

Expand Down Expand Up @@ -3018,16 +3037,21 @@ def try_infer_partial_type_from_indexed_assignment(
if partial_types is None:
return
typename = type_type.fullname
if typename == 'builtins.dict' or typename == 'collections.OrderedDict':
if (typename == 'builtins.dict'
or typename == 'collections.OrderedDict'
or typename == 'collections.defaultdict'):
# TODO: Don't infer things twice.
key_type = self.expr_checker.accept(lvalue.index)
value_type = self.expr_checker.accept(rvalue)
if (is_valid_inferred_type(key_type) and
is_valid_inferred_type(value_type)):
if not self.current_node_deferred:
var.type = self.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]
is_valid_inferred_type(value_type) and
not self.current_node_deferred and
not (typename == 'collections.defaultdict' and
var.type.value_type is not None and
not is_equivalent(value_type, var.type.value_type))):
var.type = self.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]

8000 def visit_expression_stmt(self, s: ExpressionStmt) -> None:
self.expr_checker.accept(s.expr, allow_none_return=True, always_allow_any=True)
Expand Down
125 changes: 90 additions & 35 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,42 +567,91 @@ def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]:
} # type: ClassVar[Dict[str, Dict[str, List[str]]]]

def try_infer_partial_type(self, e: CallExpr) -> None:
if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr):
var = e.callee.expr.node
if var is None and isinstance(e.callee.expr, MemberExpr):
var = self.get_partial_self_var(e.callee.expr)
if not isinstance(var, Var):
"""Try to make partial type precise from a call."""
if not isinstance(e.callee, MemberExpr):
return
callee = e.callee
if isinstance(callee.expr, RefExpr):
# Call a method with a RefExpr callee, such as 'x.method(...)'.
ret = self.get_partial_var(callee.expr)
if ret is None:
return
partial_types = self.chk.find_partial_types(var)
if partial_types is not None and not self.chk.current_node_deferred:
partial_type = var.type
if (partial_type is None or
not isinstance(partial_type, PartialType) or
partial_type.type is None):
# A partial None type -> can't infer anything.
return
typename = partial_type.type.fullname
methodname = e.callee.name
# Sometimes we can infer a full type for a partial List, Dict or Set type.
# TODO: Don't infer argument expression twice.
if (typename in self.item_args and methodname in self.item_args[typename]
and e.arg_kinds == [ARG_POS]):
item_type = self.accept(e.args[0])
if mypy.checker.is_valid_inferred_type(item_type):
var.type = self.chk.named_generic_type(typename, [item_type])
del partial_types[var]
elif (typename in self.container_args
and methodname in self.container_args[typename]
and e.arg_kinds == [ARG_POS]):
arg_type = get_proper_type(self.accept(e.args[0]))
if isinstance(arg_type, Instance):
arg_typename = arg_type.type.fullname
if arg_typename in self.container_args[typename][methodname]:
if all(mypy.checker.is_valid_inferred_type(item_type)
for item_type in arg_type.args):
var.type = self.chk.named_generic_type(typename,
list(arg_type.args))
del partial_types[var]
var, partial_types = ret
typ = self.try_infer_partial_value_type_from_call(e, callee.name, var)
if typ is not None:
var.type = typ
del partial_types[var]
elif isinstance(callee.expr, IndexExpr) and isinstance(callee.expr.base, RefExpr):
# Call 'x[y].method(...)'; may infer type of 'x' if it's a partial defaultdict.
if callee.expr.analyzed is not None:
return # A special form
base = callee.expr.base
index = callee.expr.index
ret = self.get_partial_var(base)
if ret is None:
return
var, partial_types = ret
partial_type = get_partial_instance_type(var.type)
if partial_type is None or partial_type.value_type is None:
return
value_type = self.try_infer_partial_value_type_from_call(e, callee.name, var)
if value_type is not None:
# Infer key type.
key_type = self.accept(index)
if mypy.checker.is_valid_inferred_type(key_type):
# Store inferred partial type.
assert partial_type.type is not None
typename = partial_type.type.fullname
var.type = self.chk.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]

def get_partial_var(self, ref: RefExpr) -> Optional[Tuple[Var, Dict[Var, Context]]]:
var = ref.node
if var is None and isinstance(ref, MemberExpr):
var = self.get_partial_self_var(ref)
if not isinstance(var, Var):
return None
partial_types = self.chk.find_partial_types(var)
if partial_types is None:
return None
return var, partial_types

def try_infer_partial_value_type_from_call(
self,
e: CallExpr,
methodname: str,
var: Var) -> Optional[Instance]:
"""Try to make partial type precise from a call such as 'x.append(y)'."""
if self.chk.current_node_deferred:
return None
partial_type = get_partial_instance_type(var.type)
if partial_type is None:
return None
if partial_type.value_type:
typename = partial_type.value_type.type.fullname
else:
assert partial_type.type is not None
typename = partial_type.type.fullname
# Sometimes we can infer a full type for a partial List, Dict or Set type.
# TODO: Don't infer argument expression twice.
if (typename in self.item_args and methodname in self.item_args[typename]
and e.arg_kinds == [ARG_POS]):
item_type = self.accept(e.args[0])
if mypy.checker.is_valid_inferred_type(item_type):
return self.chk.named_generic_type(typename, [item_type])
elif (typename in self.container_args
and methodname in self.container_args[typename]
and e.arg_kinds == [ARG_POS]):
arg_type = get_proper_type(self.accept(e.args[0]))
if isinstance(arg_type, Instance):
arg_typename = arg_type.type.fullname
if arg_typename in self.container_args[typename][methodname]:
if all(mypy.checker.is_valid_inferred_type(item_type)
for item_type in arg_type.args):
return self.chk.named_generic_type(typename,
list(arg_type.args))
return None

def apply_function_plugin(self,
callee: CallableType,
Expand Down Expand Up @@ -4299,3 +4348,9 @@ def is_operator_method(fullname: Optional[str]) -> bool:
short_name in nodes.op_methods.values() or
short_name in nodes.reverse_op_methods.values() or
short_name in nodes.unary_op_methods.values())


def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]:
if t is None or not isinstance(t, PartialType) or t.type is None:
return None
return t
6 changes: 5 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,13 +1763,17 @@ class PartialType(ProperType):
# None for the 'None' partial type; otherwise a generic class
type = None # type: Optional[mypy.nodes.TypeInfo]
var = None # type: mypy.nodes.Var
# For partial DefaultDict[K, V], the type V (K is unknown)
value_type = None # type: Optional[Instance]

def __init__(self,
type: 'Optional[mypy.nodes.TypeInfo]',
var: 'mypy.nodes.Var') -> None:
var: 'mypy.nodes.Var',
value_type: 'Optional[Instance]' = None) -> None:
super().__init__()
self.type = type
self.var = var
self.value_type = value_type

def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_partial_type(self)
Expand Down
69 changes: 69 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -2976,3 +2976,72 @@ x: Optional[str]
y = filter(None, [x])
reveal_type(y) # N: Revealed type is 'builtins.list[builtins.str*]'
[builtins fixtures/list.pyi]

[case testPartialDefaultDict]
from collections import defaultdict
x = defaultdict(int)
x[''] = 1
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a similar test case where value type is generic, e.g.:

x = defaultdict(list)  # No error
x['a'] = [1, 2, 3]

and

x = defaultdict(list)  # Error here
x['a'] = []


y = defaultdict(int) # E: Need type annotation for 'y'

z = defaultdict(int) # E: Need type annotation for 'z'
z[''] = ''
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictInconsistentValueTypes]
from collections import defaultdict
a = defaultdict(int) # E: Need type annotation for 'a'
a[''] = ''
a[''] = 1
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictListValue]
from collections import defaultdict
a = defaultdict(list)
a['x'].append(1)
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'

b = defaultdict(lambda: [])
b[1].append('x')
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCases]
from collections import defaultdict
class A:
def f(self) -> None:
self.x = defaultdict(list)
self.x['x'].append(1)
reveal_type(self.x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'
self.y = defaultdict(list) # E: Need type annotation for 'y'
s = self
s.y['x'].append(1)

x = {} # E: Need type annotation for 'x' (hint: "x: Dict[<type>, <type>] = ...")
x['x'].append(1)

y = defaultdict(list) # E: Need type annotation for 'y'
y[[]].append(1)
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCases2]
from collections import defaultdict

x = defaultdict(lambda: [1]) # E: Need type annotation for 'x'
x[1].append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
reveal_type(x) # N: Revealed type is 'collections.defaultdict[Any, builtins.list[builtins.int]]'

xx = defaultdict(lambda: {'x': 1}) # E: Need type annotation for 'xx'
xx[1]['z'] = 3
reveal_type(xx) # N: Revealed type is 'collections.defaultdict[Any, builtins.dict[builtins.str, builtins.int]]'

y = defaultdict(dict) # E: Need type annotation for 'y'
y['x'][1] = [3]

z = defaultdict(int) # E: Need type annotation for 'z'
z[1].append('')
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]
1 change: 1 addition & 0 deletions test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class list(Sequence[T]): # needed by some test cases
def __iter__(self) -> Iterator[T]: pass
def __mul__(self, x: int) -> list[T]: pass
def __contains__(self, item: object) -> bool: pass
def append(self, item: T) -> None: pass

class tuple(Generic[T]): pass
class function: pass
Expand Down
12 changes: 7 additions & 5 deletions test-data/unit/lib-stub/collections.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Union, Optional, Dict, TypeVar
from typing import Any, Iterable, Union, Optional, Dict, TypeVar, overload, Optional, Callable

def namedtuple(
typename: str,
Expand All @@ -10,8 +10,10 @@ def namedtuple(
defaults: Optional[Iterable[Any]] = ...
) -> Any: ...

K = TypeVar('K')
V = TypeVar('V')
KT = TypeVar('KT')
VT = TypeVar('VT')

class OrderedDict(Dict[K, V]):
def __setitem__(self, k: K, v: V) -> None: ...
class OrderedDict(Dict[KT, VT]): ...

class defaultdict(Dict[KT, VT]):
def __init__(self, default_factory: Optional[Callable[[], VT]]) -> None: ...
0