10000 Allow narrowing enum values using == by hauntsaninja · Pull Request #11521 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Allow narrowing enum values using == #11521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
10000 Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4802,7 +4802,15 @@ def refine_identity_comparison_expression(self,
"""
should_coerce = True
if coerce_only_in_literal_context:
should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices)

def should_coerce_inner(typ: Type) -> bool:
typ = get_proper_type(typ)
return is_literal_type_like(typ) or (
isinstance(typ, Instance)
and typ.type.is_enum
)

should_coerce = any(should_coerce_inner(operand_types[i]) for i in chain_indices)

target: Optional[Type] = None
possible_target_indices = []
Expand Down
57 changes: 33 additions & 24 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -699,47 +699,47 @@ class FlipFlopStr:
def mutate(self) -> None:
self.state = "state-2" if self.state == "state-1" else "state-1"

def test1(switch: FlipFlopEnum) -> None:

def test1(switch: FlipFlopStr) -> None:
# Naively, we might assume the 'assert' here would narrow the type to
# Literal[State.A]. However, doing this ends up breaking a fair number of real-world
# Literal["state-1"]. However, doing this ends up breaking a fair number of real-world
# code (usually test cases) that looks similar to this function: e.g. checks
# to make sure a field was mutated to some particular value.
#
# And since mypy can't really reason about state mutation, we take a conservative
# approach and avoid narrowing anything here.

assert switch.state == State.A
reveal_type(switch.state) # N: Revealed type is "__main__.State"
assert switch.state == "state-1"
reveal_type(switch.state) # N: Revealed type is "builtins.str"

switch.mutate()

assert switch.state == State.B
reveal_type(switch.state) # N: Revealed type is "__main__.State"
assert switch.state == "state-2"
reveal_type(switch.state) # N: Revealed type is "builtins.str"

def test2(switch: FlipFlopEnum) -> None:
# So strictly speaking, we ought to do the same thing with 'is' comparisons
# for the same reasons as above. But in practice, not too many people seem to
# know that doing 'some_enum is MyEnum.Value' is idiomatic. So in practice,
# this is probably good enough for now.
# This is the same thing as 'test1', except we use enums, which we allow to be narrowed
# to literals.

assert switch.state is State.A
assert switch.state == State.A
reveal_type(switch.state) # N: Revealed type is "Literal[__main__.State.A]"

switch.mutate()

assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]")
assert switch.state == State.B # E: Non-overlapping equality check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]")
reveal_type(switch.state) # E: Statement is unreachable

def test3(switch: FlipFlopStr) -> None:
# This is the same thing as 'test1', except we try using str literals.
def test3(switch: FlipFlopEnum) -> None:
# Same thing, but using 'is' comparisons. Previously mypy's behaviour differed
# here, narrowing when using 'is', but not when using '=='.

assert switch.state == "state-1"
reveal_type(switch.state) # N: Revealed type is "builtins.str"
assert switch.state is State.A
reveal_type(switch.state) # N: Revealed type is "Literal[__main__.State.A]"

switch.mutate()

assert switch.state == "state-2"
reveal_type(switch.state) # N: Revealed type is "builtins.str"
assert switch.state is State.B # E: Non-overlapping identity check (left operand type: "Literal[State.A]", right operand type: "Literal[State.B]")
reveal_type(switch.state) # E: Statement is unreachable
[builtins fixtures/primitives.pyi]

[case testNarrowingEqualityRequiresExplicitStrLiteral]
Expand Down Expand Up @@ -791,6 +791,7 @@ reveal_type(x_union) # N: Revealed type is "Union[Literal['A'], Literal['B'

[case testNarrowingEqualityRequiresExplicitEnumLiteral]
# flags: --strict-optional
from typing import Union
from typing_extensions import Literal, Final
from enum import Enum

Expand All @@ -801,26 +802,34 @@ class Foo(Enum):
A_final: Final = Foo.A
A_literal: Literal[Foo.A]

# See comments in testNarrowingEqualityRequiresExplicitStrLiteral and
# testNarrowingEqualityFlipFlop for more on why we can't narrow here.
# Note this is unlike testNarrowingEqualityRequiresExplicitStrLiteral
# See also testNarrowingEqualityFlipFlop
x1: Foo
if x1 == Foo.A:
reveal_type(x1) # N: Revealed type is "__main__.Foo"
reveal_type(x1) # N: Revealed type is "Literal[__main__.Foo.A]"
else:
reveal_type(x1) # N: Revealed type is "__main__.Foo"
reveal_type(x1) # N: Revealed type is "Literal[__main__.Foo.B]"

x2: Foo
if x2 == A_final:
reveal_type(x2) # N: Revealed type is "__main__.Foo"
reveal_type(x2) # N: Revealed type is "Literal[__main__.Foo.A]"
else:
reveal_type(x2) # N: Revealed type is "__main__.Foo"
reveal_type(x2) # N: Revealed type is "Literal[__main__.Foo.B]"

# But we let this narrow since there's an explicit literal in the RHS.
x3: Foo
if x3 == A_literal:
reveal_type(x3) # N: Revealed type is "Literal[__main__.Foo.A]"
else:
reveal_type(x3) # N: Revealed type is "Literal[__main__.Foo.B]"


class SingletonFoo(Enum):
A = "A"

def bar(x: Union[SingletonFoo, Foo], y: SingletonFoo) -> None:
if x == y:
reveal_type(x) # N: Revealed type is "Literal[__main__.SingletonFoo.A]"
[builtins fixtures/primitives.pyi]

[case testNarrowingEqualityDisabledForCustomEquality]
Expand Down
0