-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Allow assignments to multiple targets from union types #4067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
b89fc28
13cd547
d550e4a
f4b57eb
edd70cc
9a68ae6
012be52
1a34540
036af84
a4b734b
950b8f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from typing import Dict, List, Set, Iterator, Union, Optional, cast | ||
from typing import Dict, List, Set, Iterator, Union, Optional, Tuple, DefaultDict, cast | ||
from contextlib import contextmanager | ||
from collections import defaultdict | ||
|
||
from mypy.types import Type, AnyType, PartialType, UnionType, TypeOfAny | ||
from mypy.subtypes import is_subtype | ||
|
@@ -57,6 +58,7 @@ class A: | |
reveal_type(lst[0].a) # str | ||
``` | ||
""" | ||
type_assignments = None # type: Optional[DefaultDict[Expression, List[Tuple[Type, Type]]]] | ||
|
||
def __init__(self) -> None: | ||
# The stack of frames currently used. These map | ||
|
@@ -210,10 +212,20 @@ def pop_frame(self, can_skip: bool, fall_through: int) -> Frame: | |
|
||
return result | ||
|
||
@contextmanager | ||
def accumulate_type_assignments(self) -> Iterator[DefaultDict[Expression, | ||
List[Tuple[Type, Type]]]]: | ||
self.type_assignments = defaultdict(list) | ||
yield self.type_assignments | ||
self.type_assignments = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about things like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you are totally right. Will push a commit in a second. Also testing this uncovered another flaw in my implementation: I need to iterate over nested lvalues in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
|
||
def assign_type(self, expr: Expression, | ||
type: Type, | ||
< 8000 td class="blob-code blob-code-context js-file-line"> declared_type: Optional[Type], | ||
restrict_any: bool = False) -> None: | ||
if self.type_assignments is not None: | ||
self.type_assignments[expr].append((type, declared_type)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment describing the purpose of this. |
||
return | ||
if not isinstance(expr, BindableTypes): | ||
return None | ||
if not literal(expr): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,7 @@ | |
from mypy.erasetype import erase_typevars | ||
from mypy.expandtype import expand_type, expand_type_by_instance | ||
from mypy.visitor import NodeVisitor | ||
from mypy.join import join_types | ||
from mypy.join import join_types, join_type_list | ||
from mypy.treetransform import TransformVisitor | ||
from mypy.binder import ConditionalTypeBinder, get_declaration | ||
from mypy.meet import is_overlapping_types | ||
|
@@ -1604,12 +1604,14 @@ def check_rvalue_count_in_assignment(self, lvalues: List[Lvalue], rvalue_count: | |
def check_multi_assignment(self, lvalues: List[Lvalue], | ||
rvalue: Expression, | ||
context: Context, | ||
infer_lvalue_type: bool = True) -> None: | ||
infer_lvalue_type: bool = True, | ||
rv_type: Optional[Type] = None, | ||
undefined_rvalue: bool = False) -> None: | ||
"""Check the assignment of one rvalue to a number of lvalues.""" | ||
|
||
# Infer the type of an ordinary rvalue expression. | ||
rvalue_type = self.expr_checker.accept(rvalue) # TODO maybe elsewhere; redundant | ||
undefined_rvalue = False | ||
# TODO: maybe elsewhere; redundant. | ||
rvalue_type = rv_type or self.expr_checker.accept(rvalue) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it okay to not type check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the only way to get there is from another function like this, so the |
||
|
||
if isinstance(rvalue_type, UnionType): | ||
# If this is an Optional type in non-strict Optional code, unwrap it. | ||
|
@@ -1628,11 +1630,37 @@ def check_multi_assignment(self, lvalues: List[Lvalue], | |
self.check_multi_assignment_from_tuple(lvalues, rvalue, rvalue_type, | ||
context, undefined_rvalue, infer_lvalue_type) | ||
elif isinstance(rvalue_type, UnionType): | ||
self.check_multi_assignment_from_union(lvalues, rvalue, context, infer_lvalue_type) | ||
self.check_multi_assignment_from_union(lvalues, rvalue, rvalue_type, context, | ||
infer_lvalue_type) | ||
else: | ||
self.check_multi_assignment_from_iterable(lvalues, rvalue_type, | ||
context, infer_lvalue_type) | ||
|
||
def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: Expression, | ||
rvalue_type: UnionType, context: Context, | ||
infer_lvalue_type: bool) -> None: | ||
transposed = tuple([] for _ in lvalues) # type: Tuple[List[Type], ...] | ||
with self.binder.accumulate_type_assignments() as assignments: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add comment about why we have this. |
||
for item in rvalue_type.items: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add short comment describing what we are doing here. For example, something like "Type check the assignment separately for each union item and collect the inferred lvalue types for each union item.". |
||
self.check_multi_assignment(lvalues, rvalue, context, | ||
infer_lvalue_type=infer_lvalue_type, | ||
rv_type=item, undefined_rvalue=True) | ||
for t, lv in zip(transposed, lvalues): | ||
t.append(self.type_map.get(lv, AnyType(TypeOfAny.special_form))) | ||
union_types = tuple(join_type_list(col) for col in transposed) | ||
for expr, items in assignments.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, add comment that describes the purpose of this look in one sentence. |
||
types, declared_types = zip(*items) | ||
self.binder.assign_type(expr, | ||
join_type_list(types), | ||
join_type_list(declared_types), | ||
False) | ||
for union, lv in zip(union_types, lvalues): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One more comment for this loop. |
||
_1, _2, inferred = self.check_lvalue(lv) | ||
if inferred: | ||
self.set_inferred_type(inferred, lv, union) | ||
else: | ||
self.store_type(lv, union) | ||
|
||
def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expression, | ||
rvalue_type: TupleType, context: Context, | ||
undefined_rvalue: bool, | ||
|
@@ -1820,7 +1848,6 @@ def infer_variable_type(self, name: Var, lvalue: Lvalue, | |
|
||
# Make the type more general (strip away function names etc.). | ||
init_type = strip_type(init_type) | ||
|
||
self.set_inferred_type(name, lvalue, init_type) | ||
|
||
def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool: | ||
|
@@ -2698,7 +2725,17 @@ def iterable_item_type(self, instance: Instance) -> Type: | |
iterable = map_instance_to_supertype( | ||
instance, | ||
self.lookup_typeinfo('typing.Iterable')) | ||
return iterable.args[0] | ||
item_type = iterable.args[0] | ||
if not isinstance(item_type, AnyType): | ||
return item_type | ||
# Try also structural typing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This relies on |
||
iter_type = find_member('__iter__', instance, instance) | ||
if (iter_type and isinstance(iter_type, CallableType) and | ||
isinstance(iter_type.ret_type, Instance)): | ||
iterator = map_instance_to_supertype(iter_type.ret_type, | ||
self.lookup_typeinfo('typing.Iterator')) | ||
item_type = iterator.args[0] | ||
return item_type | ||
|
||
def function_type(self, func: FuncBase) -> FunctionLike: | ||
return function_type(func, self.named_type('builtins.function')) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstring. Also mention that this is used for multi-assignment from union (and why this is needed there).