18
18
from mypy .types import (
19
19
Type , AnyType , CallableType , Overloaded , NoneTyp , TypeVarDef ,
20
20
TupleType , TypedDictType , Instance , TypeVarType , ErasedType , UnionType ,
21
- PartialType , DeletedType , UninhabitedType , TypeType , TypeOfAny , LiteralType ,
21
+ PartialType , DeletedType , UninhabitedType , TypeType , TypeOfAny , LiteralType , LiteralValue ,
22
22
true_only , false_only , is_named_instance , function_type , callable_type , FunctionLike ,
23
23
StarType , is_optional , remove_optional , is_generic_instance
24
24
)
@@ -139,6 +139,16 @@ def __init__(self,
139
139
self .msg = msg
140
140
self .plugin = plugin
141
141
self .type_context = [None ]
142
+
143
+ # Set to 'True' whenever we are checking the expression in some 'Final' declaration.
144
+ # For example, if we're checking the "3" in a statement like "var: Final = 3".
145
+ #
146
+ # This flag changes the type that eventually gets inferred for "var". Instead of
147
+ # inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
148
+ # of the underlying literal value. See the comments in Instance's constructors for
149
+ # more details.
150
+ self .in_final_declaration = False
151
+
142
152
# Temporary overrides for expression types. This is currently
143
153
# used by the union math in overloads.
144
154
# TODO: refactor this to use a pattern similar to one in
@@ -210,10 +220,12 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
210
220
211
221
def analyze_var_ref (self , var : Var , context : Context ) -> Type :
212
222
if var .type :
213
- if is_literal_type_like (self .type_context [- 1 ]) and var .name () in {'True' , 'False' }:
214
- return LiteralType (var .name () == 'True' , self .named_type ('builtins.bool' ))
215
- else :
216
- return var .type
223
+ if isinstance (var .type , Instance ):
224
+ if self .is_literal_context () and var .type .final_value is not None :
225
+ return var .type .final_value
226
+ if var .name () in {'True' , 'False' }:
227
+ return self .infer_literal_expr_type (var .name () == 'True' , 'builtins.bool' )
228
+ return var .type
217
229
else :
218
230
if not var .is_ready and self .chk .in_checked_function ():
219
231
self .chk .handle_cannot_determine_type (var .name (), context )
@@ -691,7 +703,8 @@ def check_call(self,
691
703
elif isinstance (callee , Instance ):
692
704
call_function = analyze_member_access ('__call__' , callee , context ,
693
705
False , False , False , self .msg ,
694
- original_type = callee , chk = self .chk )
706
+ original_type = callee , chk = self .chk ,
707
+ in_literal_context = self .is_literal_context ())
695
708
return self .check_call (call_function , args , arg_kinds , context , arg_names ,
696
709
callable_node , arg_messages )
697
710
elif isinstance (callee , TypeVarType ):
@@ -1755,7 +1768,8 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
1755
1768
original_type = self .accept (e .expr )
1756
1769
member_type = analyze_member_access (
1757
1770
e .name , original_type , e , is_lvalue , False , False ,
1758
- self .msg , original_type = original_type , chk = self .chk )
1771
+ self .msg , original_type = original_type , chk = self .chk ,
1772
+ in_literal_context = self .is_literal_context ())
1759
1773
return member_type
1760
1774
1761
1775
def analyze_external_member_access (self , member : str , base_type : Type ,
@@ -1765,35 +1779,57 @@ def analyz
10000
e_external_member_access(self, member: str, base_type: Type,
1765
1779
"""
1766
1780
# TODO remove; no private definitions in mypy
1767
1781
return analyze_member_access (member , base_type , context , False , False , False ,
1768
- self .msg , original_type = base_type , chk = self .chk )
1782
+ self .msg , original_type = base_type , chk = self .chk ,
1783
+ in_literal_context = self .is_literal_context ())
1784
+
1785
+ def is_literal_context (self ) -> bool :
1786
+ return is_literal_type_like (self .type_context [- 1 ])
1787
+
1788
+ def infer_literal_expr_type (self , value : LiteralValue , fallback_name : str ) -> Type :
1789
+ """Analyzes the given literal expression and determines if we should be
1790
+ inferring an Instance type, a Literal[...] type, or an Instance that
1791
+ remembers the original literal. We...
1792
+
1793
+ 1. ...Infer a normal Instance in most circumstances.
1794
+
1795
+ 2. ...Infer a Literal[...] if we're in a literal context. For example, if we
1796
+ were analyzing the "3" in "foo(3)" where "foo" has a signature of
1797
+ "def foo(Literal[3]) -> None", we'd want to infer that the "3" has a
1798
+ type of Literal[3] instead of Instance.
1799
+
1800
+ 3. ...Infer an Instance that remembers the original Literal if we're declaring
1801
+ a Final variable with an inferred type -- for example, "bar" in "bar: Final = 3"
1802
+ would be assigned an Instance that remembers it originated from a '3'. See
1803
+ the comments in Instance's constructor for more details.
1804
+ """
1805
+ typ = self .named_type (fallback_name )
1806
+ if self .is_literal_context ():
1807
+ return LiteralType (value = value , fallback = typ )
1808
+ elif self .in_final_declaration :
1809
+ return typ .copy_modified (final_value = LiteralType (
1810
+ value = value ,
1811
+ fallback = typ ,
1812
+ line = typ .line ,
1813
+ column = typ .column ,
1814
+ ))
1815
+ else :
1816
+ return typ
1769
1817
1770
1818
def visit_int_expr (self , e : IntExpr ) -> Type :
1771
1819
"""Type check an integer literal (trivial)."""
1772
- typ = self .named_type ('builtins.int' )
1773
- if is_literal_type_like (self .type_context [- 1 ]):
1774
- return LiteralType (value = e .value , fallback = typ )
1775
- return typ
1820
+ return self .infer_literal_expr_type (e .value , 'builtins.int' )
1776
1821
1777
1822
def visit_str_expr (self , e : StrExpr ) -> Type :
1778
1823
"""Type check a string literal (trivial)."""
1779
- typ = self .named_type ('builtins.str' )
1780
- if is_literal_type_like (self .type_context [- 1 ]):
1781
- return LiteralType (value = e .value , fallback = typ )
1782
- return typ
1824
+ return self .infer_literal_expr_type (e .value , 'builtins.str' )
1783
1825
1784
1826
def visit_bytes_expr (self , e : BytesExpr ) -> Type :
1785
1827
"""Type check a bytes literal (trivial)."""
1786
- typ = self .named_type ('builtins.bytes' )
1787
- if is_literal_type_like (self .type_context [- 1 ]):
1788
- return LiteralType (value = e .value , fallback = typ )
1789
- return typ
1828
+ return self .infer_literal_expr_type (e .value , 'builtins.bytes' )
1790
1829
1791
1830
def visit_unicode_expr (self , e : UnicodeExpr ) -> Type :
1792
1831
"""Type check a unicode literal (trivial)."""
1793
- typ = self .named_type ('builtins.unicode' )
1794
- if is_literal_type_like (self .type_context [- 1 ]):
1795
- return LiteralType (value = e .value , fallback = typ )
1796
- return typ
1832
+ return self .infer_literal_expr_type (e .value , 'builtins.unicode' )
1797
1833
1798
1834
def visit_float_expr (self , e : FloatExpr ) -> Type :
1799
1835
"""Type check a float literal (trivial)."""
@@ -1930,7 +1966,8 @@ def check_method_call_by_name(self,
1930
1966
"""
1931
1967
local_errors = local_errors or self .msg
1932
1968
method_type = analyze_member_access (method , base_type , context , False , False , True ,
1933
- local_errors , original_type = base_type , chk = self .chk )
1969
+ local_errors , original_type = base_type , chk = self .chk ,
1970
+ in_literal_context = self .is_literal_context ())
1934
1971
return self .check_method_call (
1935
1972
method , base_type , method_type , args , arg_kinds , context , local_errors )
1936
1973
@@ -1994,6 +2031,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
1994
2031
context = context ,
1995
2032
msg = local_errors ,
1996
2033
chk = self .chk ,
2034
+ in_literal_context = self .is_literal_context ()
1997
2035
)
1998
2036
if local_errors .is_errors ():
1999
2037
return None
@@ -2950,7 +2988,8 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
2950
2988
override_info = base ,
2951
2989
context = e ,
2952
2990
msg = self .msg ,
2953
- chk = self .chk )
2991
+ chk = self .chk ,
2992
+ in_literal_context = self .is_literal_context ())
2954
2993
assert False , 'unreachable'
2955
2994
else :
2956
2995
# Invalid super. This has been reported by the semantic analyzer.
@@ -3117,13 +3156,16 @@ def accept(self,
3117
3156
type_context : Optional [Type ] = None ,
3118
3157
allow_none_return : bool = False ,
3119
3158
always_allow_any : bool = False ,
3159
+ in_final_declaration : bool = False ,
3120
3160
) -> Type :
3121
3161
"""Type check a node in the given type context. If allow_none_return
3122
3162
is True and this expression is a call, allow it to return None. This
3123
3163
applies only to this expression and not any subexpressions.
3124
3164
"""
3125
3165
if node in self .type_overrides :
3126
3166
return self .type_overrides [node ]
3167
+ old_in_final_declaration = self .in_final_declaration
3168
+ self .in_final_declaration = in_final_declaration
3127
3169
self .type_context .append (type_context )
3128
3170
try :
3129
3171
if allow_none_return and isinstance (node , CallExpr ):
@@ -3136,6 +3178,7 @@ def accept(self,
3136
3178
report_internal_error (err , self .chk .errors .file ,
3137
3179
node .line , self .chk .errors , self .chk .options )
3138
3180
self .type_context .pop ()
3181
+ self .in_final_declaration = old_in_final_declaration
3139
3182
assert typ is not None
3140
3183
self .chk .store_type (node , typ )
3141
3184
0 commit comments