@@ -288,6 +288,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
288
288
AnyType (TypeOfAny .special_form ),
289
289
self .builtin_type ('builtins.function' ))
290
290
291
+ def get_starting_type (self , fdef : FuncDef ) -> CallableType :
292
+ if isinstance (fdef .type , CallableType ):
293
+ return fdef .type
294
+ else :
295
+ return self .get_trivial_type (fdef )
296
+
291
297
def get_args (self , is_method : bool ,
292
298
base : CallableType , defaults : List [Optional [Type ]],
293
299
callsites : List [Callsite ],
@@ -356,11 +362,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
356
362
"""
357
363
options = self .get_args (is_method , base , defaults , callsites , uses )
358
364
options = [self .add_adjustments (tps ) for tps in options ]
359
- return [base .copy_modified (arg_types = list (x )) for x in itertools .product (* options )]
365
+ return [refine_callable (base , base .copy_modified (arg_types = list (x )))
366
+ for x in itertools .product (* options )]
360
367
361
368
def get_callsites (self , func : FuncDef ) -> Tuple [List [Callsite ], List [str ]]:
362
369
"""Find all call sites of a function."""
363
- new_type = self .get_trivial_type (func )
370
+ new_type = self .get_starting_type (func )
364
371
365
372
collector_plugin = SuggestionPlugin (func .fullname ())
366
373
@@ -413,7 +420,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
413
420
with strict_optional_set (graph [mod ].options .strict_optional ):
414
421
guesses = self .get_guesses (
415
422
is_method ,
416
- self .get_trivial_type (node ),
423
+ self .get_starting_type (node ),
417
424
self .get_default_arg_types (graph [mod ], node ),
418
425
callsites ,
419
426
uses ,
@@ -432,7 +439,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
432
439
else :
433
440
ret_types = [NoneType ()]
434
441
435
- guesses = [best .copy_modified (ret_type = t ) for t in ret_types ]
442
+ guesses = [best .copy_modified (ret_type = refine_type ( best . ret_type , t ) ) for t in ret_types ]
436
443
guesses = self .filter_options (guesses , is_method )
437
444
best , errors = self .find_best (node , guesses )
438
445
@@ -593,8 +600,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
593
600
"""
594
601
old = func .unanalyzed_type
595
602
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
596
- # We don't modify type because it isn't necessary and it
597
- # would mess up the snapshotting.
603
+ # We set type to None to ensure that the type always changes during
604
+ # reprocessing.
605
+ func .type = None
598
606
func .unanalyzed_type = typ
599
607
try :
600
608
res = self .fgmanager .trigger (func .fullname ())
@@ -682,6 +690,8 @@ def score_type(self, t: Type, arg_pos: bool) -> int:
682
690
if isinstance (t , UnionType ):
683
691
if any (isinstance (x , AnyType ) for x in t .items ):
684
692
return 20
693
+ if any (has_any_type (x ) for x in t .items ):
694
+ return 15
685
695
if not is_optional (t ):
686
696
return 10
687
697
if isinstance (t , CallableType ) and (has_any_type (t ) or is_tricky_callable (t )):
@@ -840,6 +850,105 @@ def count_errors(msgs: List[str]) -> int:
840
850
return len ([x for x in msgs if ' error: ' in x ])
841
851
842
852
853
+ def refine_type (ti : Type , si : Type ) -> Type :
854
+ """Refine `ti` by replacing Anys in it with information taken from `si`
855
+
856
+ This basically works by, when the types have the same structure,
857
+ traversing both of them in parallel and replacing Any on the left
858
+ with whatever the type on the right is. If the types don't have the
859
+ same structure (or aren't supported), the left type is chosen.
860
+
861
+ For example:
862
+ refine(Any, T) = T, for all T
863
+ refine(float, int) = float
864
+ refine(List[Any], List[int]) = List[int]
865
+ refine(Dict[int, Any], Dict[Any, int]) = Dict[int, int]
866
+ refine(Tuple[int, Any], Tuple[Any, int]) = Tuple[int, int]
867
+
868
+ refine(Callable[[Any], Any], Callable[[int], int]) = Callable[[int], int]
869
+ refine(Callable[..., int], Callable[[int, float], Any]) = Callable[[int, float], int]
870
+
871
+ refine(Optional[Any], int) = Optional[int]
872
+ refine(Optional[Any], Optional[int]) = Optional[int]
873
+ refine(Optional[Any], Union[int, str]) = Optional[Union[int, str]]
874
+ refine(Optional[List[Any]], List[int]) = List[int]
875
+
876
+ """
877
+ t = get_proper_type (ti )
878
+ s = get_proper_type (si )
879
+
880
+ if isinstance (t , AnyType ):
881
+ return s
882
+
883
+ if isinstance (t , Instance ) and isinstance (s , Instance ) and t .type == s .type :
884
+ return t .copy_modified (args = [refine_type (ta , sa ) for ta , sa in zip (t .args , s .args )])
885
+
886
+ if (
887
+ isinstance (t , TupleType )
888
+ and isinstance (s , TupleType )
889
+ and t .partial_fallback == s .partial_fallback
890
+ and len (t .items ) == len (s .items )
891
+ ):
892
+ return t .copy_modified (items = [refine_type (ta , sa ) for ta , sa in zip (t .items , s .items )])
893
+
894
+ if isinstance (t , CallableType ) and isinstance (s , CallableType ):
895
+ return refine_callable (t , s )
896
+
897
+ if isinstance (t , UnionType ):
898
+ return refine_union (t , s )
899
+
900
+ # TODO: Refining of builtins.tuple, Type?
901
+
902
+ return t
903
+
904
+
905
+ def refine_union (t : UnionType , s : ProperType ) -> Type :
906
+ """Refine a union type based on another type.
907
+
908
+ This is done by refining every component of the union against the
909
+ right hand side type (or every component of its union if it is
910
+ one). If an element of the union is succesfully refined, we drop it
911
+ from the union in favor of the refined versions.
912
+ """
913
+ rhs_items = s .items if isinstance (s , UnionType ) else [s ]
914
+
915
+ new_items = []
916
+ for lhs in t .items :
917
+ refined = False
918
+ for rhs in rhs_items :
919
+ new = refine_type (lhs , rhs )
920
+ if new != lhs :
921
+ new_items .append (new )
922
+ refined = True
923
+ if not refined :
924
+ new_items .append (lhs )
925
+
926
+ # Turn strict optional on when simplifying the union since we
927
+ # don't want to drop Nones.
928
+ with strict_optional_set (True ):
929
+ return make_simplified_union (new_items )
930
+
931
+
932
+ def refine_callable (t : CallableType , s : CallableType ) -> CallableType :
933
+ """Refine a callable based on another.
934
+
935
+ See comments for refine_type.
936
+ """
937
+ if t .fallback != s .fallback :
938
+ return t
939
+
940
+ if t .is_ellipsis_args and not is_tricky_callable (s ):
941
+ return s .copy_modified (ret_type = refine_type (t .ret_type , s .ret_type ))
942
+
943
+ if is_tricky_callable (t ) or t .arg_kinds != s .arg_kinds :
944
+ return t
945
+
946
+ return t .copy_modified (
947
+ arg_types = [refine_type (ta , sa ) for ta , sa in zip (t .arg_types , s .arg_types )],
948
+ ret_type = refine_type (t .ret_type , s .ret_type ),
949
+ )
950
+
951
+
843
952
T = TypeVar ('T' )
844
953
845
954
0 commit comments