8000 Merge branch 'none-initializer-inference' · python/mypy@5a0f7ec · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a0f7ec

Browse files
committed
Merge branch 'none-initializer-inference'
2 parents cdcb777 + 7334c8f commit 5a0f7ec

File tree

5 files changed

+195
-55
lines changed

5 files changed

+195
-55
lines changed

mypy/checker.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,21 @@ def check_assignment(self, lvalue: Node, rvalue: Node, infer_lvalue_type: bool =
10091009
else:
10101010
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
10111011
if lvalue_type:
1012+
if isinstance(lvalue_type, PartialType) and lvalue_type.type is None:
1013+
# Try to infer a proper type for a variable with a partial None type.
1014+
rvalue_type = self.accept(rvalue)
1015+
if isinstance(rvalue_type, NoneTyp):
1016+
# This doesn't actually provide any additional information -- multiple
1017+
# None initializers preserve the partial None type.
1018+
return
1019+
if is_valid_inferred_type(rvalue_type):
1020+
lvalue_type.var.type = rvalue_type
1021+
partial_types = self.partial_types[-1]
1022+
del partial_types[lvalue_type.var]
1023+
# Try to infer a partial type. No need to check the return value, as
1024+
# an error will be reported elsewhere.
1025+
self.infer_partial_type(lvalue_type.var, lvalue, rvalue_type)
1026+
return
10121027
rvalue_type = self.check_simple_assignment(lvalue_type, rvalue, lvalue)
10131028

10141029
if rvalue_type and infer_lvalue_type:
@@ -1215,7 +1230,7 @@ def check_lvalue(self, lvalue: Node) -> Tuple[Type, IndexExpr, Var]:
12151230
True)
12161231
self.store_type(lvalue, lvalue_type)
12171232
elif isinstance(lvalue, NameExpr):
1218-
lvalue_type = self.expr_checker.analyze_ref_expr(lvalue)
1233+
lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True)
12191234
self.store_type(lvalue, lvalue_type)
12201235
elif isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr):
12211236
lv = cast(Union[TupleExpr, ListExpr], lvalue)
@@ -1270,19 +1285,23 @@ def infer_variable_type(self, name: Var, lvalue: Node,
12701285
self.set_inferred_type(name, lvalue, init_type)
12711286

12721287
def infer_partial_type(self, name: Var, lvalue: Node, init_type: Type) -> bool:
1273-
if not isinstance(init_type, Instance):
1288+
if isinstance(init_type, NoneTyp):
1289+
partial_type = PartialType(None, name)
1290+
elif isinstance(init_type, Instance):
1291+
fullname = init_type.type.fullname()
1292+
if ((fullname == 'builtins.list' or fullname == 'builtins.set' or
1293+
fullname == 'builtins.dict')
1294+
and isinstance(init_type.args[0], NoneTyp)
1295+
and (fullname != 'builtins.dict' or isinstance(init_type.args[1], NoneTyp))
1296+
and isinstance(lvalue, NameExpr)):
1297+
partial_type = PartialType(init_type.type, name)
1298+
else:
1299+
return False
1300+
else:
12741301
return False
1275-
fullname = init_type.type.fullname()
1276-
if ((fullname == 'builtins.list' or fullname == 'builtins.set' or
1277-
fullname == 'builtins.dict')
1278-
and isinstance(init_type.args[0], NoneTyp)
1279-
and (fullname != 'builtins.dict' or isinstance(init_type.args[1], NoneTyp))
1280-
and isinstance(lvalue, NameExpr)):
1281-
partial_type = PartialType(init_type.type, name)
1282-
self.set_inferred_type(name, lvalue, partial_type)
1283-
self.partial_types[-1][name] = lvalue
1284-
return True
1285-
return False
1302+
self.set_inferred_type(name, lvalue, partial_type)
1303+
self.partial_types[-1][name] = lvalue
1304+
return True
12861305

12871306
def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None:
12881307
"""Store inferred variable type.

mypy/checkexpr.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,24 @@ def visit_name_expr(self, e: NameExpr) -> Type:
6767
result = self.analyze_ref_expr(e)
6868
return self.chk.narrow_type_from_binder(e, result)
6969

70-
def analyze_ref_expr(self, e: RefExpr) -> Type:
70+
def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
7171
result = None # type: Type
7272
node = e.node
7373
if isinstance(node, Var):
7474
# Variable reference.
7575
result = self.analyze_var_ref(node, e)
7676
if isinstance(result, PartialType):
77-
partial_types = self.chk.partial_types[-1]
78-
if node in partial_types:
79-
context = partial_types[node]
80-
self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
81-
result = AnyType()
77+
if result.type is None:
78+
# 'None' partial type. It has a well-defined type. In an lvalue context
79+
# we want to preserve the knowledge of it being a partial type.
80+
if not lvalue:
81+
result = NoneTyp()
82+
else:
83+
partial_types = self.chk.partial_types[-1]
84+
if node in partial_types:
85+
context = partial_types[node]
86+
self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
87+
result = AnyType()
8288
elif isinstance(node, FuncDef):
8389
# Reference to a global function.
8490
result = function_type(node, self.named_type('builtins.function'))
@@ -143,7 +149,11 @@ def try_infer_partial_type(self, e: CallExpr) -> None:
143149
var = e.callee.expr.node
144150
if var in partial_types:
145151
var = cast(Var, var)
146-
typename = cast(Instance, var.type).type.fullname()
152+
partial_type_type = cast(PartialType, var.type).type
153+
if partial_type_type is None:
154+
# A partial None type -> can't infer anything.
155+
return
156+
typename = partial_type_type.fullname()
147157
methodname = e.callee.name
148158
# Sometimes we can infer a full type for a partial List, Dict or Set type.
149159
# TODO: Don't infer argument expression twice.

mypy/parse.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def parse_import(self) -> Import:
189189
while True:
190190
id = self.parse_qualified_name()
191191
translated = self.translate_module_id(id)
192-
as_id = None # type: Optional[str]
192+
as_id = None
193193
if self.current_str() == 'as':
194194
self.expect('as')
195195
name_tok = self.expect_type(Name)
@@ -319,7 +319,7 @@ def parse_class_def(self) -> ClassDef:
319319
self.is_class_body = True
320320

321321
self.expect('class')
322-
metaclass = None # type: str
322+
metaclass = None
323323

324324
try:
325325
commas, base_types = [], [] # type: List[Token], List[Node]
@@ -520,7 +520,7 @@ def parse_args(self, no_type_checks: bool=False) -> Tuple[List[Argument],
520520
self.skip()
521521
if no_type_checks:
522522
self.parse_expression()
523-
ret_type = None # type: Type
523+
ret_type = None
524524
else:
525525
ret_type = self.parse_type()
526526
else:
@@ -640,7 +640,7 @@ def parse_asterisk_arg(self,
640640
else:
641641
kind = nodes.ARG_STAR2
642642

643-
type = None # type: Type
643+
type = None
644644
if no_type_checks:
645645
self.parse_parameter_annotation()
646646
else:
@@ -674,7 +674,7 @@ def parse_tuple_arg(self, index: int) -> Tuple[Argument, AssignmentStmt, List[st
674674
decompose = AssignmentStmt([paren_arg], rvalue)
675675
decompose.set_line(line)
676676
kind = nodes.ARG_POS
677-
initializer = None # type: Optional[Node]
677+
initializer = None
678678
if self.current_str() == '=':
679679
self.expect('=')
680680
initializer = self.parse_expression(precedence[','])
@@ -707,7 +707,7 @@ def parse_normal_arg(self, require_named: bool,
707707
name = self.expect_type(Name)
708708
variable = Var(name.string)
709709

710-
type = None # type: Type
710+
type = None
711711
if no_type_checks:
712712
self.parse_parameter_annotation()
713713
else:
@@ -936,7 +936,7 @@ def parse_assignment(self, lvalue: Any) -> Node:
936936

937937
def parse_return_stmt(self) -> ReturnStmt:
938938
self.expect('return')
939-
expr = None # type: Node
939+
expr = None
940940
current = self.current()
941941
if current.string == 'yield':
942942
self.parse_error()
@@ -947,8 +947,8 @@ def parse_return_stmt(self) -> ReturnStmt:
947947

948948
def parse_raise_stmt(self) -> RaiseStmt:
949949
self.expect('raise')
950-
expr = None # type: Node
951-
from_expr = None # type: Node
950+
expr = None
951+
from_expr = None
952952
if not isinstance(self.current(), Break):
953953
expr = self.parse_expression()
954954
if self.current_str() == 'from':
@@ -965,7 +965,7 @@ def parse_assert_stmt(self) -> AssertStmt:
965965

966966
def parse_yield_stmt(self) -> Union[YieldStmt, YieldFromStmt]:
967967
self.expect('yield')
968-
expr = None # type: Node
968+
expr = None
969969
node = YieldStmt(expr)
970970
if not isinstance(self.current(), Break):
971971
if self.current_str() == "from":
@@ -1205,7 +1205,7 @@ def parse_with_stmt(self) -> WithStmt:
12051205
def parse_print_stmt(self) -> PrintStmt:
12061206
self.expect('print')
12071207
args = []
1208-
target = None # type: Node
1208+
target = None
12091209
if self.current_str() == '>>':
12101210
self.skip()
12111211
target = self.parse_expression(precedence[','])
@@ -1230,8 +1230,8 @@ def parse_print_stmt(self) -> PrintStmt:
12301230
def parse_exec_stmt(self) -> ExecStmt:
12311231
self.expect('exec')
12321232
expr = self.parse_expression(precedence['in'])
1233-
variables1 = None # type: Optional[Node]
1234-
variables2 = None # type: Optional[Node]
1233+
variables1 = None
1234+
variables2 = None
12351235
if self.current_str() == 'in':
12361236
self.skip()
12371237
variables1 = self.parse_expression(precedence[','])
@@ -1419,7 +1419,7 @@ def parse_comp_for(self) -> Tuple[List[Node], List[Node], List[List[Node]]]:
14191419
sequences = []
14201420
condlists = [] # type: List[List[Node]]
14211421
while self.current_str() == 'for':
1422-
conds = [] # type: List[Node]
1422+
conds = []
14231423
self.expect('for')
14241424
index = self.parse_for_index_variables()
14251425
indices.append(index)
@@ -1549,9 +1549,8 @@ def parse_str_expr(self) -> Node:
15491549
elif isinstance(token, UnicodeLit):
15501550
value += token.parsed()
15511551
is_unicode = True
1552-
node = None # type: Node
15531552
if is_unicode or (self.pyversion[0] == 2 and 'unicode_literals' in self.future_options):
1554-
node = UnicodeExpr(value)
1553+
node = UnicodeExpr(value) # type: Node
15551554
else:
15561555
node = StrExpr(value)
15571556
return node
@@ -1649,11 +1648,10 @@ def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str]]:
16491648
def parse_member_expr(self, expr: Any) -> Node:
16501649
self.expect('.')
16511650
name = self.expect_type(Name)
1652-
node = None # type: Node
16531651
if (isinstance(expr, CallExpr) and isinstance(expr.callee, NameExpr)
16541652
and cast(NameExpr, expr.callee).name == 'super'):
16551653
# super() expression
1656-
node = SuperExpr(name.string)
1654+
node = SuperExpr(name.string) # type: Node
16571655
else:
16581656
node = MemberExpr(expr, name.string)
16591657
return node
@@ -1695,7 +1693,7 @@ def parse_slice_item(self) -> Node:
16951693
end_index = self.parse_expression(precedence[','])
16961694
else:
16971695
end_index = None
1698-
stride = None # type: Node
1696+
stride = None
16991697
if self.current_str() == ':':
17001698
self.expect(':')
17011699
if self.current_str() not in (']', ','):
@@ -1714,7 +1712,7 @@ def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr:
17141712
return node
17151713

17161714
def parse_comparison_expr(self, left: Node, prec: int) -> ComparisonExpr:
1717-
operators_str = [] # type: List[str]
1715+
operators_str = []
17181716
operands = [left]
17191717

17201718
while True:

0 commit comments

Comments
 (0)
0