8000 [suggest] Support refining existing type annotations (#7838) · python/mypy@a4f4ffe · GitHub
[go: up one dir, main page]

Skip to content

Commit a4f4ffe

Browse files
authored
[suggest] Support refining existing type annotations (#7838)
1 parent 38e0f5d commit a4f4ffe

File tree

2 files changed

+256
-6
lines changed

2 files changed

+256
-6
lines changed

mypy/suggestions.py

Lines changed: 115 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
288288
AnyType(TypeOfAny.special_form),
289289
self.builtin_type('builtins.function'))
290290

291+
def get_starting_type(self, fdef: FuncDef) -> CallableType:
292+
if isinstance(fdef.type, CallableType):
293+
return fdef.type
294+
else:
295+
return self.get_trivial_type(fdef)
296+
291297
def get_args(self, is_method: bool,
292298
base: CallableType, defaults: List[Optional[Type]],
293299
callsites: List[Callsite],
@@ -356,11 +362,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
356362
"""
357363
options = self.get_args(is_method, base, defaults, callsites, uses)
358364
options = [self.add_adjustments(tps) for tps in options]
359-
return [base.copy_modified(arg_types=list(x)) for x in itertools.product(*options)]
365+
return [refine_callable(base, base.copy_modified(arg_types=list(x)))
366+
for x in itertools.product(*options)]
360367

361368
def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]:
362369
"""Find all call sites of a function."""
363-
new_type = self.get_trivial_type(func)
370+
new_type = self.get_starting_type(func)
364371

365372
collector_plugin = SuggestionPlugin(func.fullname())
366373

@@ -413,7 +420,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
413420
with strict_optional_set(graph[mod].options.strict_optional):
414421
guesses = self.get_guesses(
415422
is_method,
416-
self.get_trivial_type(node),
423+
self.get_starting_type(node),
417424
self.get_default_arg_types(graph[mod], node),
418425
callsites,
419426
uses,
@@ -432,7 +439,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
432439
else:
433440
ret_types = [NoneType()]
434441

435-
guesses = [best.copy_modified(ret_type=t) for t in ret_types]
442+
guesses = [best.copy_modified(ret_type=refine_type(best.ret_type, t)) for t in ret_types]
436443
guesses = self.filter_options(guesses, is_method)
437444
best, errors = self.find_best(node, guesses)
438445

@@ -593,8 +600,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
593600
"""
594601
old = func.unanalyzed_type
595602
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
596-
# We don't modify type because it isn't necessary and it
597-
# would mess up the snapshotting.
603+
# We set type to None to ensure that the type always changes during
604+
# reprocessing.
605+
func.type = None
598606
func.unanalyzed_type = typ
599607
try:
600608
res = self.fgmanager.trigger(func.fullname())
@@ -682,6 +690,8 @@ def score_type(self, t: Type, arg_pos: bool) -> int:
682690
if isinstance(t, UnionType):
683691
if any(isinstance(x, AnyType) for x in t.items):
684692
return 20
693+
if any(has_any_type(x) for x in t.items):
694+
return 15
685695
if not is_optional(t):
686696
return 10
687697
if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)):
@@ -840,6 +850,105 @@ def count_errors(msgs: List[str]) -> int:
840850
return len([x for x in msgs if ' error: ' in x])
841851

842852

853+
def refine_type(ti: Type, si: Type) -> Type:
854+
"""Refine `ti` by replacing Anys in it with information taken from `si`
855+
856+
This basically works by, when the types have the same structure,
857+
traversing both of them in parallel and replacing Any on the left
858+
with whatever the type on the right is. If the types don't have the
859+
same structure (or aren't supported), the left type is chosen.
860+
861+
For example:
862+
refine(Any, T) = T, for all T
863+
refine(float, int) = float
864+
refine(List[Any], List[int]) = List[int]
865+
refine(Dict[int, Any], Dict[Any, int]) = Dict[int, int]
866+
refine(Tuple[int, Any], Tuple[Any, int]) = Tuple[int, int]
867+
868+
refine(Callable[[Any], Any], Callable[[int], int]) = Callable[[int], int]
869+
refine(Callable[..., int], Callable[[int, float], Any]) = Callable[[int, float], int]
870+
871+
refine(Optional[Any], int) = Optional[int]
872+
refine(Optional[Any], Optional[int]) = Optional[int]
873+
refine(Optional[Any], Union[int, str]) = Optional[Union[int, str]]
874+
refine(Optional[List[Any]], List[int]) = List[int]
875+
876+
"""
877+
t = get_proper_type(ti)
878+
s = get_proper_type(si)
879+
880+
if isinstance(t, AnyType):
881+
return s
882+
883+
if isinstance(t, Instance) and isinstance(s, Instance) and t.type == s.type:
884+
return t.copy_modified(args=[refine_type(ta, sa) for ta, sa in zip(t.args, s.args)])
885+
886+
if (
887+
isinstance(t, TupleType)
888+
and isinstance(s, TupleType)
889+
and t.partial_fallback == s.partial_fallback
890+
and len(t.items) == len(s.items)
891+
):
892+
return t.copy_modified(items=[refine_type(ta, sa) for ta, sa in zip(t.items, s.items)])
893+
894+
if isinstance(t, CallableType) and isinstance(s, CallableType):
895+
return refine_callable(t, s)
896+
897+
if isinstance(t, UnionType):
898+
return refine_union(t, s)
899+
900+
# TODO: Refining of builtins.tuple, Type?
901+
902+
return t
903+
904+
905+
def refine_union(t: UnionType, s: ProperType) -> Type:
906+
"""Refine a union type based on another type.
907+
908+
This is done by refining every component of the union against the
909+
right hand side type (or every component of its union if it is
910+
one). If an element of the union is succesfully refined, we drop it
911+
from the union in favor of the refined versions.
912+
"""
913+
rhs_items = s.items if isinstance(s, UnionType) else [s]
914+
915+
new_items = []
916+
for lhs in t.items:
917+
refined = False
918+
for rhs in rhs_items:
919+
new = refine_type(lhs, rhs)
920+
if new != lhs:
921+
new_items.append(new)
922+
refined = True
923+
if not refined:
924+
new_items.append(lhs)
925+
926+
# Turn strict optional on when simplifying the union since we
927+
# don't want to drop Nones.
928+
with strict_optional_set(True):
929+
return make_simplified_union(new_items)
930+
931+
932+
def refine_callable(t: CallableType, s: CallableType) -> CallableType:
933+
"""Refine a callable based on another.
934+
935+
See comments for refine_type.
936+
"""
937+
if t.fallback != s.fallback:
938+
return t
939+
940+
if t.is_ellipsis_args and not is_tricky_callable(s):
941+
return s.copy_modified(ret_type=refine_type(t.ret_type, s.ret_type))
942+
943+
if is_tricky_callable(t) or t.arg_kinds != s.arg_kinds:
944+
return t
945+
946+
return t.copy_modified(
947+
arg_types=[refine_type(ta, sa) for ta, sa in zip(t.arg_types, s.arg_types)],
948+
ret_type=refine_type(t.ret_type, s.ret_type),
949+
)
950+
951+
843952
T = TypeVar('T')
844953

