8000 Always flatten unions on creation (#3298) · python/mypy@27a8437 · GitHub
[go: up one dir, main page]

Skip to content

Commit 27a8437

Browse files
ilevkivskyigvanrossum
authored andcommitted
Always flatten unions on creation (#3298)
The idea is that we should have a normalized internal representation of unions to simplify reasoning. This also turns out to fix #3196.
1 parent ff9abd8 commit 27a8437

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

mypy/types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ class UnionType(Type):
10121012
items = None # type: List[Type]
10131013

10141014
def __init__(self, items: List[Type], line: int = -1, column: int = -1) -> None:
1015-
self.items = items
1015+
self.items = flatten_nested_unions(items)
10161016
self.can_be_true = any(item.can_be_true for item in items)
10171017
self.can_be_false = any(item.can_be_false for item in items)
10181018
super().__init__(line, column)
@@ -1771,6 +1771,17 @@ def get_type_vars(typ: Type) -> List[TypeVarType]:
17711771
return tvars
17721772

17731773

1774+
def flatten_nested_unions(types: Iterable[Type]) -> List[Type]:
1775+
"""Flatten nested unions in a type list."""
1776+
flat_items = [] # type: List[Type]
1777+
for tp in types:
1778+
if isinstance(tp, UnionType):
1779+
flat_items.extend(flatten_nested_unions(tp.items))
1780+
else:
1781+
flat_items.append(tp)
1782+
return flat_items
1783+
1784+
17741785
def union_items(typ: Type) -> List[Type]:
17751786
"""Return the flattened items of a union type.
17761787

test-data/unit/check-unions.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,21 @@ reveal_type(l) \
453453
# E: Revealed type is 'builtins.list[Union[builtins.bool, builtins.int, builtins.float, builtins.str]]'
454454
[builtins fixtures/list.pyi]
455455

456+
[case testNestedUnionsProcessedCorrectly]
457+
from typing import Union
458+
459+
class A: pass
460+
class B: pass
461+
class C: pass
462+
463+
def foo(bar: Union[Union[A, B], C]) -> None:
464+
if isinstance(bar, A):
465+
reveal_type(bar) # E: Revealed type is '__main__.A'
466+
else:
467+
reveal_type(bar) # E: Revealed type is 'Union[__main__.B, __main__.C]'
468+
[builtins fixtures/isinstance.pyi]
469+
[out]
470+
456471
[case testAssignAnyToUnion]
457472
from typing import Union, Any
458473
x: Union[int, str]

0 commit comments

Comments
 (0)
0