8000 Merge branch 'collection-type-infer' · python/mypy@b470951 · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b470951

Browse files
committed
Merge branch 'collection-type-infer'
Partially addresses #1055.
2 parents 1d04eb7 + 8af8d14 commit b470951

19 files changed

+360
-60
lines changed

mypy/checker.py

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from mypy.types import (
3030
Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType,
3131
Instance, NoneTyp, UnboundType, ErrorType, TypeTranslator, strip_type,
32-
UnionType, TypeVarType,
32+
UnionType, TypeVarType, PartialType
3333
)
3434
from mypy.sametypes import is_same_type
3535
from mypy.messages import MessageBuilder
@@ -332,7 +332,8 @@ class TypeChecker(NodeVisitor[Type]):
332332
breaking_out = False
333333
# Do weak type checking in this file
334334
weak_opts = set() # type: Set[str]
335-
335+
# Stack of collections of variables with partial types
336+
partial_types = None # type: List[Dict[Var, Context]]
336337
globals = None # type: SymbolTable
337338
locals = None # type: SymbolTable
338339
modules = None # type: Dict[str, MypyFile]
@@ -358,6 +359,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile],
358359
self.dynamic_funcs = []
359360
self.function_stack = []
360361
self.weak_opts = set() # type: Set[str]
362+
self.partial_types = []
361363

362364
def visit_file(self, file_node: MypyFile, path: str) -> None:
363365
"""Type check a mypy file with the given path."""
@@ -367,10 +369,12 @@ def visit_file(self, file_node: MypyFile, path: str) -> None:
367369
self.globals = file_node.names
368370
self.locals = None
369371
self.weak_opts = file_node.weak_opts
372+
self.enter_partial_types()
370373

371374
for d in file_node.defs:
372375
self.accept(d)
373376

377+
self.leave_partial_types()
374378
self.errors.set_ignored_lines(set())
375379

