From 26113007281d1283755a0836638751aa62f28a8d Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Thu, 14 Oct 2021 00:16:20 +0200 Subject: [PATCH 01/23] Add exhaustiveness checking --- mypy/checker.py | 78 +++++++++++++++++++++-------- mypy/subtypes.py | 6 +++ test-data/unit/check-python310.test | 52 +++++++++++++++++-- 3 files changed, 110 insertions(+), 26 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 851f23185f4f..d2c9947e3dd1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4089,37 +4089,62 @@ def visit_match_stmt(self, s: MatchStmt) -> None: if isinstance(subject_type, DeletedType): self.msg.deleted_as_rvalue(subject_type, s) + # We have to check each pattern twice. Once ignoring the guard statement to infer + # the capture types and once with then to narrow the subject. + # In addition PatternChecker adds intersection types to the scope. We only want that + # to happen on the second pass, so we copy the SymbolTable beforehand. + curr_module = self.scope.stack[0] + assert isinstance(curr_module, MypyFile) + names = curr_module.names.copy() pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] + curr_module.names = names type_maps: List[TypeMap] = [t.captures for t in pattern_types] - self.infer_variable_types_from_type_maps(type_maps) + inferred_names = self.infer_variable_types_from_type_maps(type_maps) - for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): + for p, g, b in zip(s.patterns, s.guards, s.bodies): + current_subject_type = self.expr_checker.narrow_type_from_binder(s.subject, + subject_type) + pattern_type = self.pattern_checker.accept(p, current_subject_type) with self.binder.frame_context(can_skip=True, fall_through=2): if b.is_unreachable or isinstance(get_proper_type(pattern_type.type), UninhabitedType): self.push_type_map(None) + else_map: TypeMap = {} else: - self.binder.put(s.subject, pattern_type.type) + pattern_map, else_map = conditional_types_to_typemaps( + s.subject, + pattern_type.type, + pattern_type.rest_type + ) + self.remove_capture_conflicts(pattern_type.captures, + inferred_names) + self.push_type_map(pattern_map) self.push_type_map(pattern_type.captures) if g is not None: - gt = get_proper_type(self.expr_checker.accept(g)) + with self.binder.frame_context(can_skip=True, fall_through=3): + gt = get_proper_type(self.expr_checker.accept(g)) - if isinstance(gt, DeletedType): - self.msg.deleted_as_rvalue(gt, s) + if isinstance(gt, DeletedType): + self.msg.deleted_as_rvalue(gt, s) - if_map, _ = self.find_isinstance_check(g) + guard_map, guard_else_map = self.find_isinstance_check(g) + else_map = or_conditional_maps(else_map, guard_else_map) - self.push_type_map(if_map) - self.accept(b) + self.push_type_map(guard_map) + self.accept(b) + else: + self.accept(b) + self.push_type_map(else_map) # This is needed due to a quirk in frame_context. Without it types will stay narrowed # after the match. with self.binder.frame_context(can_skip=False, fall_through=2): pass - def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None: + def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[Var, Type]: all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list) + inferred_names: Dict[Var, Type] = {} for tm in type_maps: if tm is not None: for expr, typ in tm.items(): @@ -4129,27 +4154,36 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None: all_captures[node].append((expr, typ)) for var, captures in all_captures.items(): - conflict = False + already_exists = False types: List[Type] = [] for expr, typ in captures: types.append(typ) - previous_type, _, inferred = self.check_lvalue(expr) + previous_type, _, _ = self.check_lvalue(expr) if previous_type is not None: - conflict = True - self.check_subtype(typ, previous_type, expr, - msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, - subtype_label="pattern captures type", - supertype_label="variable has type") - for type_map in type_maps: - if type_map is not None and expr in type_map: - del type_map[expr] - - if not conflict: + already_exists = True + if self.check_subtype(typ, previous_type, expr, + msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, + subtype_label="pattern captures type", + supertype_label="variable has type"): + inferred_names[var] = previous_type + + if not already_exists: new_type = UnionType.make_union(types) # Infer the union type at the first occurrence first_occurrence, _ = captures[0] + inferred_names[var] = new_type self.infer_variable_type(var, first_occurrence, new_type, first_occurrence) + return inferred_names + + def remove_capture_conflicts(self, type_map: TypeMap, inferred_names: Dict[Var, Type]) -> None: + if type_map is not None: + for expr, typ in type_map.copy().items(): + if isinstance(expr, NameExpr): + node = expr.node + assert isinstance(node, Var) + if node not in inferred_names or not is_subtype(typ, inferred_names[node]): + del type_map[expr] def make_fake_typeinfo(self, curr_module_fullname: str, diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a261e3712328..0a84d0112ea8 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1237,11 +1237,13 @@ def _is_proper_subtype(left: Type, right: Type, *, class ProperSubtypeVisitor(TypeVisitor[bool]): def __init__(self, right: Type, *, ignore_promotions: bool = False, + ignore_last_known_value: bool = False, erase_instances: bool = False, keep_erased_types: bool = False) -> None: self.right = get_proper_type(right) self.orig_right = right self.ignore_promotions = ignore_promotions + self.ignore_last_known_value = ignore_last_known_value self.erase_instances = erase_instances self.keep_erased_types = keep_erased_types self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( @@ -1297,6 +1299,10 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(right, Instance): if TypeState.is_cached_subtype_check(self._subtype_kind, left, right): return True + if not self.ignore_last_known_value: + if right.last_known_value is not None and \ + right.last_known_value != left.last_known_value: + return False if not self.ignore_promotions: for base in left.type.mro: if base._promote and self._is_proper_subtype(base._promote, right): diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 0d4f46d53924..1d64374bfc49 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1049,6 +1049,19 @@ match m: case a: reveal_type(a) # N: Revealed type is "builtins.int" +[case testCapturePatternPreexistingNarrows] +a: int +m: bool + +match m: + case a: + reveal_type(a) # N: Revealed type is "builtins.bool" + +# This is actually correct, as case a has to be taken. +reveal_type(a) # N: Revealed type is "builtins.bool" +a = 3 +reveal_type(a) # N: Revealed type is "builtins.int" + [case testCapturePatternPreexistingIncompatible] a: str m: int @@ -1057,6 +1070,8 @@ match m: case a: # E: Incompatible types in capture pattern (pattern captures type "int", variable has type "str") reveal_type(a) # N: Revealed type is "builtins.str" +reveal_type(a) # N: Revealed type is "builtins.str" + [case testCapturePatternPreexistingIncompatibleLater] a: str m: object @@ -1067,6 +1082,8 @@ match m: case int(a): # E: Incompatible types in capture pattern (pattern captures type "int", variable has type "str") reveal_type(a) # N: Revealed type is "builtins.str" +reveal_type(a) # N: Revealed type is "builtins.str" + -- Guards -- [case testSimplePatternGuard] @@ -1135,7 +1152,7 @@ match m: [builtins fixtures/isinstancelist.pyi] -- Exhaustiveness -- -[case testUnionNegativeNarrowing-skip] +[case testUnionNegativeNarrowing] from typing import Union m: Union[str, int] @@ -1148,7 +1165,7 @@ match m: reveal_type(b) # N: Revealed type is "builtins.int" reveal_type(m) # N: Revealed type is "builtins.int" -[case testOrPatternNegativeNarrowing-skip] +[case testOrPatternNegativeNarrowing] from typing import Union m: Union[str, bytes, int] @@ -1160,7 +1177,7 @@ match m: case b: reveal_type(b) # N: Revealed type is "builtins.int" -[case testExhaustiveReturn-skip] +[case testExhaustiveReturn] def foo(value) -> int: match value: case "bar": @@ -1168,7 +1185,7 @@ def foo(value) -> int: case _: return 2 -[case testNoneExhaustiveReturn-skip] +[case testNoneExhaustiveReturn] def foo(value) -> int: # E: Missing return statement match value: case "bar": @@ -1216,3 +1233,30 @@ class A: class B: def __enter__(self) -> B: ... def __exit__(self, x, y, z) -> None: ... + +[case testNonExhaustiveError] +from typing import NoReturn +def assert_never(x: NoReturn) -> None: ... + +value: int +match value: + case 1: + pass + case 2: + pass + case o: + assert_never(o) # E: Argument 1 to "assert_never" has incompatible type "int"; expected "NoReturn" + +[case testExhaustiveNoError] +from typing import NoReturn, Union, Literal +def assert_never(x: NoReturn) -> None: ... + +value: Union[Literal[1], Literal[2]] +match value: + case 1: + pass + case 2: + pass + case o: + assert_never(o) +[typing fixtures/typing-medium.pyi] From 8b4458525d911a7284ef30134a679870312066a6 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Sat, 16 Oct 2021 15:18:05 +0200 Subject: [PATCH 02/23] Improve exhaustiveness checking and add enum support --- mypy/checker.py | 8 +++ mypy/checkpattern.py | 11 +++-- mypy/typeops.py | 14 +++++- test-data/unit/check-python310.test | 77 +++++++++++++++++++++++++---- 4 files changed, 96 insertions(+), 14 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d2c9947e3dd1..fde264bcad43 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5671,6 +5671,14 @@ def conditional_types(current_type: Type, None means no new information can be inferred. If default is set it is returned instead.""" if proposed_type_ranges: + if len(proposed_type_ranges) == 1: + target = proposed_type_ranges[0].item + target = get_proper_type(target) + if isinstance(target, LiteralType) and target.is_enum_literal(): + enum_name = target.fallback.type.fullname + current_type = try_expanding_enum_to_union(current_type, + enum_name, + ignore_custom_equals=False) proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) if isinstance(proposed_type, AnyType): diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index fbbb4c319ccb..b45289a5e617 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -19,7 +19,8 @@ ) from mypy.plugin import Plugin from mypy.subtypes import is_subtype -from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union +from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union, \ + coerce_to_literal from mypy.types import ( ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type, TypedDictType, TupleType, NoneType, UnionType @@ -177,6 +178,7 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: def visit_value_pattern(self, o: ValuePattern) -> PatternType: current_type = self.type_context[-1] typ = self.chk.expr_checker.accept(o.expr) + typ = coerce_to_literal(typ) narrowed_type, rest_type = self.chk.conditional_types_with_intersection( current_type, [get_type_range(typ)], @@ -259,6 +261,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, star_position, len(inner_types)) + rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types, + star_position, + len(inner_types)) # # Calculate new type @@ -287,9 +292,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: if all(is_uninhabited(typ) for typ in inner_rest_types): # All subpatterns always match, so we can apply negative narrowing - new_type, rest_type = self.chk.conditional_types_with_intersection( - current_type, [get_type_range(new_type)], o, default=current_type - ) + rest_type = TupleType(rest_inner_types, current_type.partial_fallback) else: new_inner_type = UninhabitedType() for typ in new_inner_types: diff --git a/mypy/typeops.py b/mypy/typeops.py index 57fdfeadad9a..dfe94df76d4d 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -683,7 +683,19 @@ def is_singleton_type(typ: Type) -> bool: ) -def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType: +def enum_has_custom_equals(enum: Instance): + assert enum.type.is_enum + for typ in enum.type.mro: + if typ.fullname == "enum.Enum": + return False + if "__eq__" in typ.names: + return True + + +def try_expanding_sum_type_to_union(typ: Type, + target_fullname: str, + *, + ignore_custom_equals: bool = True) -> ProperType: """Attempts to recursively expand any enum Instances with the given target_fullname into a Union of all of its component LiteralTypes. diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 1d64374bfc49..260a4c1e6756 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -14,14 +14,15 @@ m: object match m: case 1: - reveal_type(m) # N: Revealed type is "Literal[1]?" + reveal_type(m) # N: Revealed type is "Literal[1]" -[case testLiteralPatternAlreadyNarrower] +[case testLiteralPatternAlreadyNarrower-skip] m: bool match m: case 1: - reveal_type(m) # N: Revealed type is "builtins.bool" + reveal_type(m) # This should probably be unreachable, but isn't detected as such. +[builtins fixtures/primitives.pyi] [case testLiteralPatternUnreachable] # primitives are needed because otherwise mypy doesn't see that int and str are incompatible @@ -269,7 +270,7 @@ m: Tuple[object, object] match m: case [1, "str"]: - reveal_type(m) # N: Revealed type is "Tuple[Literal[1]?, Literal['str']?]" + reveal_type(m) # N: Revealed type is "Tuple[Literal[1], Literal['str']]" [builtins fixtures/list.pyi] [case testSequencePatternTupleStarred] @@ -968,7 +969,7 @@ m: object match m: case 1 | 2 as n: - reveal_type(n) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]" + reveal_type(n) # N: Revealed type is "Union[Literal[1], Literal[2]]" [case testAsPatternAlreadyNarrower] m: bool @@ -984,21 +985,21 @@ m: object match m: case 1 | 2: - reveal_type(m) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]" + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" [case testOrPatternNarrowsStr] m: object match m: case "foo" | "bar": - reveal_type(m) # N: Revealed type is "Union[Literal['foo']?, Literal['bar']?]" + reveal_type(m) # N: Revealed type is "Union[Literal['foo'], Literal['bar']]" [case testOrPatternNarrowsUnion] m: object match m: case 1 | "foo": - reveal_type(m) # N: Revealed type is "Union[Literal[1]?, Literal['foo']?]" + reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal['foo']]" [case testOrPatterCapturesMissing] from typing import List @@ -1057,7 +1058,6 @@ match m: case a: reveal_type(a) # N: Revealed type is "builtins.bool" -# This is actually correct, as case a has to be taken. reveal_type(a) # N: Revealed type is "builtins.bool" a = 3 reveal_type(a) # N: Revealed type is "builtins.int" @@ -1260,3 +1260,62 @@ match value: case o: assert_never(o) [typing fixtures/typing-medium.pyi] + +[case testSequencePatternNegativeNarrowing] +from typing import Union, Sequence, Tuple + +m1: Sequence[Union[int, str]] + +match m1: + case [int()]: + reveal_type(m1) # N: Revealed type is "typing.Sequence[builtins.int]" + case r: + reveal_type(m1) # N: Revealed type is "typing.Sequence[Union[builtins.int, builtins.str]]" + +m2: Tuple[Union[int, str]] + +match m2: + case (int(),): + reveal_type(m2) # N: Revealed type is "Tuple[builtins.int]" + case r2: + reveal_type(m2) # N: Revealed type is "Tuple[builtins.str]" + +m3: Tuple[Union[int, str]] + +match m3: + case (1,): + reveal_type(m3) # N: Revealed type is "Tuple[Literal[1]]" + case r2: + reveal_type(m3) # N: Revealed type is "Tuple[Union[builtins.int, builtins.str]]" + +[case testLiteralPatternEnumNegativeNarrowing] +from enum import Enum +class Medal(Enum): + gold = 1 + silver = 2 + bronze = 3 + +m: Medal + +match m: + case Medal.gold: + reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" + case _: + reveal_type(m) # N: Revealed type is "Union[Literal[__main__.Medal.silver], Literal[__main__.Medal.bronze]]" + +[case testLiteralPatternEnumCustomEquals] +from enum import Enum +class Medal(Enum): + gold = 1 + silver = 2 + bronze = 3 + + def __eq__(self, other) -> bool: ... + +m: Medal + +match m: + case Medal.gold: + reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" + case _: + reveal_type(m) # N: Revealed type is "__main__.Medal" From 33b17009a527487b8489a3b52694f365c7622902 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Sat, 16 Oct 2021 15:32:48 +0200 Subject: [PATCH 03/23] Remove outdated comment --- mypy/checkpattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index b45289a5e617..c310adc7bf75 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -56,7 +56,7 @@ 'PatternType', [ ('type', Type), # The type the match subject can be narrowed to - ('rest_type', Type), # For exhaustiveness checking. Not used yet + ('rest_type', Type), ('captures', Dict[Expression, Type]), # The variables captured by the pattern ]) From f13c37754d9952f9cab4cc871ceed2855154f7ae Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Thu, 28 Oct 2021 17:46:12 +0200 Subject: [PATCH 04/23] Final cleanup --- mypy/checkpattern.py | 16 +++++++++++++++- mypy/typeops.py | 3 ++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index c310adc7bf75..5d007af705f9 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -298,7 +298,14 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: for typ in new_inner_types: new_inner_type = join_types(new_inner_type, typ) new_type = self.construct_sequence_child(current_type, new_inner_type) - if not is_subtype(new_type, current_type): + if is_subtype(new_type, current_type): + new_type, _ = self.chk.conditional_types_with_intersection( + current_type, + [get_type_range(new_type)], + o, + default=current_type + ) + else: new_type = current_type return PatternType(new_type, rest_type, captures) @@ -642,6 +649,13 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type: For example: construct_sequence_child(List[int], str) = List[str] """ + proper_type = get_proper_type(outer_type) + if isinstance(proper_type, UnionType): + types = [ + self.construct_sequence_child(item, inner_type) for item in proper_type.items + if self.can_match_sequence(get_proper_type(item)) + ] + return make_simplified_union(types) sequence = self.chk.named_generic_type("typing.Sequence", [inner_type]) if is_subtype(outer_type, self.chk.named_type("typing.Sequence")): proper_type = get_proper_type(outer_type) diff --git a/mypy/typeops.py b/mypy/typeops.py index dfe94df76d4d..aa5e11dd34c0 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -683,13 +683,14 @@ def is_singleton_type(typ: Type) -> bool: ) -def enum_has_custom_equals(enum: Instance): +def enum_has_custom_equals(enum: Instance) -> bool: assert enum.type.is_enum for typ in enum.type.mro: if typ.fullname == "enum.Enum": return False if "__eq__" in typ.names: return True + assert False def try_expanding_sum_type_to_union(typ: Type, From 63faf94d635cad7bd7e836c02f722ff066b086f3 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 28 Feb 2022 15:16:34 +0000 Subject: [PATCH 05/23] Fix --- mypy/checker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index fde264bcad43..39f4aeaada85 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5676,9 +5676,9 @@ def conditional_types(current_type: Type, target = get_proper_type(target) if isinstance(target, LiteralType) and target.is_enum_literal(): enum_name = target.fallback.type.fullname - current_type = try_expanding_enum_to_union(current_type, - enum_name, - ignore_custom_equals=False) + current_type = try_expanding_sum_type_to_union(current_type, + enum_name, + ignore_custom_equals=False) proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) if isinstance(proposed_type, AnyType): From 5d82b79d33a0c208f5cbb4623e5a76c5746e78f3 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 28 Feb 2022 15:16:49 +0000 Subject: [PATCH 06/23] Fix test case --- test-data/unit/check-python310.test | 1 + 1 file changed, 1 insertion(+) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 260a4c1e6756..ff45542eae78 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1287,6 +1287,7 @@ match m3: reveal_type(m3) # N: Revealed type is "Tuple[Literal[1]]" case r2: reveal_type(m3) # N: Revealed type is "Tuple[Union[builtins.int, builtins.str]]" +[builtins fixtures/tuple.pyi] [case testLiteralPatternEnumNegativeNarrowing] from enum import Enum From a36749a968c371b9f8bae04d3260dc99bfde8a75 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 28 Feb 2022 15:16:55 +0000 Subject: [PATCH 07/23] Skip test case (testLiteralPatternEnumCustomEquals) --- test-data/unit/check-python310.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index ff45542eae78..04acbaa62449 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1304,7 +1304,7 @@ match m: case _: reveal_type(m) # N: Revealed type is "Union[Literal[__main__.Medal.silver], Literal[__main__.Medal.bronze]]" -[case testLiteralPatternEnumCustomEquals] +[case testLiteralPatternEnumCustomEquals-skip] from enum import Enum class Medal(Enum): gold = 1 From e14bffb6385be2ec277876586146ee7220fb9e94 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 28 Feb 2022 15:29:15 +0000 Subject: [PATCH 08/23] Update test cases --- test-data/unit/check-python310.test | 115 +++++++++++++++++++++------- 1 file changed, 89 insertions(+), 26 deletions(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 04acbaa62449..688a8fe51bbe 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1185,7 +1185,7 @@ def foo(value) -> int: case _: return 2 -[case testNoneExhaustiveReturn] +[case testNonExhaustiveReturn] def foo(value) -> int: # E: Missing return statement match value: case "bar": @@ -1193,6 +1193,50 @@ def foo(value) -> int: # E: Missing return statement case 2: return 2 +[case testMoreExhaustiveReturnChecking] +def f(value: int | str | None) -> int: # E: Missing return statement + match value: + case int(): + return 0 + case None: + return 1 + +def g(value: int | None) -> int: + match value: + case int(): + return 0 + case None: + return 1 + +[case testMiscNonExhaustiveReturn-skip] +class C: + a: int | str + +def f(c: C) -> int: # E: Missing return statement + match c: + case C(a=int()): + return 0 + case C(a=str()): + return 1 + +def g(x: list[str]) -> int: # E: Missing return statement + match x: + case [a]: + return 0 + case [a, b]: + return 1 + +def h(x: dict[str, int]) -> int: # E: Missing return statement + match x: + case {'x': a}: + return 0 + +def ff(x: bool) -> int: # E: Missing return statement + match x: + case True: + return 0 +[builtins fixtures/dict.pyi] + [case testWithStatementScopeAndMatchStatement] from m import A, B @@ -1238,33 +1282,33 @@ class B: from typing import NoReturn def assert_never(x: NoReturn) -> None: ... -value: int -match value: - case 1: - pass - case 2: - pass - case o: - assert_never(o) # E: Argument 1 to "assert_never" has incompatible type "int"; expected "NoReturn" +def f(value: int) -> int: # E: Missing return statement + match value: + case 1: + return 0 + case 2: + return 1 + case o: + assert_never(o) # E: Argument 1 to "assert_never" has incompatible type "int"; expected "NoReturn" [case testExhaustiveNoError] from typing import NoReturn, Union, Literal def assert_never(x: NoReturn) -> None: ... -value: Union[Literal[1], Literal[2]] -match value: - case 1: - pass - case 2: - pass - case o: - assert_never(o) +def f(value: Literal[1] | Literal[2]) -> int: + match value: + case 1: + return 0 + case 2: + return 1 + case o: + assert_never(o) [typing fixtures/typing-medium.pyi] [case testSequencePatternNegativeNarrowing] from typing import Union, Sequence, Tuple -m1: Sequence[Union[int, str]] +m1: Sequence[int | str] match m1: case [int()]: @@ -1272,7 +1316,7 @@ match m1: case r: reveal_type(m1) # N: Revealed type is "typing.Sequence[Union[builtins.int, builtins.str]]" -m2: Tuple[Union[int, str]] +m2: Tuple[int | str] match m2: case (int(),): @@ -1296,13 +1340,23 @@ class Medal(Enum): silver = 2 bronze = 3 -m: Medal - -match m: - case Medal.gold: - reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" - case _: - reveal_type(m) # N: Revealed type is "Union[Literal[__main__.Medal.silver], Literal[__main__.Medal.bronze]]" +def f(m: Medal) -> int: + match m: + case Medal.gold: + reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" + return 0 + case _: + reveal_type(m) # N: Revealed type is "Union[Literal[__main__.Medal.silver], Literal[__main__.Medal.bronze]]" + return 1 + +def g(m: Medal) -> int: + match m: + case Medal.gold: + return 0 + case Medal.silver: + return 1 + case Medal.bronze: + return 2 [case testLiteralPatternEnumCustomEquals-skip] from enum import Enum @@ -1320,3 +1374,12 @@ match m: reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" case _: reveal_type(m) # N: Revealed type is "__main__.Medal" + +[case testNarrowUsingPatternGuardSpecialCase] +def f(x: int | str) -> int: # E: Missing return statement + match x: + case x if isinstance(x, str): + return 0 + case int(): + return 1 +[builtins fixtures/isinstance.pyi] From 1246bb7c8142a44507d6c6da30f3ad81922f7526 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 28 Feb 2022 17:19:37 +0000 Subject: [PATCH 09/23] Add docstrings --- mypy/patterns.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mypy/patterns.py b/mypy/patterns.py index 8557fac6daf6..f7f5f56d0ed5 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -21,6 +21,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class AsPattern(Pattern): + """The pattern as """ # The python ast, and therefore also our ast merges capture, wildcard and as patterns into one # for easier handling. # If pattern is None this is a capture pattern. If name and pattern are both none this is a @@ -39,6 +40,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class OrPattern(Pattern): + """The pattern | | ...""" patterns: List[Pattern] def __init__(self, patterns: List[Pattern]) -> None: @@ -50,6 +52,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class ValuePattern(Pattern): + """The pattern x.y (or x.y.z, ...)""" expr: Expression def __init__(self, expr: Expression): @@ -73,6 +76,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class SequencePattern(Pattern): + """The pattern [, ...]""" patterns: List[Pattern] def __init__(self, patterns: List[Pattern]): @@ -114,6 +118,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class ClassPattern(Pattern): + """The pattern Cls(...)""" class_ref: RefExpr positionals: List[Pattern] keyword_keys: List[str] From a7a3ec9d2cc2f718e035932c879485a62f906ed8 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 28 Feb 2022 17:20:00 +0000 Subject: [PATCH 10/23] Fix unreachable code after "case True" --- mypy/checkpattern.py | 4 ++ test-data/unit/check-python310.test | 67 ++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 5d007af705f9..f10e04888cd1 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -693,6 +693,10 @@ def get_var(expr: Expression) -> Var: def get_type_range(typ: Type) -> 'mypy.checker.TypeRange': + if (isinstance(typ, Instance) + and typ.last_known_value + and isinstance(typ.last_known_value.value, bool)): + typ = typ.last_known_value return mypy.checker.TypeRange(typ, is_upper_bound=False) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 688a8fe51bbe..e2626678447f 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1208,7 +1208,7 @@ def g(value: int | None) -> int: case None: return 1 -[case testMiscNonExhaustiveReturn-skip] +[case testMiscNonExhaustiveReturn] class C: a: int | str @@ -1383,3 +1383,68 @@ def f(x: int | str) -> int: # E: Missing return statement case int(): return 1 [builtins fixtures/isinstance.pyi] + +[case testNarrowingDownUnionPartially] +# flags: --strict-optional + +def f(x: int | str) -> None: + match x: + case int(): + return + reveal_type(x) # N: Revealed type is "builtins.str" + +def g(x: int | str | None) -> None: + match x: + case int() | None: + return + reveal_type(x) # N: Revealed type is "builtins.str" + +def h(x: int | str | None) -> None: + match x: + case int() | str(): + return + reveal_type(x) # N: Revealed type is "None" + +[case testNarrowDownUsingLiteralMatch] +from enum import Enum +class Medal(Enum): + gold = 1 + silver = 2 + +def b1(x: bool) -> None: + match x: + case True: + return + # Possibly we could have Literal[False] here? + reveal_type(x) # N: Revealed type is "builtins.bool" + +def b2(x: bool) -> None: + match x: + case False: + return + # Possibly we could have Literal[True] here? + reveal_type(x) # N: Revealed type is "builtins.bool" + +def e1(x: Medal) -> None: + match x: + case Medal.gold: + return + reveal_type(x) # N: Revealed type is "Literal[__main__.Medal.silver]" + +def e2(x: Medal) -> None: + match x: + case Medal.silver: + return + reveal_type(x) # N: Revealed type is "Literal[__main__.Medal.gold]" + +def i(x: int) -> None: + match x: + case 1: + return + reveal_type(x) # N: Revealed type is "builtins.int" + +def s(x: str) -> None: + match x: + case 'x': + return + reveal_type(x) # N: Revealed type is "builtins.str" From 1095852546198b80959794ab2fae3efde81a4465 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 28 Feb 2022 17:42:53 +0000 Subject: [PATCH 11/23] Support narrowing down bool values --- mypy/checker.py | 3 ++- test-data/unit/check-python310.test | 41 ++++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 39f4aeaada85..1e74d5a546e7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5674,7 +5674,8 @@ def conditional_types(current_type: Type, if len(proposed_type_ranges) == 1: target = proposed_type_ranges[0].item target = get_proper_type(target) - if isinstance(target, LiteralType) and target.is_enum_literal(): + if isinstance(target, LiteralType) and (target.is_enum_literal() + or isinstance(target.value, bool)): enum_name = target.fallback.type.fullname current_type = try_expanding_sum_type_to_union(current_type, enum_name, diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index e2626678447f..83900b10e803 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1152,6 +1152,7 @@ match m: [builtins fixtures/isinstancelist.pyi] -- Exhaustiveness -- + [case testUnionNegativeNarrowing] from typing import Union @@ -1208,6 +1209,13 @@ def g(value: int | None) -> int: case None: return 1 +def b(value: bool) -> int: + match value: + case True: + return 2 + case False: + return 3 + [case testMiscNonExhaustiveReturn] class C: a: int | str @@ -1415,15 +1423,13 @@ def b1(x: bool) -> None: match x: case True: return - # Possibly we could have Literal[False] here? - reveal_type(x) # N: Revealed type is "builtins.bool" + reveal_type(x) # N: Revealed type is "Literal[False]" def b2(x: bool) -> None: match x: case False: return - # Possibly we could have Literal[True] here? - reveal_type(x) # N: Revealed type is "builtins.bool" + reveal_type(x) # N: Revealed type is "Literal[True]" def e1(x: Medal) -> None: match x: @@ -1448,3 +1454,30 @@ def s(x: str) -> None: case 'x': return reveal_type(x) # N: Revealed type is "builtins.str" + +def union(x: str | bool) -> None: + match x: + case True: + return + reveal_type(x) # N: Revealed type is "Union[builtins.str, Literal[False]]" + +[case testMatchAssertFalseToSilenceFalsePositives] +class C: + a: int | str + +def f(c: C) -> int: + match c: + case C(a=int()): + return 0 + case C(a=str()): + return 1 + case _: + assert False + +def g(c: C) -> int: + match c: + case C(a=int()): + return 0 + case C(a=str()): + return 1 + assert False From d3d04f9e15ee9feb3ceba73fcefd631f76a433fb Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 12:08:31 +0000 Subject: [PATCH 12/23] Fix type check --- mypy/checkpattern.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index f10e04888cd1..81ab81001421 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -693,6 +693,7 @@ def get_var(expr: Expression) -> Var: def get_type_range(typ: Type) -> 'mypy.checker.TypeRange': + typ = get_proper_type(typ) if (isinstance(typ, Instance) and typ.last_known_value and isinstance(typ.last_known_value.value, bool)): From 81bf0a0bd1a6d729caee9bb3f15cee2d28da52c2 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 12:09:36 +0000 Subject: [PATCH 13/23] Remove unused function --- mypy/typeops.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index aa5e11dd34c0..2ede9194ac26 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -683,16 +683,6 @@ def is_singleton_type(typ: Type) -> bool: ) -def enum_has_custom_equals(enum: Instance) -> bool: - assert enum.type.is_enum - for typ in enum.type.mro: - if typ.fullname == "enum.Enum": - return False - if "__eq__" in typ.names: - return True - assert False - - def try_expanding_sum_type_to_union(typ: Type, target_fullname: str, *, From 2969faccdeb5438cec8a9c8f138ce2d80ec4334a Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 13:30:05 +0000 Subject: [PATCH 14/23] Rename and restructure match statement test cases --- test-data/unit/check-python310.test | 331 ++++++++++++++-------------- 1 file changed, 167 insertions(+), 164 deletions(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 83900b10e803..bd828b785055 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1,5 +1,6 @@ -- Capture Pattern -- -[case testCapturePatternType] + +[case testMatchCapturePatternType] class A: ... m: A @@ -7,16 +8,16 @@ match m: case a: reveal_type(a) # N: Revealed type is "__main__.A" - -- Literal Pattern -- -[case testLiteralPatternNarrows] + +[case testMatchLiteralPatternNarrows] m: object match m: case 1: reveal_type(m) # N: Revealed type is "Literal[1]" -[case testLiteralPatternAlreadyNarrower-skip] +[case testMatchLiteralPatternAlreadyNarrower-skip] m: bool match m: @@ -24,7 +25,7 @@ match m: reveal_type(m) # This should probably be unreachable, but isn't detected as such. [builtins fixtures/primitives.pyi] -[case testLiteralPatternUnreachable] +[case testMatchLiteralPatternUnreachable] # primitives are needed because otherwise mypy doesn't see that int and str are incompatible m: int @@ -33,9 +34,9 @@ match m: reveal_type(m) [builtins fixtures/primitives.pyi] - -- Value Pattern -- -[case testValuePatternNarrows] + +[case testMatchValuePatternNarrows] import b m: object @@ -45,7 +46,7 @@ match m: [file b.py] b: int -[case testValuePatternAlreadyNarrower] +[case testMatchValuePatternAlreadyNarrower] import b m: bool @@ -55,7 +56,7 @@ match m: [file b.py] b: int -[case testValuePatternIntersect] +[case testMatchValuePatternIntersect] import b class A: ... @@ -68,7 +69,7 @@ match m: class B: ... b: B -[case testValuePatternUnreachable] +[case testMatchValuePatternUnreachable] # primitives are needed because otherwise mypy doesn't see that int and str are incompatible import b @@ -81,9 +82,9 @@ match m: b: str [builtins fixtures/primitives.pyi] - -- Sequence Pattern -- -[case testSequencePatternCaptures] + +[case testMatchSequencePatternCaptures] from typing import List m: List[int] @@ -92,7 +93,7 @@ match m: reveal_type(a) # N: Revealed type is "builtins.int*" [builtins fixtures/list.pyi] -[case testSequencePatternCapturesStarred] +[case testMatchSequencePatternCapturesStarred] from typing import Sequence m: Sequence[int] @@ -102,7 +103,7 @@ match m: reveal_type(b) # N: Revealed type is "builtins.list[builtins.int*]" [builtins fixtures/list.pyi] -[case testSequencePatternNarrowsInner] +[case testMatchSequencePatternNarrowsInner] from typing import Sequence m: Sequence[object] @@ -110,7 +111,7 @@ match m: case [1, True]: reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" -[case testSequencePatternNarrowsOuter] +[case testMatchSequencePatternNarrowsOuter] from typing import Sequence m: object @@ -118,7 +119,7 @@ match m: case [1, True]: reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" -[case testSequencePatternAlreadyNarrowerInner] +[case testMatchSequencePatternAlreadyNarrowerInner] from typing import Sequence m: Sequence[bool] @@ -126,7 +127,7 @@ match m: case [1, True]: reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.bool]" -[case testSequencePatternAlreadyNarrowerOuter] +[case testMatchSequencePatternAlreadyNarrowerOuter] from typing import Sequence m: Sequence[object] @@ -134,7 +135,7 @@ match m: case [1, True]: reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" -[case testSequencePatternAlreadyNarrowerBoth] +[case testMatchSequencePatternAlreadyNarrowerBoth] from typing import Sequence m: Sequence[bool] @@ -142,7 +143,7 @@ match m: case [1, True]: reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.bool]" -[case testNestedSequencePatternNarrowsInner] +[case testMatchNestedSequencePatternNarrowsInner] from typing import Sequence m: Sequence[Sequence[object]] @@ -150,7 +151,7 @@ match m: case [[1], [True]]: reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" -[case testNestedSequencePatternNarrowsOuter] +[case testMatchNestedSequencePatternNarrowsOuter] from typing import Sequence m: object @@ -158,7 +159,7 @@ match m: case [[1], [True]]: reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" -[case testSequencePatternDoesntNarrowInvariant] +[case testMatchSequencePatternDoesntNarrowInvariant] from typing import List m: List[object] @@ -167,7 +168,7 @@ match m: reveal_type(m) # N: Revealed type is "builtins.list[builtins.object]" [builtins fixtures/list.pyi] -[case testSequencePatternMatches] +[case testMatchSequencePatternMatches] import array, collections from typing import Sequence, Iterable @@ -230,8 +231,7 @@ match m11: [builtins fixtures/primitives.pyi] [typing fixtures/typing-full.pyi] - -[case testSequencePatternCapturesTuple] +[case testMatchSequencePatternCapturesTuple] from typing import Tuple m: Tuple[int, str, bool] @@ -243,7 +243,7 @@ match m: reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.bool]" [builtins fixtures/list.pyi] -[case testSequencePatternTupleTooLong] +[case testMatchSequencePatternTupleTooLong] from typing import Tuple m: Tuple[int, str] @@ -254,7 +254,7 @@ match m: reveal_type(c) [builtins fixtures/list.pyi] -[case testSequencePatternTupleTooShort] +[case testMatchSequencePatternTupleTooShort] from typing import Tuple m: Tuple[int, str, bool] @@ -264,7 +264,7 @@ match m: reveal_type(b) [builtins fixtures/list.pyi] -[case testSequencePatternTupleNarrows] +[case testMatchSequencePatternTupleNarrows] from typing import Tuple m: Tuple[object, object] @@ -273,7 +273,7 @@ match m: reveal_type(m) # N: Revealed type is "Tuple[Literal[1], Literal['str']]" [builtins fixtures/list.pyi] -[case testSequencePatternTupleStarred] +[case testMatchSequencePatternTupleStarred] from typing import Tuple m: Tuple[int, str, bool] @@ -285,7 +285,7 @@ match m: reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.bool]" [builtins fixtures/list.pyi] -[case testSequencePatternTupleStarredUnion] +[case testMatchSequencePatternTupleStarredUnion] from typing import Tuple m: Tuple[int, str, float, bool] @@ -297,8 +297,7 @@ match m: reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.float, builtins.bool]" [builtins fixtures/list.pyi] - -[case testSequencePatternTupleStarredTooShort] +[case testMatchSequencePatternTupleStarredTooShort] from typing import Tuple m: Tuple[int] reveal_type(m) # N: Revealed type is "Tuple[builtins.int]" @@ -310,7 +309,7 @@ match m: reveal_type(c) [builtins fixtures/list.pyi] -[case testNonMatchingSequencePattern] +[case testMatchNonMatchingSequencePattern] from typing import List x: List[int] @@ -318,7 +317,7 @@ match x: case [str()]: pass -[case testSequenceUnion-skip] +[case testMatchSequenceUnion-skip] from typing import List, Union m: Union[List[List[str]], str] @@ -328,7 +327,8 @@ match m: [builtins fixtures/list.pyi] -- Mapping Pattern -- -[case testMappingPatternCaptures] + +[case testMatchMappingPatternCaptures] from typing import Dict import b m: Dict[str, int] @@ -342,7 +342,7 @@ match m: b: str [builtins fixtures/dict.pyi] -[case testMappingPatternCapturesWrongKeyType] +[case testMatchMappingPatternCapturesWrongKeyType] # This is not actually unreachable, as a subclass of dict could accept keys with different types from typing import Dict import b @@ -357,7 +357,7 @@ match m: b: int [builtins fixtures/dict.pyi] -[case testMappingPatternCapturesTypedDict] +[case testMatchMappingPatternCapturesTypedDict] from typing import TypedDict class A(TypedDict): @@ -378,7 +378,7 @@ match m: reveal_type(v5) # N: Revealed type is "builtins.object*" [typing fixtures/typing-typeddict.pyi] -[case testMappingPatternCapturesTypedDictWithLiteral] +[case testMatchMappingPatternCapturesTypedDictWithLiteral] from typing import TypedDict import b @@ -405,7 +405,7 @@ b: Literal["b"] = "b" o: Final[str] = "o" [typing fixtures/typing-typeddict.pyi] -[case testMappingPatternCapturesTypedDictWithNonLiteral] +[case testMatchMappingPatternCapturesTypedDictWithNonLiteral] from typing import TypedDict import b @@ -423,7 +423,7 @@ from typing import Final, Literal a: str [typing fixtures/typing-typeddict.pyi] -[case testMappingPatternCapturesTypedDictUnreachable] +[case testMatchMappingPatternCapturesTypedDictUnreachable] # TypedDict keys are always str, so this is actually unreachable from typing import TypedDict import b @@ -443,7 +443,7 @@ match m: b: int [typing fixtures/typing-typeddict.pyi] -[case testMappingPatternCaptureRest] +[case testMatchMappingPatternCaptureRest] m: object match m: @@ -451,7 +451,7 @@ match m: reveal_type(r) # N: Revealed type is "builtins.dict[builtins.object, builtins.object]" [builtins fixtures/dict.pyi] -[case testMappingPatternCaptureRestFromMapping] +[case testMatchMappingPatternCaptureRestFromMapping] from typing import Mapping m: Mapping[str, int] @@ -461,10 +461,11 @@ match m: reveal_type(r) # N: Revealed type is "builtins.dict[builtins.str*, builtins.int*]" [builtins fixtures/dict.pyi] --- Mapping patterns currently don't narrow -- +-- Mapping patterns currently do not narrow -- -- Class Pattern -- -[case testClassPatternCapturePositional] + +[case testMatchClassPatternCapturePositional] from typing import Final class A: @@ -480,7 +481,7 @@ match m: reveal_type(j) # N: Revealed type is "builtins.int" [builtins fixtures/tuple.pyi] -[case testClassPatternMemberClassCapturePositional] +[case testMatchClassPatternMemberClassCapturePositional] import b m: b.A @@ -498,7 +499,7 @@ class A: b: int [builtins fixtures/tuple.pyi] -[case testClassPatternCaptureKeyword] +[case testMatchClassPatternCaptureKeyword] class A: a: str b: int @@ -510,7 +511,7 @@ match m: reveal_type(i) # N: Revealed type is "builtins.str" reveal_type(j) # N: Revealed type is "builtins.int" -[case testClassPatternCaptureSelf] +[case testMatchClassPatternCaptureSelf] m: object match m: @@ -538,7 +539,7 @@ match m: reveal_type(k) # N: Revealed type is "builtins.tuple[Any, ...]" [builtins fixtures/primitives.pyi] -[case testClassPatternNarrowSelfCapture] +[case testMatchClassPatternNarrowSelfCapture] m: object match m: @@ -566,7 +567,7 @@ match m: reveal_type(m) # N: Revealed type is "builtins.tuple[Any, ...]" [builtins fixtures/primitives.pyi] -[case testClassPatternCaptureDataclass] +[case testMatchClassPatternCaptureDataclass] from dataclasses import dataclass @dataclass @@ -582,7 +583,7 @@ match m: reveal_type(j) # N: Revealed type is "builtins.int" [builtins fixtures/dataclasses.pyi] -[case testClassPatternCaptureDataclassNoMatchArgs] +[case testMatchClassPatternCaptureDataclassNoMatchArgs] from dataclasses import dataclass @dataclass(match_args=False) @@ -597,7 +598,7 @@ match m: pass [builtins fixtures/dataclasses.pyi] -[case testClassPatternCaptureDataclassPartialMatchArgs] +[case testMatchClassPatternCaptureDataclassPartialMatchArgs] from dataclasses import dataclass, field @dataclass @@ -614,7 +615,7 @@ match m: reveal_type(k) # N: Revealed type is "builtins.str" [builtins fixtures/dataclasses.pyi] -[case testClassPatternCaptureNamedTupleInline] +[case testMatchClassPatternCaptureNamedTupleInline] from collections import namedtuple A = namedtuple("A", ["a", "b"]) @@ -627,7 +628,7 @@ match m: reveal_type(j) # N: Revealed type is "Any" [builtins fixtures/list.pyi] -[case testClassPatternCaptureNamedTupleInlineTyped] +[case testMatchClassPatternCaptureNamedTupleInlineTyped] from typing import NamedTuple A = NamedTuple("A", [("a", str), ("b", int)]) @@ -640,7 +641,7 @@ match m: reveal_type(j) # N: Revealed type is "builtins.int" [builtins fixtures/list.pyi] -[case testClassPatternCaptureNamedTupleClass] +[case testMatchClassPatternCaptureNamedTupleClass] from typing import NamedTuple class A(NamedTuple): @@ -655,7 +656,7 @@ match m: reveal_type(j) # N: Revealed type is "builtins.int" [builtins fixtures/tuple.pyi] -[case testClassPatternCaptureGeneric] +[case testMatchClassPatternCaptureGeneric] from typing import Generic, TypeVar T = TypeVar('T') @@ -670,7 +671,7 @@ match m: reveal_type(m) # N: Revealed type is "__main__.A[Any]" reveal_type(i) # N: Revealed type is "Any" -[case testClassPatternCaptureGenericAlreadyKnown] +[case testMatchClassPatternCaptureGenericAlreadyKnown] from typing import Generic, TypeVar T = TypeVar('T') @@ -685,7 +686,7 @@ match m: reveal_type(m) # N: Revealed type is "__main__.A[builtins.int]" reveal_type(i) # N: Revealed type is "builtins.int*" -[case testClassPatternCaptureFilledGenericTypeAlias] +[case testMatchClassPatternCaptureFilledGenericTypeAlias] from typing import Generic, TypeVar T = TypeVar('T') @@ -701,7 +702,7 @@ match m: case B(a=i): # E: Class pattern class must not be a type alias with type parameters reveal_type(i) -[case testClassPatternCaptureGenericTypeAlias] +[case testMatchClassPatternCaptureGenericTypeAlias] from typing import Generic, TypeVar T = TypeVar('T') @@ -717,7 +718,7 @@ match m: case B(a=i): pass -[case testClassPatternNarrows] +[case testMatchClassPatternNarrows] from typing import Final class A: @@ -734,7 +735,7 @@ match m: reveal_type(m) # N: Revealed type is "__main__.A" [builtins fixtures/tuple.pyi] -[case testClassPatternNarrowsUnion] +[case testMatchClassPatternNarrowsUnion] from typing import Final, Union class A: @@ -770,7 +771,7 @@ match m: reveal_type(l) # N: Revealed type is "builtins.str" [builtins fixtures/tuple.pyi] -[case testClassPatternAlreadyNarrower] +[case testMatchClassPatternAlreadyNarrower] from typing import Final class A: @@ -790,7 +791,7 @@ match m: reveal_type(m) # N: Revealed type is "__main__.B" [builtins fixtures/tuple.pyi] -[case testClassPatternIntersection] +[case testMatchClassPatternIntersection] from typing import Final class A: @@ -808,7 +809,7 @@ match m: reveal_type(m) # N: Revealed type is "__main__.1" [builtins fixtures/tuple.pyi] -[case testClassPatternNonexistentKeyword] +[case testMatchClassPatternNonexistentKeyword] class A: ... m: object @@ -818,7 +819,7 @@ match m: reveal_type(m) # N: Revealed type is "__main__.A" reveal_type(j) # N: Revealed type is "Any" -[case testClassPatternDuplicateKeyword] +[case testMatchClassPatternDuplicateKeyword] class A: a: str @@ -828,7 +829,7 @@ match m: case A(a=i, a=j): # E: Duplicate keyword pattern "a" pass -[case testClassPatternDuplicateImplicitKeyword] +[case testMatchClassPatternDuplicateImplicitKeyword] from typing import Final class A: @@ -842,7 +843,7 @@ match m: pass [builtins fixtures/tuple.pyi] -[case testClassPatternTooManyPositionals] +[case testMatchClassPatternTooManyPositionals] from typing import Final class A: @@ -857,7 +858,7 @@ match m: pass [builtins fixtures/tuple.pyi] -[case testClassPatternIsNotType] +[case testMatchClassPatternIsNotType] a = 1 m: object @@ -866,7 +867,7 @@ match m: reveal_type(i) reveal_type(j) -[case testClassPatternNestedGenerics] +[case testMatchClassPatternNestedGenerics] # From cpython test_patma.py x = [[{0: 0}]] match x: @@ -878,7 +879,7 @@ reveal_type(y) # N: Revealed type is "builtins.int" reveal_type(z) # N: Revealed type is "builtins.int*" [builtins fixtures/dict.pyi] -[case testNonFinalMatchArgs] +[case testMatchNonFinalMatchArgs] class A: __match_args__ = ("a", "b") # N: __match_args__ must be final for checking of match statements to work a: str @@ -892,7 +893,7 @@ match m: reveal_type(j) # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] -[case testAnyTupleMatchArgs] +[case testMatchAnyTupleMatchArgs] from typing import Tuple, Any class A: @@ -909,7 +910,7 @@ match m: reveal_type(k) # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] -[case testNonLiteralMatchArgs] +[case testMatchNonLiteralMatchArgs] from typing import Final b: str = "b" @@ -928,7 +929,7 @@ match m: reveal_type(j) # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] -[case testExternalMatchArgs] +[case testMatchExternalMatchArgs] from typing import Final, Literal args: Final = ("a", "b") @@ -947,9 +948,9 @@ class B: [builtins fixtures/tuple.pyi] [typing fixtures/typing-medium.pyi] - -- As Pattern -- -[case testAsPattern] + +[case testMatchAsPattern] m: int match m: @@ -957,51 +958,51 @@ match m: reveal_type(x) # N: Revealed type is "builtins.int" reveal_type(l) # N: Revealed type is "builtins.int" -[case testAsPatternNarrows] +[case testMatchAsPatternNarrows] m: object match m: case int() as l: reveal_type(l) # N: Revealed type is "builtins.int" -[case testAsPatternCapturesOr] +[case testMatchAsPatternCapturesOr] m: object match m: case 1 | 2 as n: reveal_type(n) # N: Revealed type is "Union[Literal[1], Literal[2]]" -[case testAsPatternAlreadyNarrower] +[case testMatchAsPatternAlreadyNarrower] m: bool match m: case int() as l: reveal_type(l) # N: Revealed type is "builtins.bool" - -- Or Pattern -- -[case testOrPatternNarrows] + +[case testMatchOrPatternNarrows] m: object match m: case 1 | 2: reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]" -[case testOrPatternNarrowsStr] +[case testMatchOrPatternNarrowsStr] m: object match m: case "foo" | "bar": reveal_type(m) # N: Revealed type is "Union[Literal['foo'], Literal['bar']]" -[case testOrPatternNarrowsUnion] +[case testMatchOrPatternNarrowsUnion] m: object match m: case 1 | "foo": reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal['foo']]" -[case testOrPatterCapturesMissing] +[case testMatchOrPatterCapturesMissing] from typing import List m: List[int] @@ -1011,7 +1012,7 @@ match m: reveal_type(y) # N: Revealed type is "builtins.int*" [builtins fixtures/list.pyi] -[case testOrPatternCapturesJoin] +[case testMatchOrPatternCapturesJoin] m: object match m: @@ -1019,9 +1020,9 @@ match m: reveal_type(x) # N: Revealed type is "typing.Iterable[Any]" [builtins fixtures/dict.pyi] - -- Interactions -- -[case testCapturePatternMultipleCases] + +[case testMatchCapturePatternMultipleCases] m: object match m: @@ -1032,7 +1033,7 @@ match m: reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]" -[case testCapturePatternMultipleCaptures] +[case testMatchCapturePatternMultipleCaptures] from typing import Iterable m: Iterable[int] @@ -1042,7 +1043,7 @@ match m: reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/list.pyi] -[case testCapturePatternPreexistingSame] +[case testMatchCapturePatternPreexistingSame] a: int m: int @@ -1050,7 +1051,7 @@ match m: case a: reveal_type(a) # N: Revealed type is "builtins.int" -[case testCapturePatternPreexistingNarrows] +[case testMatchCapturePatternPreexistingNarrows] a: int m: bool @@ -1062,7 +1063,7 @@ reveal_type(a) # N: Revealed type is "builtins.bool" a = 3 reveal_type(a) # N: Revealed type is "builtins.int" -[case testCapturePatternPreexistingIncompatible] +[case testMatchCapturePatternPreexistingIncompatible] a: str m: int @@ -1072,7 +1073,7 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" -[case testCapturePatternPreexistingIncompatibleLater] +[case testMatchCapturePatternPreexistingIncompatibleLater] a: str m: object @@ -1084,9 +1085,9 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" - -- Guards -- -[case testSimplePatternGuard] + +[case testMatchSimplePatternGuard] m: str def guard() -> bool: ... @@ -1095,21 +1096,21 @@ match m: case a if guard(): reveal_type(a) # N: Revealed type is "builtins.str" -[case testAlwaysTruePatternGuard] +[case testMatchAlwaysTruePatternGuard] m: str match m: case a if True: reveal_type(a) # N: Revealed type is "builtins.str" -[case testAlwaysFalsePatternGuard] +[case testMatchAlwaysFalsePatternGuard] m: str match m: case a if False: reveal_type(a) -[case testRedefiningPatternGuard] +[case testMatchRedefiningPatternGuard] # flags: --strict-optional m: str @@ -1117,14 +1118,14 @@ match m: case a if a := 1: # E: Incompatible types in assignment (expression has type "int", variable has type "str") reveal_type(a) # N: Revealed type is "" -[case testAssigningPatternGuard] +[case testMatchAssigningPatternGuard] m: str match m: case a if a := "test": reveal_type(a) # N: Revealed type is "builtins.str" -[case testNarrowingPatternGuard] +[case testMatchNarrowingPatternGuard] m: object match m: @@ -1132,7 +1133,7 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" [builtins fixtures/isinstancelist.pyi] -[case testIncompatiblePatternGuard] +[case testMatchIncompatiblePatternGuard] class A: ... class B: ... @@ -1143,7 +1144,7 @@ match m: reveal_type(a) # N: Revealed type is "__main__." [builtins fixtures/isinstancelist.pyi] -[case testUnreachablePatternGuard] +[case testMatchUnreachablePatternGuard] m: str match m: @@ -1153,7 +1154,7 @@ match m: -- Exhaustiveness -- -[case testUnionNegativeNarrowing] +[case testMatchUnionNegativeNarrowing] from typing import Union m: Union[str, int] @@ -1166,7 +1167,7 @@ match m: reveal_type(b) # N: Revealed type is "builtins.int" reveal_type(m) # N: Revealed type is "builtins.int" -[case testOrPatternNegativeNarrowing] +[case testMatchOrPatternNegativeNarrowing] from typing import Union m: Union[str, bytes, int] @@ -1178,7 +1179,7 @@ match m: case b: reveal_type(b) # N: Revealed type is "builtins.int" -[case testExhaustiveReturn] +[case testMatchExhaustiveReturn] def foo(value) -> int: match value: case "bar": @@ -1186,7 +1187,7 @@ def foo(value) -> int: case _: return 2 -[case testNonExhaustiveReturn] +[case testMatchNonExhaustiveReturn] def foo(value) -> int: # E: Missing return statement match value: case "bar": @@ -1194,14 +1195,7 @@ def foo(value) -> int: # E: Missing return statement case 2: return 2 -[case testMoreExhaustiveReturnChecking] -def f(value: int | str | None) -> int: # E: Missing return statement - match value: - case int(): - return 0 - case None: - return 1 - +[case testMatchMoreExhaustiveReturnCases] def g(value: int | None) -> int: match value: case int(): @@ -1216,77 +1210,43 @@ def b(value: bool) -> int: case False: return 3 -[case testMiscNonExhaustiveReturn] +[case testMatchMiscNonExhaustiveReturn] class C: a: int | str -def f(c: C) -> int: # E: Missing return statement +def f1(value: int | str | None) -> int: # E: Missing return statement + match value: + case int(): + return 0 + case None: + return 1 + +def f2(c: C) -> int: # E: Missing return statement match c: case C(a=int()): return 0 case C(a=str()): return 1 -def g(x: list[str]) -> int: # E: Missing return statement +def f3(x: list[str]) -> int: # E: Missing return statement match x: case [a]: return 0 case [a, b]: return 1 -def h(x: dict[str, int]) -> int: # E: Missing return statement +def f4(x: dict[str, int]) -> int: # E: Missing return statement match x: case {'x': a}: return 0 -def ff(x: bool) -> int: # E: Missing return statement +def f5(x: bool) -> int: # E: Missing return statement match x: case True: return 0 [builtins fixtures/dict.pyi] -[case testWithStatementScopeAndMatchStatement] -from m import A, B - -with A() as x: - pass -with B() as x: \ - # E: Incompatible types in assignment (expression has type "B", variable has type "A") - pass - -with A() as y: - pass -with B() as y: \ - # E: Incompatible types in assignment (expression has type "B", variable has type "A") - pass - -with A() as z: - pass -with B() as z: \ - # E: Incompatible types in assignment (expression has type "B", variable has type "A") - pass - -with A() as zz: - pass -with B() as zz: \ - # E: Incompatible types in assignment (expression has type "B", variable has type "A") - pass - -match x: - case str(y) as z: - zz = y - -[file m.pyi] -from typing import Any - -class A: - def __enter__(self) -> A: ... - def __exit__(self, x, y, z) -> None: ... -class B: - def __enter__(self) -> B: ... - def __exit__(self, x, y, z) -> None: ... - -[case testNonExhaustiveError] +[case testMatchNonExhaustiveError] from typing import NoReturn def assert_never(x: NoReturn) -> None: ... @@ -1299,7 +1259,7 @@ def f(value: int) -> int: # E: Missing return statement case o: assert_never(o) # E: Argument 1 to "assert_never" has incompatible type "int"; expected "NoReturn" -[case testExhaustiveNoError] +[case testMatchExhaustiveNoError] from typing import NoReturn, Union, Literal def assert_never(x: NoReturn) -> None: ... @@ -1313,7 +1273,7 @@ def f(value: Literal[1] | Literal[2]) -> int: assert_never(o) [typing fixtures/typing-medium.pyi] -[case testSequencePatternNegativeNarrowing] +[case testMatchSequencePatternNegativeNarrowing] from typing import Union, Sequence, Tuple m1: Sequence[int | str] @@ -1341,7 +1301,7 @@ match m3: reveal_type(m3) # N: Revealed type is "Tuple[Union[builtins.int, builtins.str]]" [builtins fixtures/tuple.pyi] -[case testLiteralPatternEnumNegativeNarrowing] +[case testMatchLiteralPatternEnumNegativeNarrowing] from enum import Enum class Medal(Enum): gold = 1 @@ -1366,7 +1326,7 @@ def g(m: Medal) -> int: case Medal.bronze: return 2 -[case testLiteralPatternEnumCustomEquals-skip] +[case testMatchLiteralPatternEnumCustomEquals-skip] from enum import Enum class Medal(Enum): gold = 1 @@ -1383,7 +1343,7 @@ match m: case _: reveal_type(m) # N: Revealed type is "__main__.Medal" -[case testNarrowUsingPatternGuardSpecialCase] +[case testMatchNarrowUsingPatternGuardSpecialCase] def f(x: int | str) -> int: # E: Missing return statement match x: case x if isinstance(x, str): @@ -1392,7 +1352,7 @@ def f(x: int | str) -> int: # E: Missing return statement return 1 [builtins fixtures/isinstance.pyi] -[case testNarrowingDownUnionPartially] +[case testMatchNarrowDownUnionPartially] # flags: --strict-optional def f(x: int | str) -> None: @@ -1413,7 +1373,7 @@ def h(x: int | str | None) -> None: return reveal_type(x) # N: Revealed type is "None" -[case testNarrowDownUsingLiteralMatch] +[case testMatchNarrowDownUsingLiteralMatch] from enum import Enum class Medal(Enum): gold = 1 @@ -1481,3 +1441,46 @@ def g(c: C) -> int: case C(a=str()): return 1 assert False + +-- Misc + +[case testMatchAndWithStatementScope] +from m import A, B + +with A() as x: + pass +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +with A() as y: + pass +with B() as y: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +with A() as z: + pass +with B() as z: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +with A() as zz: + pass +with B() as zz: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +match x: + case str(y) as z: + zz = y + +[file m.pyi] +from typing import Any + +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... From 8eb2dc9e41d9661876b01a50155fdfaabb85e32c Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 14:04:06 +0000 Subject: [PATCH 15/23] Some clean-up --- mypy/checker.py | 24 ++++++++---------------- mypy/checkpattern.py | 3 ++- test-data/unit/check-python310.test | 6 +++--- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 1e74d5a546e7..a4ea32b84f51 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4089,18 +4089,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None: if isinstance(subject_type, DeletedType): self.msg.deleted_as_rvalue(subject_type, s) - # We have to check each pattern twice. Once ignoring the guard statement to infer - # the capture types and once with then to narrow the subject. - # In addition PatternChecker adds intersection types to the scope. We only want that - # to happen on the second pass, so we copy the SymbolTable beforehand. - curr_module = self.scope.stack[0] - assert isinstance(curr_module, MypyFile) - names = curr_module.names.copy() pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] - curr_module.names = names type_maps: List[TypeMap] = [t.captures for t in pattern_types] - inferred_names = self.infer_variable_types_from_type_maps(type_maps) + inferred_types = self.infer_variable_types_from_type_maps(type_maps) for p, g, b in zip(s.patterns, s.guards, s.bodies): current_subject_type = self.expr_checker.narrow_type_from_binder(s.subject, @@ -4118,7 +4110,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_type.rest_type ) self.remove_capture_conflicts(pattern_type.captures, - inferred_names) + inferred_types) self.push_type_map(pattern_map) self.push_type_map(pattern_type.captures) if g is not None: @@ -4144,7 +4136,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[Var, Type]: all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list) - inferred_names: Dict[Var, Type] = {} + inferred_types: Dict[Var, Type] = {} for tm in type_maps: if tm is not None: for expr, typ in tm.items(): @@ -4166,23 +4158,23 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[ msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, subtype_label="pattern captures type", supertype_label="variable has type"): - inferred_names[var] = previous_type + inferred_types[var] = previous_type if not already_exists: new_type = UnionType.make_union(types) # Infer the union type at the first occurrence first_occurrence, _ = captures[0] - inferred_names[var] = new_type + inferred_types[var] = new_type self.infer_variable_type(var, first_occurrence, new_type, first_occurrence) - return inferred_names + return inferred_types - def remove_capture_conflicts(self, type_map: TypeMap, inferred_names: Dict[Var, Type]) -> None: + def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: Dict[Var, Type]) -> None: if type_map is not None: for expr, typ in type_map.copy().items(): if isinstance(expr, NameExpr): node = expr.node assert isinstance(node, Var) - if node not in inferred_names or not is_subtype(typ, inferred_names[node]): + if node not in inferred_types or not is_subtype(typ, inferred_types[node]): del type_map[expr] def make_fake_typeinfo(self, diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 81ab81001421..a1e060ea3721 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,4 +1,5 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" + from collections import defaultdict from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union from typing_extensions import Final @@ -56,7 +57,7 @@ 'PatternType', [ ('type', Type), # The type the match subject can be narrowed to - ('rest_type', Type), + ('rest_type', Type), # The remaining type if the pattern didn't match ('captures', Dict[Expression, Type]), # The variables captured by the pattern ]) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index bd828b785055..dcf6d9350ff0 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -64,7 +64,7 @@ m: A match m: case b.b: - reveal_type(m) # N: Revealed type is "__main__." + reveal_type(m) # N: Revealed type is "__main__.1" [file b.py] class B: ... b: B @@ -804,9 +804,9 @@ m: B match m: case A(): - reveal_type(m) # N: Revealed type is "__main__." + reveal_type(m) # N: Revealed type is "__main__.2" case A(i, j): - reveal_type(m) # N: Revealed type is "__main__.1" + reveal_type(m) # N: Revealed type is "__main__.3" [builtins fixtures/tuple.pyi] [case testMatchClassPatternNonexistentKeyword] From 1de6aa3d2cfd1b8ad6c728edb0996e6aa674541e Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 14:18:04 +0000 Subject: [PATCH 16/23] More cleanup --- mypy/checker.py | 9 +++++++-- mypy/checkpattern.py | 3 +-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index a4ea32b84f51..dfb1a8f532d1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4089,11 +4089,16 @@ def visit_match_stmt(self, s: MatchStmt) -> None: if isinstance(subject_type, DeletedType): self.msg.deleted_as_rvalue(subject_type, s) + # We infer types of patterns twice. The first pass is used + # to infer the types of capture variables. The type of a + # capture variable may depend on multiple patterns (it + # will be a union of all capture types). This pass ignores + # guard expressions. pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] - type_maps: List[TypeMap] = [t.captures for t in pattern_types] inferred_types = self.infer_variable_types_from_type_maps(type_maps) + # The second pass narrows down the types and type checks bodies. for p, g, b in zip(s.patterns, s.guards, s.bodies): current_subject_type = self.expr_checker.narrow_type_from_binder(s.subject, subject_type) @@ -4136,7 +4141,6 @@ def visit_match_stmt(self, s: MatchStmt) -> None: def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[Var, Type]: all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list) - inferred_types: Dict[Var, Type] = {} for tm in type_maps: if tm is not None: for expr, typ in tm.items(): @@ -4145,6 +4149,7 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[ assert isinstance(node, Var) all_captures[node].append((expr, typ)) + inferred_types: Dict[Var, Type] = {} for var, captures in all_captures.items(): already_exists = False types: List[Type] = [] diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index a1e060ea3721..9c6e67db03e1 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -355,8 +355,7 @@ def expand_starred_pattern_types(self, star_pos: Optional[int], num_types: int ) -> List[Type]: - """ - Undoes the contraction done by contract_starred_pattern_types. + """Undoes the contraction done by contract_starred_pattern_types. For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended to lenght 4 the result is [bool, int, int, str]. From a79c4c277eaef09199ebe3ec06e369b03b4a15e8 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 14:39:20 +0000 Subject: [PATCH 17/23] Add failing test case --- test-data/unit/check-python310.test | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index dcf6d9350ff0..fcfcc1615799 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1442,6 +1442,19 @@ def g(c: C) -> int: return 1 assert False +[case testMatchAsPatternIntersection-skip] +class A: pass +class B: pass +class C: pass + +def f(x: A) -> None: + match x: + case B() as y: + reveal_type(y) # N: Revealed type is "__main__." + case C() as y: + reveal_type(y) # N: Revealed type is "__main__." + reveal_type(y) # N: Revealed type is "Union[__main__., __main__.]" + -- Misc [case testMatchAndWithStatementScope] From 14fd60c60a155b37403dba3315a94800d5db37bb Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 14:49:40 +0000 Subject: [PATCH 18/23] Add break/continue test case --- test-data/unit/check-python310.test | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index fcfcc1615799..784c4171c6e0 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1455,6 +1455,19 @@ def f(x: A) -> None: reveal_type(y) # N: Revealed type is "__main__." reveal_type(y) # N: Revealed type is "Union[__main__., __main__.]" +[case testMatchWithBreakAndContinue] +# flags: --strict-optional +def f(x: int | str | None) -> None: + i = int() + while i: + match x: + case int(): + continue + case str(): + break + reveal_type(x) # N: Revealed type is "None" + reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str, None]" + -- Misc [case testMatchAndWithStatementScope] From 4a96bca6f76dea7dbbe367dfbd63de28a0bd7f31 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 14:52:07 +0000 Subject: [PATCH 19/23] Add test case --- test-data/unit/check-python310.test | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 784c4171c6e0..350ccfddcd9d 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1442,6 +1442,14 @@ def g(c: C) -> int: return 1 assert False +[case testMatchAsPatternExhaustiveness] +def f(x: int | str) -> int: + match x: + case int() as n: + return n + case str() as s: + return 1 + [case testMatchAsPatternIntersection-skip] class A: pass class B: pass From e21b1ce996a02df83045e293782622ae1ec7e854 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 15:00:23 +0000 Subject: [PATCH 20/23] Add failing test case --- test-data/unit/check-python310.test | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 350ccfddcd9d..9d56aeb468f7 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1476,6 +1476,16 @@ def f(x: int | str | None) -> None: reveal_type(x) # N: Revealed type is "None" reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str, None]" +[case testMatchNarrowDownWithStarred-skip] +from typing import List +def f(x: List[int] | int) -> None: + match x: + case [*y]: + reveal_type(y) # N: Revealed type is "builtins.list[builtins.int*]" + return + reveal_type(x) # N: Revealed type is "builtins.int" +[builtins fixtures/list.pyi] + -- Misc [case testMatchAndWithStatementScope] From 0892388304c98caa3ab05d86e4f25b8d4d107fbd Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 1 Mar 2022 16:53:42 +0000 Subject: [PATCH 21/23] Small tweaks --- mypy/checker.py | 5 ++--- mypy/typeops.py | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index dfb1a8f532d1..ccdc3b35340f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4174,7 +4174,7 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[ return inferred_types def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: Dict[Var, Type]) -> None: - if type_map is not None: + if type_map: for expr, typ in type_map.copy().items(): if isinstance(expr, NameExpr): node = expr.node @@ -5675,8 +5675,7 @@ def conditional_types(current_type: Type, or isinstance(target.value, bool)): enum_name = target.fallback.type.fullname current_type = try_expanding_sum_type_to_union(current_type, - enum_name, - ignore_custom_equals=False) + enum_name) proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) if isinstance(proposed_type, AnyType): diff --git a/mypy/typeops.py b/mypy/typeops.py index 2ede9194ac26..57fdfeadad9a 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -683,10 +683,7 @@ def is_singleton_type(typ: Type) -> bool: ) -def try_expanding_sum_type_to_union(typ: Type, - target_fullname: str, - *, - ignore_custom_equals: bool = True) -> ProperType: +def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType: """Attempts to recursively expand any enum Instances with the given target_fullname into a Union of all of its component LiteralTypes. From 1d2a037115645bc58119b109f9a17e08b4edaa0b Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Thu, 3 Mar 2022 11:23:40 +0000 Subject: [PATCH 22/23] Remove unneeded change in subtype checking --- mypy/subtypes.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 0a84d0112ea8..a261e3712328 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1237,13 +1237,11 @@ def _is_proper_subtype(left: Type, right: Type, *, class ProperSubtypeVisitor(TypeVisitor[bool]): def __init__(self, right: Type, *, ignore_promotions: bool = False, - ignore_last_known_value: bool = False, erase_instances: bool = False, keep_erased_types: bool = False) -> None: self.right = get_proper_type(right) self.orig_right = right self.ignore_promotions = ignore_promotions - self.ignore_last_known_value = ignore_last_known_value self.erase_instances = erase_instances self.keep_erased_types = keep_erased_types self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( @@ -1299,10 +1297,6 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(right, Instance): if TypeState.is_cached_subtype_check(self._subtype_kind, left, right): return True - if not self.ignore_last_known_value: - if right.last_known_value is not None and \ - right.last_known_value != left.last_known_value: - return False if not self.ignore_promotions: for base in left.type.mro: if base._promote and self._is_proper_subtype(base._promote, right): From a2fa4aefca09f3df1662699a114bbba7872ced32 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Thu, 3 Mar 2022 11:25:04 +0000 Subject: [PATCH 23/23] Address feedback --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index ccdc3b35340f..31dcf985200b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4175,7 +4175,7 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: Dict[Var, Type]) -> None: if type_map: - for expr, typ in type_map.copy().items(): + for expr, typ in list(type_map.items()): if isinstance(expr, NameExpr): node = expr.node assert isinstance(node, Var)