From db11aa5c13884bf3a9005767d5a00c428eabae29 Mon Sep 17 00:00:00 2001 From: hauntsaninja <> Date: Thu, 11 Nov 2021 01:33:19 -0800 Subject: [PATCH] Allow narrowing enum values using == Resolves #10915, resolves #9786 See the discussion in #10915. I'm sympathetic to the difference between identity and equality here being surprising and that mypy doesn't usually make concessions to mutability when type checking. The old test cases are pretty explicit about their intentions and are worth reading. Curious to see what people (and mypy-primer) have to say about this. --- mypy/checker.py | 10 ++++- test-data/unit/check-narrowing.test | 57 +++++++++++++++++------------ 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 95789831fd6f..376d84cbcac0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4802,7 +4802,15 @@ def refine_identity_comparison_expression(self, """ should_coerce = True if coerce_only_in_literal_context: - should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices) + + def should_coerce_inner(typ: Type) -> bool: + typ = get_proper_type(typ) + return is_literal_type_like(typ) or ( + isinstance(typ, Instance) + and typ.type.is_enum + ) + + should_coerce = any(should_coerce_inner(operand_types[i]) for i in chain_indices) target: Optional[Type] = None possible_target_indices = [] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index d6b25ef456d9..5651ac7d5d90 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -699,47 +699,47 @@ class FlipFlopStr: def mutate(self) -> None: self.state = "state-2" if self.state == "state-1" else "state-1" -def test1(switch: FlipFlopEnum) -> None: + +def test1(switch: FlipFlopStr) -> None: # Naively, we might assume the 'assert' here would narrow the type to - # Literal[State.A]. However, doing this ends up breaking a fair number of real-world + # Literal["state-1"]. However, doing this ends up breaking a fair number of real-world # code (usually test cases) that looks similar to this function: e.g. checks # to make sure a field was mutated to some particular value. # # And since mypy can't really reason about state mutation, we take a conservative # approach and avoid narrowing anything here. - assert switch.state == State.A - reveal_type(switch.state) # N: Revealed type is "__main__.State" + assert switch.state == "state-1" + reveal_type(switch.state) # N: Revealed type is "builtins.str" switch.mutate() - assert switch.state == State.B - reveal_type(switch.state) # N: Revealed type is "__main__.State" + assert switch.state == "state-2" + reveal_type(switch.state) # N: Revealed type is "builtins.str" def test2(switch: FlipFlopEnum) -> None: - # So strictly speaking, we ought to do the same thing with 'is' comparisons - # for the same reasons as above. But in practice, not too many people seem to - # know that doing 'some_enum is MyEnum.Value' is idiomatic. So in practice, - # this is probably good enough for now. + # This is the same thing as 'test1', except we use enums, which we allow to be narrowed + # to literals. - assert switch.state is State.A + assert switch.state == State.A reveal_type(switch.state) # N: Revealed type is "Literal[__main__.State.A]" switch.mutate() - assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]") + assert switch.state == State.B # E: Non-overlapping equality check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]") reveal_type(switch.state) # E: Statement is unreachable -def test3(switch: FlipFlopStr) -> None: - # This is the same thing as 'test1', except we try using str literals. +def test3(switch: FlipFlopEnum) -> None: + # Same thing, but using 'is' comparisons. Previously mypy's behaviour differed + # here, narrowing when using 'is', but not when using '=='. - assert switch.state == "state-1" - reveal_type(switch.state) # N: Revealed type is "builtins.str" + assert switch.state is State.A + reveal_type(switch.state) # N: Revealed type is "Literal[__main__.State.A]" switch.mutate() - assert switch.state == "state-2" - reveal_type(switch.state) # N: Revealed type is "builtins.str" + assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]") + reveal_type(switch.state) # E: Statement is unreachable [builtins fixtures/primitives.pyi] [case testNarrowingEqualityRequiresExplicitStrLiteral] @@ -791,6 +791,7 @@ reveal_type(x_union) # N: Revealed type is "Union[Literal['A'], Literal['B' [case testNarrowingEqualityRequiresExplicitEnumLiteral] # flags: --strict-optional +from typing import Union from typing_extensions import Literal, Final from enum import Enum @@ -801,19 +802,19 @@ class Foo(Enum): A_final: Final = Foo.A A_literal: Literal[Foo.A] -# See comments in testNarrowingEqualityRequiresExplicitStrLiteral and -# testNarrowingEqualityFlipFlop for more on why we can't narrow here. +# Note this is unlike testNarrowingEqualityRequiresExplicitStrLiteral +# See also testNarrowingEqualityFlipFlop x1: Foo if x1 == Foo.A: - reveal_type(x1) # N: Revealed type is "__main__.Foo" + reveal_type(x1) # N: Revealed type is "Literal[__main__.Foo.A]" else: - reveal_type(x1) # N: Revealed type is "__main__.Foo" + reveal_type(x1) # N: Revealed type is "Literal[__main__.Foo.B]" x2: Foo if x2 == A_final: - reveal_type(x2) # N: Revealed type is "__main__.Foo" + reveal_type(x2) # N: Revealed type is "Literal[__main__.Foo.A]" else: - reveal_type(x2) # N: Revealed type is "__main__.Foo" + reveal_type(x2) # N: Revealed type is "Literal[__main__.Foo.B]" # But we let this narrow since there's an explicit literal in the RHS. x3: Foo @@ -821,6 +822,14 @@ if x3 == A_literal: reveal_type(x3) # N: Revealed type is "Literal[__main__.Foo.A]" else: reveal_type(x3) # N: Revealed type is "Literal[__main__.Foo.B]" + + +class SingletonFoo(Enum): + A = "A" + +def bar(x: Union[SingletonFoo, Foo], y: SingletonFoo) -> None: + if x == y: + reveal_type(x) # N: Revealed type is "Literal[__main__.SingletonFoo.A]" [builtins fixtures/primitives.pyi] [case testNarrowingEqualityDisabledForCustomEquality]