376380
def accept(self, node: Node, type_context: Type = None) -> Type:
@@ -461,6 +465,8 @@ def check_func_item(self, defn: FuncItem,
461465
if fdef:
462466
self.errors.push_function(fdef.name())
463467

468+
self.enter_partial_types()
469+
464470
typ = self.function_type(defn)
465471
if type_override:
466472
typ = type_override
@@ -469,6 +475,8 @@ def check_func_item(self, defn: FuncItem,
469475
else:
470476
raise RuntimeError('Not supported')
471477

478+
self.leave_partial_types()
479+
472480
if fdef:
473481
self.errors.pop_function()
474482

@@ -864,12 +872,14 @@ def visit_class_def(self, defn: ClassDef) -> Type:
864872
"""Type check a class definition."""
865873
typ = defn.info
866874
self.errors.push_type(defn.name)
875+
self.enter_partial_types()
867876
old_binder = self.binder
868877
self.binder = ConditionalTypeBinder()
869878
self.binder.push_frame()
870879
self.accept(defn.defs)
871880
self.binder = old_binder
872881
self.check_multiple_inheritance(typ)
882+
self.leave_partial_types()
873883
self.errors.pop_type()
874884

875885
def check_multiple_inheritance(self, typ: TypeInfo) -> None:
@@ -1237,11 +1247,14 @@ def infer_variable_type(self, name: Var, lvalue: Node,
12371247
elif isinstance(init_type, Void):
12381248
self.check_not_void(init_type, context)
12391249
self.set_inference_error_fallback_type(name, lvalue, init_type, context)
1240-
elif not self.is_valid_inferred_type(init_type):
1241-
# We cannot use the type of the initialization expression for type
1242-
# inference (it's not specific enough).
1243-
self.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
1244-
self.set_inference_error_fallback_type(name, lvalue, init_type, context)
1250+
elif not is_valid_inferred_type(init_type):
1251+
# We cannot use the type of the initialization expression for full type
1252+
# inference (it's not specific enough), but we might be able to give
1253+
# partial type which will be made more specific later. A partial type
1254+
# gets generated in assignment like 'x = []' where item type is not known.
1255+
if not self.infer_partial_type(name, lvalue, init_type):
1256+
self.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
1257+
self.set_inference_error_fallback_type(name, lvalue, init_type, context)
12451258
else:
12461259
# Infer type of the target.
12471260

@@ -1250,6 +1263,21 @@ def infer_variable_type(self, name: Var, lvalue: Node,
12501263

12511264
self.set_inferred_type(name, lvalue, init_type)
12521265

1266+
def infer_partial_type(self, name: Var, lvalue: Node, init_type: Type) -> bool:
1267+
if not isinstance(init_type, Instance):
1268+
return False
1269+
fullname = init_type.type.fullname()
1270+
if ((fullname == 'builtins.list' or fullname == 'builtins.set' or
1271+
fullname == 'builtins.dict')
1272+
and isinstance(init_type.args[0], NoneTyp)
1273+
and (fullname != 'builtins.dict' or isinstance(init_type.args[1], NoneTyp))
1274+
and isinstance(lvalue, NameExpr)):
1275+
partial_type = PartialType(init_type.type, name)
1276+
self.set_inferred_type(name, lvalue, partial_type)
1277+
self.partial_types[-1][name] = lvalue
1278+
return True
1279+
return False
1280+
12531281
def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None:
12541282
"""Store inferred variable type.
12551283
@@ -1275,23 +1303,6 @@ def set_inference_error_fallback_type(self, var: Var, lvalue: Node, type: Type,
12751303
if context.get_line() in self.errors.ignored_lines:
12761304
self.set_inferred_type(var, lvalue, AnyType())
12771305

1278-
def is_valid_inferred_type(self, typ: Type) -> bool:
1279-
"""Is an inferred type invalid?
1280-
1281-
Examples include the None type or a type with a None component.
1282-
"""
1283-
if is_same_type(typ, NoneTyp()):
1284-
return False
1285-
elif isinstance(typ, Instance):
1286-
for arg in typ.args:
1287-
if not self.is_valid_inferred_type(arg):
1288-
return False
1289-
elif isinstance(typ, TupleType):
1290-
for item in typ.items:
1291-
if not self.is_valid_inferred_type(item):
1292-
return False
1293-
return True
1294-
12951306
def narrow_type_from_binder(self, expr: Node, known_type: Type) -> Type:
12961307
if expr.literal >= LITERAL_TYPE:
12971308
restriction = self.binder.get(expr)
F987
@@ -1323,6 +1334,7 @@ def check_indexed_assignment(self, lvalue: IndexExpr,
13231334
13241335
The lvalue argument is the base[index] expression.
13251336
"""
1337+
self.try_infer_partial_type_from_indexed_assignment(lvalue, rvalue)
13261338
basetype = self.accept(lvalue.base)
13271339
method_type = self.expr_checker.analyze_external_member_access(
13281340
'__setitem__', basetype, context)
@@ -1331,6 +1343,26 @@ def check_indexed_assignment(self, lvalue: IndexExpr,
13311343
[nodes.ARG_POS, nodes.ARG_POS],
13321344
context)
13331345

1346+
def try_infer_partial_type_from_indexed_assignment(
1347+
self, lvalue: IndexExpr, rvalue: Node) -> None:
1348+
# TODO: Should we share some of this with try_infer_partial_type?
1349+
partial_types = self.partial_types[-1]
1350+
if not partial_types:
1351+
# Fast path leave -- no partial types in the current scope.
1352+
return
1353+
if isinstance(lvalue.base, RefExpr):
1354+
var = lvalue.base.node
1355+
if var in partial_types:
1356+
var = cast(Var, var)
1357+
typename = cast(Instance, var.type).type.fullname()
1358+
if typename == 'builtins.dict':
1359+
# TODO: Don't infer things twice.
1360+
key_type = self.accept(lvalue.index)
1361+
value_type = self.accept(rvalue)
1362+
if is_valid_inferred_type(key_type) and is_valid_inferred_type(value_type):
1363+
var.type = self.named_generic_type('builtins.dict', [key_type, value_type])
1364+
del partial_types[var]
1365+
13341366
def visit_expression_stmt(self, s: ExpressionStmt) -> Type:
13351367
self.accept(s.expr)
13361368

@@ -2032,6 +2064,21 @@ def enter(self) -> None:
20322064
def leave(self) -> None:
20332065
self.locals = None
20342066

2067+
def enter_partial_types(self) -> None:
2068+
"""Push a new scope for collecting partial types."""
2069+
self.partial_types.append({})
2070+
2071+
def leave_partial_types(self) -> None:
2072+
"""Pop partial type scope.
2073+
2074+
Also report errors for variables which still have partial
2075+
types, i.e. we couldn't infer a complete type.
2076+
"""
2077+
partial_types = self.partial_types.pop()
2078+
for var, context in partial_types.items():
2079+
self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
2080+
var.type = AnyType()
2081+
20352082
def is_within_function(self) -> bool:
20362083
"""Are we currently type checking within a function?
20372084
@@ -2289,3 +2336,21 @@ def infer_operator_assignment_method(type: Type, operator: str) -> str:
22892336
if type.type.has_readable_member(inplace):
22902337
method = inplace
22912338
return method
2339+
2340+
2341+
def is_valid_inferred_type(typ: Type) -> bool:
2342+
"""Is an inferred type valid?
2343+
2344+
Examples of invalid types include the None type or a type with a None component.
2345+
"""
2346+
if is_same_type(typ, NoneTyp()):
2347+
return False
2348+
elif isinstance(typ, Instance):
2349+
for arg in typ.args:
2350+
if not is_valid_inferred_type(arg):
2351+
return False
2352+
elif isinstance(typ, TupleType):
2353+
for item in typ.items:
2354+
if not is_valid_inferred_type(item):
2355+
return False
2356+
return True

mypy/checkexpr.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from mypy.types import (
66
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
7-
TupleType, Instance, TypeVarType, TypeTranslator, ErasedType, FunctionLike, UnionType
7+
TupleType, Instance, TypeVarType, TypeTranslator, ErasedType, FunctionLike, UnionType,
8+
PartialType
89
)
910
from mypy.nodes import (
1011
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
@@ -14,7 +15,7 @@
1415
ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator,
1516
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
1617
DictionaryComprehension, ComplexExpr, EllipsisExpr, LITERAL_TYPE,
17-
TypeAliasExpr, YieldExpr, BackquoteExpr
18+
TypeAliasExpr, YieldExpr, BackquoteExpr, ARG_POS
1819
)
1920
from mypy.errors import Errors
2021
from mypy.nodes import function_type
@@ -72,6 +73,11 @@ def analyze_ref_expr(self, e: RefExpr) -> Type:
7273
if isinstance(node, Var):
7374
# Variable reference.
7475
result = self.analyze_var_ref(node, e)
76+
if isinstance(result, PartialType):
77+
partial_types = self.chk.partial_types[-1]
78+
context = partial_types[node]
79+
self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
80+
result = AnyType()
7581
elif isinstance(node, FuncDef):
7682
# Reference to a global function.
7783
result = function_type(node, self.named_type('builtins.function'))
@@ -110,13 +116,35 @@ def visit_call_expr(self, e: CallExpr) -> Type:
110116
if e.analyzed:
111117
# It's really a special form that only looks like a call.
112118
return self.accept(e.analyzed)
119+
self.try_infer_partial_type(e)
113120
self.accept(e.callee)
114121
# Access callee type directly, since accept may return the Any type
115122
# even if the type is known (in a dynamically typed function). This
116123
# way we get a more precise callee in dynamically typed functions.
117124
callee_type = self.chk.type_map[e.callee]
118125
return self.check_call_expr_with_callee_type(callee_type, e)
119126

127+
def try_infer_partial_type(self, e: CallExpr) -> None:
128+
partial_types = self.chk.partial_types[-1]
129+
if not partial_types:
130+
# Fast path leave -- no partial types in the current scope.
131+
return
132+
if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr):
133+
var = e.callee.expr.node
134+
if var in partial_types:
135+
var = cast(Var, var)
136+
typename = cast(Instance, var.type).type.fullname()
137+
methodname = e.callee.name
138+
if (((typename == 'builtins.list' and methodname == 'append') or
139+
(typename == 'builtins.set' and methodname == 'add'))
140+
and e.arg_kinds == [ARG_POS]):
141+
# We can infer a full type for a partial List type.
142+
# TODO: Don't infer argument expression twice.
143+
item_type = self.accept(e.args[0])
144+
if mypy.checker.is_valid_inferred_type(item_type):
145+
var.type = self.chk.named_generic_type(typename, [item_type])
146+
del partial_types[var]
147+
120148
def check_call_expr_with_callee_type(self, callee_type: Type,
121149
e: CallExpr) -> Type:
122150
"""Type check call expression.

mypy/checkmember.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.types import (
66
Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarDef,
7-
Overloaded, TypeVarType, TypeTranslator, UnionType
7+
Overloaded, TypeVarType, TypeTranslator, UnionType, PartialType
88
)
99
from mypy.nodes import TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context
1010
from mypy.nodes import ARG_POS, function_type, Decorator, OverloadedFuncDef

mypy/constraints.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.types import (
66
CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType,
7-
Instance, TupleType, UnionType, Overloaded, ErasedType, is_named_instance
7+
Instance, TupleType, UnionType, Overloaded, ErasedType, PartialType, is_named_instance
88
)
99
from mypy.expandtype import expand_caller_var_args
1010
from mypy.maptype import map_instance_to_supertype
@@ -151,6 +151,12 @@ def visit_none_type(self, template: NoneTyp) -> List[Constraint]:
151151
def visit_erased_type(self, template: ErasedType) -> List[Constraint]:
152152
return []
153153

154+
# Errors
155+
156+
def visit_partial_type(self, template: PartialType) -> List[Constraint]:
157+
# We can't do anything useful with a partial type here.
158+
assert False, "Internal error"
159+
154160
# Non-trivial leaf type
155161

156162
def visit_type_var(self, template: TypeVarType) -> List[Constraint]:

mypy/erasetype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mypy.types import (
44
Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp,
55
Instance, TypeVarType, CallableType, TupleType, UnionType, Overloaded, ErasedType,
6-
TypeTranslator, TypeList
6+
TypeTranslator, TypeList, PartialType
77
)
88

99

@@ -46,6 +46,10 @@ def visit_erased_type(self, t: ErasedType) -> Type:
4646
# Should not get here.
4747
raise RuntimeError()
4848

49+
def visit_partial_type(self, t: PartialType) -> Type:
50+
# Should not get here.
51+
raise RuntimeError()
52+
4953
def visit_instance(self, t: Instance) -> Type:
5054
return Instance(t.type, [AnyType()] * len(t.args), t.line)
5155

mypy/expandtype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from mypy.types import (
44
Type, Instance, CallableType, TypeVisitor, UnboundType, ErrorType, AnyType,
5-
Void, NoneTyp, TypeVarType, Overloaded, TupleType, UnionType, ErasedType, TypeList
5+
Void, NoneTyp, TypeVarType, Overloaded, TupleType, UnionType, ErasedType, TypeList,
6+
PartialType
67
)
78

89

@@ -92,6 +93,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:
9293
def visit_union_type(self, t: UnionType) -> Type:
9394
return UnionType(self.expand_types(t.items), t.line)
9495

96+
def visit_partial_type(self, t: PartialType) -> Type:
97+
return t
98+
9599
def expand_types(self, types: List[Type]) -> List[Type]:
96100
a = [] # type: List[Type]
97101
for t in types:

mypy/join.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypy.types import (
66
Type, AnyType, NoneTyp, Void, TypeVisitor, Instance, UnboundType,
77
ErrorType, TypeVarType, CallableType, TupleType, ErasedType, TypeList,
8-
UnionType, FunctionLike, Overloaded
8+
UnionType, FunctionLike, Overloaded, PartialType
99
)
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy.subtypes import is_subtype, is_equivalent, is_subtype_ignoring_tvars
@@ -196,6 +196,11 @@ def visit_tuple_type(self, t: TupleType) -> Type:
196196
else:
197197
return self.default(self.s)
198198

199+
def visit_partial_type(self, t: PartialType) -> Type:
200+
# We only have partial information so we can't decide the join result. We should
201+
# never get here.
202+
assert False, "Internal error"
203+
199204
def join(self, s: Type, t: Type) -> Type:
200205
return join_types(s, t)
201206

0 commit comments

Comments
 (0)
0