8000 Add interactions between Literal and Final (#6081) · python/mypy@94fe11c · GitHub
[go: up one dir, main page]

Skip to content

Commit 94fe11c

Browse files
authored
Add interactions between Literal and Final (#6081)
This pull request adds logic to handle interactions between Literal and Final. In short, if the user were to define a variable like `x: Final = 3` and latter do `some_func(x)`, mypy will attempt to type-check the code almost as if the user had done `some_func(3)` instead. This normally does not make a difference, except when type-checking code using literal types. For example, if `some_func` accepts a `Literal[3]` up above, the code would type-check since `x` cannot be anything other then a `3`. Or to put it another way, this pull request makes variables that use `Final` with the type omitted context-sensitive.
1 parent fd048ab commit 94fe11c

15 files changed

+661
-69
lines changed

mypy/checker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,8 +1810,11 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
18101810
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
18111811

18121812
if inferred:
1813-
self.infer_variable_type(inferred, lvalue, self.expr_checker.accept(rvalue),
1814-
rvalue)
1813+
rvalue_type = self.expr_checker.accept(
1814+
rvalue,
1815+
in_final_declaration=inferred.is_final,
1816+
)
1817+
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
18151818

18161819
def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
18171820
rvalue: Expression) -> bool:

mypy/checkexpr.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from mypy.types import (
1919
Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef,
2020
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
21-
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType,
21+
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
2222
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
2323
StarType, is_optional, remove_optional, is_generic_instance
2424
)
@@ -139,6 +139,16 @@ def __init__(self,
139139
self.msg = msg
140140
self.plugin = plugin
141141
self.type_context = [None]
142+
143+
# Set to 'True' whenever we are checking the expression in some 'Final' declaration.
144+
# For example, if we're checking the "3" in a statement like "var: Final = 3".
145+
#
146+
# This flag changes the type that eventually gets inferred for "var". Instead of
147+
# inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
148+
# of the underlying literal value. See the comments in Instance's constructors for
149+
# more details.
150+
self.in_final_declaration = False
151+
142152
# Temporary overrides for expression types. This is currently
143153
# used by the union math in overloads.
144154
# TODO: refactor this to use a pattern similar to one in
@@ -210,10 +220,12 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
210220

211221
def analyze_var_ref(self, var: Var, context: Context) -> Type:
212222
if var.type:
213-
if is_literal_type_like(self.type_context[-1]) and var.name() in {'True', 'False'}:
214-
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
215-
else:
216-
return var.type
223+
if isinstance(var.type, Instance):
224+
if self.is_literal_context() and var.type.final_value is not None:
225+
return var.type.final_value
226+
if var.name() in {'True', 'False'}:
227+
return self.infer_literal_expr_type(var.name() == 'True', 'builtins.bool')
228+
return var.type
217229
else:
218230
if not var.is_ready and self.chk.in_checked_function():
219231
self.chk.handle_cannot_determine_type(var.name(), context)
@@ -691,7 +703,8 @@ def check_call(self,
691703
elif isinstance(callee, Instance):
692704
call_function = analyze_member_access('__call__', callee, context,
693705
False, False, False, self.msg,
694-
original_type=callee, chk=self.chk)
706+
original_type=callee, chk=self.chk,
707+
in_literal_context=self.is_literal_context())
695708
return self.check_call(call_function, args, arg_kinds, context, arg_names,
696709
callable_node, arg_messages)
697710
elif isinstance(callee, TypeVarType):
@@ -1755,7 +1768,8 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
17551768
original_type = self.accept(e.expr)
17561769
member_type = analyze_member_access(
17571770
e.name, original_type, e, is_lvalue, False, False,
1758-
self.msg, original_type=original_type, chk=self.chk)
1771+
self.msg, original_type=original_type, chk=self.chk,
1772+
in_literal_context=self.is_literal_context())
17591773
return member_type
17601774

17611775
def analyze_external_member_access(self, member: str, base_type: Type,
@@ -1765,35 +1779,57 @@ def analyz 10000 e_external_member_access(self, member: str, base_type: Type,
17651779
"""
17661780
# TODO remove; no private definitions in mypy
17671781
return analyze_member_access(member, base_type, context, False, False, False,
1768-
self.msg, original_type=base_type, chk=self.chk)
1782+
self.msg, original_type=base_type, chk=self.chk,
1783+
in_literal_context=self.is_literal_context())
1784+
1785+
def is_literal_context(self) -> bool:
1786+
return is_literal_type_like(self.type_context[-1])
1787+
1788+
def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Type:
1789+
"""Analyzes the given literal expression and determines if we should be
1790+
inferring an Instance type, a Literal[...] type, or an Instance that
1791+
remembers the original literal. We...
1792+
1793+
1. ...Infer a normal Instance in most circumstances.
1794+
1795+
2. ...Infer a Literal[...] if we're in a literal context. For example, if we
1796+
were analyzing the "3" in "foo(3)" where "foo" has a signature of
1797+
"def foo(Literal[3]) -> None", we'd want to infer that the "3" has a
1798+
type of Literal[3] instead of Instance.
1799+
1800+
3. ...Infer an Instance that remembers the original Literal if we're declaring
1801+
a Final variable with an inferred type -- for example, "bar" in "bar: Final = 3"
1802+
would be assigned an Instance that remembers it originated from a '3'. See
1803+
the comments in Instance's constructor for more details.
1804+
"""
1805+
typ = self.named_type(fallback_name)
1806+
if self.is_literal_context():
1807+
return LiteralType(value=value, fallback=typ)
1808+
elif self.in_final_declaration:
1809+
return typ.copy_modified(final_value=LiteralType(
1810+
value=value,
1811+
fallback=typ,
1812+
line=typ.line,
1813+
column=typ.column,
1814+
))
1815+
else:
1816+
return typ
17691817

17701818
def visit_int_expr(self, e: IntExpr) -> Type:
17711819
"""Type check an integer literal (trivial)."""
1772-
typ = self.named_type('builtins.int')
1773-
if is_literal_type_like(self.type_context[-1]):
1774-
return LiteralType(value=e.value, fallback=typ)
1775-
return typ
1820+
return self.infer_literal_expr_type(e.value, 'builtins.int')
17761821

17771822
def visit_str_expr(self, e: StrExpr) -> Type:
17781823
"""Type check a string literal (trivial)."""
1779-
typ = self.named_type('builtins.str')
1780-
if is_literal_type_like(self.type_context[-1]):
1781-
return LiteralType(value=e.value, fallback=typ)
1782-
return typ
1824+
return self.infer_literal_expr_type(e.value, 'builtins.str')
17831825

17841826
def visit_bytes_expr(self, e: BytesExpr) -> Type:
17851827
"""Type check a bytes literal (trivial)."""
1786-
typ = self.named_type('builtins.bytes')
1787-
if is_literal_type_like(self.type_context[-1]):
1788-
return LiteralType(value=e.value, fallback=typ)
1789-
return typ
1828+
return self.infer_literal_expr_type(e.value, 'builtins.bytes')
17901829

17911830
def visit_unicode_expr(self, e: UnicodeExpr) -> Type:
17921831
"""Type check a unicode literal (trivial)."""
1793-
typ = self.named_type('builtins.unicode')
1794-
if is_literal_type_like(self.type_context[-1]):
1795-
return LiteralType(value=e.value, fallback=typ)
1796-
return typ
1832+
return self.infer_literal_expr_type(e.value, 'builtins.unicode')
17971833

17981834
def visit_float_expr(self, e: FloatExpr) -> Type:
17991835
"""Type check a float literal (trivial)."""
@@ -1930,7 +1966,8 @@ def check_method_call_by_name(self,
19301966
"""
19311967
local_errors = local_errors or self.msg
19321968
method_type = analyze_member_access(method, base_type, context, False, False, True,
1933-
local_errors, original_type=base_type, chk=self.chk)
1969+
local_errors, original_type=base_type, chk=self.chk,
1970+
in_literal_context=self.is_literal_context())
19341971
return self.check_method_call(
19351972
method, base_type, method_type, args, arg_kinds, context, local_errors)
19361973

@@ -1994,6 +2031,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
19942031
context=context,
19952032
msg=local_errors,
19962033
chk=self.chk,
2034+
in_literal_context=self.is_literal_context()
19972035
)
19982036
if local_errors.is_errors():
19992037
return None
@@ -2950,7 +2988,8 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
29502988
override_info=base,
29512989
context=e,
29522990
msg=self.msg,
2953-
chk=self.chk)
2991+
chk=self.chk,
2992+
in_literal_context=self.is_literal_context())
29542993
assert False, 'unreachable'
29552994
else:
29562995
# Invalid super. This has been reported by the semantic analyzer.
@@ -3117,13 +3156,16 @@ def accept(self,
31173156
type_context: Optional[Type] = None,
31183157
allow_none_return: bool = False,
31193158
always_allow_any: bool = False,
3159+
in_final_declaration: bool = False,
31203160
) -> Type:
31213161
"""Type check a node in the given type context. If allow_none_return
31223162
is True and this expression is a call, allow it to return None. This
31233163
applies only to this expression and not any subexpressions.
31243164
"""
31253165
if node in self.type_overrides:
31263166
return self.type_overrides[node]
3167+
old_in_final_declaration = self.in_final_declaration
3168+
self.in_final_declaration = in_final_declaration
31273169
self.type_context.append(type_context)
31283170
try:
31293171
if allow_none_return and isinstance(node, CallExpr):
@@ -3136,6 +3178,7 @@ def accept(self,
31363178
report_internal_error(err, self.chk.errors.file,
31373179
node.line, self.chk.errors, self.chk.options)
31383180
self.type_context.pop()
3181+
self.in_final_declaration = old_in_final_declaration
31393182
assert typ is not None
31403183
self.chk.store_type(node, typ)
31413184

mypy/checkmember.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def analyze_member_access(name: str,
7171
msg: MessageBuilder, *,
7272
original_type: Type,
7373
chk: 'mypy.checker.TypeChecker',
74-
override_info: Optional[TypeInfo] = None) -> Type:
74+
override_info: Optional[TypeInfo] = None,
75+
in_literal_context: bool = False) -> Type:
7576
"""Return the type of attribute 'name' of 'typ'.
7677
7778
The actual implementation is in '_analyze_member_access' and this docstring
@@ -96,7 +97,11 @@ def analyze_member_access(name: str,
9697
context,
9798
msg,
9899
chk=chk)
99-
return _analyze_member_access(name, typ, mx, override_info)
100+
result = _analyze_member_access(name, typ, mx, override_info)
101+
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
102+
return result.final_value
103+
else:
104+
return result
100105

101106

102107
def _analyze_member_access(name: str,

mypy/fixup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def visit_instance(self, inst: Instance) -> None:
155155
base.accept(self)
156156
for a in inst.args:
157157
a.accept(self)
158+
if inst.final_value is not None:
159+
inst.final_value.accept(self)
158160

159161
def visit_any(self, o: Any) -> None:
160162
pass # Nothing to descend into.

mypy/sametypes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def visit_deleted_type(self, left: DeletedType) -> bool:
7777
def visit_instance(self, left: Instance) -> bool:
7878
return (isinstance(self.right, Instance) and
7979
left.type == self.right.type and
80-
is_same_types(left.args, self.right.args))
80+
is_same_types(left.args, self.right.args) and
81+
left.final_value == self.right.final_value)
8182

8283
def visit_type_var(self, left: TypeVarType) -> bool:
8384
return (isinstance(self.right, TypeVarType) and

mypy/semanal.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from mypy.messages import CANNOT_ASSIGN_TO_TYPE, MessageBuilder
6666
from mypy.types import (
6767
FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, function_type,
68-
CallableType, Overloaded, Instance, Type, AnyType,
68+
CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue,
6969
TypeTranslator, TypeOfAny, TypeType, NoneTyp,
7070
)
7171
from mypy.nodes import implicit_module_attrs
@@ -1760,9 +1760,9 @@ def final_cb(keep_final: bool) -> None:
17601760
self.type and self.type.is_protocol and not self.is_func_scope()):
17611761
self.fail('All protocol members must have explicitly declared types', s)
17621762
# Set the type if the rvalue is a simple literal (even if the above error occurred).
1763-
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr):
1763+
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr):
17641764
if s.lvalues[0].is_inferred_def:
1765-
s.type = self.analyze_simple_literal_type(s.rvalue)
1765+
s.type = self.analyze_simple_literal_type(s.rvalue, s.is_final_def)
17661766
if s.type:
17671767
# Store type into nodes.
17681768
for lvalue in s.lvalues:
@@ -1900,8 +1900,10 @@ def unbox_literal(self, e: Expression) -> Optional[Union[int, float, bool, str]]
19001900
return True if e.name == 'True' else False
19011901
return None
19021902

1903-
def analyze_simple_literal_type(self, rvalue: Expression) -> Optional[Type]:
1904-
"""Return builtins.int if rvalue is an int literal, etc."""
1903+
def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Optional[Type]:
1904+
"""Return builtins.int if rvalue is an int literal, etc.
1905+
1906+
If this is a 'Final' context, we return "Literal[...]" instead."""
19051907
if self.options.semantic_analysis_only or self.function_stack:
19061908
# Skip this if we're only doing the semantic analysis pass.
19071909
# This is mostly to avoid breaking unit tests.
@@ -1910,16 +1912,31 @@ def analyze_simple_literal_type(self, rvalue: Expression) -> Optional[Type]:
19101912
# inside type variables with value restrictions (like
19111913
# AnyStr).
19121914
return None
1913-
if isinstance(rvalue, IntExpr):
1914-
return self.named_type_or_none('builtins.int')
19151915
if isinstance(rvalue, FloatExpr):
19161916
return self.named_type_or_none('builtins.float')
1917+
1918+
value = None # type: LiteralValue
1919+
type_name = None # type: Optional[str]
1920+
if isinstance(rvalue, IntExpr):
1921+
value, type_name = rvalue.value, 'builtins.int'
19171922
if isinstance(rvalue, StrExpr):
1918-
return self.named_type_or_none('builtins.str')
1923+
value, type_name = rvalue.value, 'builtins.str'
19191924
if isinstance(rvalue, BytesExpr):
1920-
return self.named_type_or_none('builtins.bytes')
1925+
value, type_name = rvalue.value, 'builtins.bytes'
19211926
if isinstance(rvalue, UnicodeExpr):
1922-
return self.named_type_or_none('builtins.unicode')
1927+
value, type_name = rvalue.value, 'builtins.unicode'
1928+
1929+
if type_name is not None:
1930+
typ = self.named_type_or_none(type_name)
1931+
if typ and is_final:
1932+
return typ.copy_modified(final_value=LiteralType(
1933+
value=value,
1934+
fallback=typ,
1935+
line=typ.line,
1936+
column=typ.column,
1937+
))
1938+
return typ
1939+
19231940
return None
19241941

19251942
def analyze_alias(self, rvalue: Expression) -> Tuple[Optional[Type], List[str],

mypy/server/astdiff.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
284284
def visit_instance(self, typ: Instance) -> SnapshotItem:
285285
return ('Instance',
286286
typ.type.fullname(),
287-
snapshot_types(typ.args))
287+
snapshot_types(typ.args),
288+
None if typ.final_value is None else snapshot_type(typ.final_value))
288289

289290
def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
290291
return ('TypeVar',

mypy/server/astmerge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ def visit_instance(self, typ: Instance) -> None:
342342
typ.type = self.fixup(typ.type)
343343
for arg in typ.args:
344344
arg.accept(self)
345+
if typ.final_value:
346+
typ.final_value.accept(self)
345347

346348
def visit_any(self, typ: AnyType) -> None:
347349
pass

mypy/server/deps.py

Lines changed: 2 additions & 0 deletions
538E
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,8 @@ def visit_instance(self, typ: Instance) -> List[str]:
882882
triggers = [trigger]
883883
for arg in typ.args:
884884
triggers.extend(self.get_type_triggers(arg))
885+
if typ.final_value:
886+
triggers.extend(self.get_type_triggers(typ.final_value))
885887
return triggers
886888

887889
def visit_any(self, typ: AnyType) -> List[str]:

mypy/type_visitor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from abc import abstractmethod
1515
from collections import OrderedDict
16-
from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable
16+
from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional
1717
from mypy_extensions import trait
1818

1919
T = TypeVar('T')
@@ -159,7 +159,18 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
159159
return t
160160

161161
def visit_instance(self, t: Instance) -> Type:
162-
return Instance(t.type, self.translate_types(t.args), t.line, t.column)
162+
final_value = None # type: Optional[LiteralType]
163+
if t.final_value is not None:
164+
raw_final_value = t.final_value.accept(self)
165+
assert isinstance(raw_final_value, LiteralType)
166+
final_value = raw_final_value
167+
return Instance(
168+
typ=t.type,
169+
args=self.translate_types(t.args),
170+
line=t.line,
171+
column=t.column,
172+
final_value=final_value,
173+
)
163174

164175
def visit_type_var(self, t: TypeVarType) -> Type:
165176
return t

mypy/typeanal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
697697
elif isinstance(arg, (NoneTyp, LiteralType)):
698698
# Types that we can just add directly to the literal/potential union of literals.
699699
return [arg]
700+
elif isinstance(arg, Instance) and arg.final_value is not None:
701+
# Types generated from declarations like "var: Final = 4".
702+
return [arg.final_value]
700703
elif isinstance(arg, UnionType):
701704
out = []
702705
for union_arg in arg.items:

0 commit comments

Comments
 (0)
0