18
18
import sys
19
19
20
20
from typing import Tuple , Union , TypeVar , Callable , Sequence , Optional , Any , cast , List
21
- from mypy .sharedparse import special_function_elide_names , argument_elide_name
21
+ from mypy .sharedparse import (
22
+ special_function_elide_names , argument_elide_name , is_overload_part ,
23
+ )
22
24
from mypy .nodes import (
23
25
MypyFile , Node , ImportBase , Import , ImportAll , ImportFrom , FuncDef , OverloadedFuncDef ,
24
26
ClassDef , Decorator , Block , Var , OperatorAssignmentStmt ,
@@ -209,19 +211,27 @@ def as_block(self, stmts: List[ast27.stmt], lineno: int) -> Block:
209
211
210
212
def fix_function_overloads (self , stmts : List [Statement ]) -> List [Statement ]:
211
213
ret = [] # type: List[Statement]
212
- current_overload = []
214
+ current_overload = [] # type: List[Decorator]
213
215
current_overload_name = None
214
216
# mypy doesn't actually check that the decorator is literally @overload
215
217
for stmt in stmts :
216
- if isinstance (stmt , Decorator ) and stmt .name () == current_overload_name :
218
+ if (isinstance (stmt , Decorator )
219
+ and is_overload_part (stmt )
220
+ and stmt .name () == current_overload_name ):
217
221
current_overload .append (stmt )
222
+ elif (isinstance (stmt , FuncDef )
223
+ and stmt .name () == current_overload_name
224
+ and stmt .name () is not None ):
225
+ ret .append (OverloadedFuncDef (current_overload , stmt ))
226
+ current_overload = []
227
+ current_overload_name = None
218
228
else :
219
229
if len (current_overload ) == 1 :
220
230
ret .append (current_overload [0 ])
221
231
elif len (current_overload ) > 1 :
222
- ret .append (OverloadedFuncDef (current_overload ))
232
+ ret .append (OverloadedFuncDef (current_overload , None ))
223
233
224
- if isinstance (stmt , Decorator ):
234
+ if isinstance (stmt , Decorator ) and is_overload_part ( stmt ) :
225
235
current_overload = [stmt ]
226
236
current_overload_name = stmt .name ()
227
237
else :
@@ -232,7 +242,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
232
242
if len (current_overload ) == 1 :
233
243
ret .append (current_overload [0 ])
234
244
elif len (current_overload ) > 1 :
235
- ret .append (OverloadedFuncDef (current_overload ))
245
+ ret .append (OverloadedFuncDef (current_overload , None ))
236
246
return ret
237
247
238
248
def in_class (self ) -> bool :
0 commit comments