8000 Allow E[<str>] where E is an Enum type. · python/mypy@5543b5a · GitHub
[go: up one dir, main page]

Skip to content

Commit 5543b5a

Browse files
author
Guido van Rossum
committed
Allow E[<str>] where E is an Enum type.
Fixes #1381.
1 parent 41bae63 commit 5543b5a

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

mypy/checkexpr.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,9 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type:
14391439
return AnyType()
14401440
elif isinstance(left_type, TypedDictType):
14411441
return self.visit_typeddict_index_expr(left_type, e.index)
1442+
elif (isinstance(left_type, CallableType)
1443+
and left_type.is_type_obj() and left_type.type_object()):
1444+
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
14421445
else:
14431446
result, method_type = self.check_op('__getitem__', left_type, e.index, e)
14441447
e.method_type = method_type
@@ -1497,6 +1500,16 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
14971500
return AnyType()
14981501
return item_type
14991502

1503+
def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression,
1504+
context: Context) -> Type:
1505+
string_type = self.named_type('builtins.str') # type: Type
1506+
if self.chk.options.python_version[0] < 3:
1507+
string_type = UnionType.make_union([string_type,
1508+
self.named_type('builtins.unicode')])
1509+
self.chk.check_subtype(self.accept(index), string_type, context,
1510+
"Enum index should be a string", "actual index type")
1511+
return Instance(enum_type, [])
1512+
15001513
def visit_cast_expr(self, expr: CastExpr) -> Type:
15011514
"""Type check a cast expression."""
15021515
source_type = self.accept(expr.expr, context=AnyType())

mypy/semanal.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2455,7 +2455,11 @@ def visit_unary_expr(self, expr: UnaryExpr) -> None:
24552455

24562456
def visit_index_expr(self, expr: IndexExpr) -> None:
24572457
expr.base.accept(self)
2458-
if isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS:
2458+
if (isinstance(expr.base, RefExpr)
2459+
and isinstance(expr.base.node, TypeInfo)
2460+
and expr.base.node.is_enum):
2461+
expr.index.accept(self)
2462+
elif isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS:
24592463
# Special form -- subscripting a generic type alias.
24602464
# Perform the type substitution and create a new alias.
24612465
res = analyze_type_alias(expr,

test-data/unit/pythoneval-enum.test

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def takes_some_ext_int_enum(s: SomeExtIntEnum):
119119
pass
120120
takes_some_ext_int_enum(SomeExtIntEnum.x)
121121

122-
123122
[case testNamedTupleEnum]
124123
from typing import NamedTuple
125124
from enum import Enum
@@ -132,3 +131,21 @@ class E(N, Enum):
132131
def f(x: E) -> None: pass
133132

134133
f(E.X)
134+
135+
[case testEnumCall]
136+
from enum import IntEnum
137+
class E(IntEnum):
138+
a = 1
139+
x = None # type: int
140+
reveal_type(E(x))
141+
[out]
142+
_program.py:5: error: Revealed type is '_testEnumCall.E'
143+
144+
[case testEnumIndex]
145+
from enum import IntEnum
146+
class E(IntEnum):
147+
a = 1
148+
s = None # type: str
149+
reveal_type(E[s])
150+
[out]
151+
_program.py:5: error: Revealed type is '_testEnumIndex.E'

0 commit comments

Comments
 (0)
0