diff --git a/mypy/types.py b/mypy/types.py index 91f75f6c592f..090abea1aad8 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -987,7 +987,7 @@ class UnionType(Type): items = None # type: List[Type] def __init__(self, items: List[Type], line: int = -1, column: int = -1) -> None: - self.items = items + self.items = flatten_nested_unions(items) self.can_be_true = any(item.can_be_true for item in items) self.can_be_false = any(item.can_be_false for item in items) super().__init__(line, column) @@ -1732,6 +1732,17 @@ def get_type_vars(typ: Type) -> List[TypeVarType]: return tvars +def flatten_nested_unions(types: Iterable[Type]) -> List[Type]: + """Flatten nested unions in a type list.""" + flat_items = [] # type: List[Type] + for tp in types: + if isinstance(tp, UnionType): + flat_items.extend(flatten_nested_unions(tp.items)) + else: + flat_items.append(tp) + return flat_items + + def union_items(typ: Type) -> List[Type]: """Return the flattened items of a union type. diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index 8a89dabad82b..f7774d5e5bcb 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -453,6 +453,21 @@ reveal_type(l) \ # E: Revealed type is 'builtins.list[Union[builtins.bool, builtins.int, builtins.float, builtins.str]]' [builtins fixtures/list.pyi] +[case testNestedUnionsProcessedCorrectly] +from typing import Union + +class A: pass +class B: pass +class C: pass + +def foo(bar: Union[Union[A, B], C]) -> None: + if isinstance(bar, A): + reveal_type(bar) # E: Revealed type is '__main__.A' + else: + reveal_type(bar) # E: Revealed type is 'Union[__main__.B, __main__.C]' +[builtins fixtures/isinstance.pyi] +[out] + [case testAssignAnyToUnion] from typing import Union, Any x: Union[int, str]