845954

test-data/unit/fine-grained-suggest.test

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,3 +906,144 @@ Command 'suggest' is only valid after a 'check' command (that produces no parse
906906
==
907907
foo.py:4: error: unexpected EOF while parsing
908908
-- )
909+
910+
[case testSuggestRefine]
911+
# suggest: foo.foo
912+
# suggest: foo.spam
913+
# suggest: foo.eggs
914+
# suggest: foo.take_l
915+
# suggest: foo.union
916+
# suggest: foo.callable1
917+
# suggest: foo.callable2
918+
# suggest: foo.optional1
919+
# suggest: foo.optional2
920+
# suggest: foo.optional3
921+
# suggest: foo.optional4
922+
# suggest: foo.optional5
923+
# suggest: foo.dict1
924+
# suggest: foo.tuple1
925+
[file foo.py]
926+
from typing import Any, List, Union, Callable, Optional, Set, Dict, Tuple
927+
928+
def bar():
929+
return 10
930+
931+
def foo(x: int, y):
932+
return x + y
933+
934+
foo(bar(), 10)
935+
936+
def spam(x: int, y: Any) -> Any:
937+
return x + y
938+
939+
spam(bar(), 20)
940+
941+
def eggs(x: int) -> List[Any]:
942+
a = [x]
943+
return a
944+
945+
def take_l(x: List[Any]) -> Any:
946+
return x[0]
947+
948+
test = [10, 20]
949+
take_l(test)
950+
951+
def union(x: Union[int, str]):
952+
pass
953+
954+
union(10)
955+
956+
def add1(x: float) -> int:
957+
pass
958+
959+
def callable1(f: Callable[[int], Any]):
960+
return f(10)
961+
962+
callable1(add1)
963+
964+
def callable2(f: Callable[..., Any]):
965+
return f(10)
966+
967+
callable2(add1)
968+
969+
def optional1(x: Optional[Any]):
970+
pass
971+
972+
optional1(10)
973+
974+
def optional2(x: Union[None, int, Any]):
975+
if x is None:
976+
pass
977+
elif isinstance(x, str):
978+
pass
979+
else:
980+
add1(x)
981+
982+
optional2(10)
983+
optional2('test')
984+
985+
def optional3(x: Optional[List[Any]]):
986+
assert not x
987+
return x[0]
988+
989+
optional3(test)
990+
991+
set_test = {1, 2}
992+
993+
def optional4(x: Union[Set[Any], List[Any]]):
994+
pass
995+
996+
optional4(test)
997+
optional4(set_test)
998+
999+
def optional5(x: Optional[Any]):
1000+
pass
1001+
1002+
optional5(10)
1003+
optional5(None)
1004+
1005+
def dict1(d: Dict[int, Any]):
1006+
pass
1007+
1008+
d: Dict[Any, int]
1009+
dict1(d)
1010+
1011+
def tuple1(d: Tuple[int, Any]):
1012+
pass
1013+
1014+
t: Tuple[Any, int]
1015+
tuple1(t)
1016+
1017+
[builtins fixtures/isinstancelist.pyi]
1018+
[out]
1019+
(int, int) -> int
1020+
(int, int) -> int
1021+
(int) -> foo.List[int]
1022+
(foo.List[int]) -> int
1023+
(Union[int, str]) -> None
1024+
(Callable[[int], int]) -> int
1025+
(Callable[[float], int]) -> int
1026+
(Optional[int]) -> None
1027+
(Union[None, int, str]) -> None
1028+
(Optional[foo.List[int]]) -> int
1029+
(Union[foo.Set[int], foo.List[int]]) -> None
1030+
(Optional[int]) -> None
1031+
(foo.Dict[int, int]) -> None
1032+
(Tuple[int, int]) -> None
1033+
==
1034+
1035+
[case testSuggestRefine2]
1036+
# suggest: foo.optional5
1037+
[file foo.py]
1038+
from typing import Optional, Any
1039+
1040+
def optional5(x: Optional[Any]):
1041+
pass
1042+
1043+
optional5(10)
1044+
optional5(None)
1045+
1046+
[builtins fixtures/isinstancelist.pyi]
1047+
[out]
1048+
(Optional[int]) -> None
1049+
==

0 commit comments

Comments
 (0)
0