8000 Merge pull request #1388 from python/decorators-in-cycles · python/mypy@7b048e6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7b048e6

Browse files
committed
Merge pull request #1388 from python/decorators-in-cycles
Support more decorators in cycles
2 parents b9bc14a + 47a721c commit 7b048e6

File tree

4 files changed

+206
-13
lines changed

4 files changed

+206
-13
lines changed

mypy/semanal.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from mypy.lex import lex
7777
from mypy.parsetype import parse_type
7878
from mypy.sametypes import is_same_type
79+
from mypy.erasetype import erase_typevars
7980
from mypy import defaults
8081

8182

@@ -2426,8 +2427,12 @@ def visit_class_def(self, tdef: ClassDef) -> None:
24262427
def visit_decorator(self, dec: Decorator) -> None:
24272428
"""Try to infer the type of the decorated function.
24282429
2429-
This helps us resolve forward references to decorated
2430-
functions during type checking.
2430+
This lets us resolve references to decorated functions during
2431+
type checking when there are cyclic imports, as otherwise the
2432+
type might not be available when we need it.
2433+
2434+
This basically uses a simple special-purpose type inference
2435+
engine just for decorators.
24312436
"""
24322437
super().visit_decorator(dec)
24332438
if dec.var.is_property:
@@ -2453,13 +2458,21 @@ def visit_decorator(self, dec: Decorator) -> None:
24532458
decorator_preserves_type = False
24542459
break
24552460
if decorator_preserves_type:
2456-
# No non-special decorators left. We can trivially infer the type
2461+
# No non-identity decorators left. We can trivially infer the type
24572462
# of the function here.
24582463
dec.var.type = function_type(dec.func, self.builtin_type('function'))
2459-
if dec.decorators and returns_any_if_called(dec.decorators[0]):
2460-
# The outermost decorator will return Any so we know the type of the
2461-
# decorated function.
2462-
dec.var.type = AnyType()
2464+
if dec.decorators:
2465+
if returns_any_if_called(dec.decorators[0]):
2466+
# The outermost decorator will return Any so we know the type of the
2467+
# decorated function.
2468+
dec.var.type = AnyType()
2469+
sig = find_fixed_callable_return(dec.decorators[0])
2470+
if sig:
2471+
# The outermost decorator always returns the same kind of function,
2472+
# so we know that this is the type of the decoratored function.
2473+
orig_sig = function_type(dec.func, self.builtin_type('function'))
2474+
sig.name = orig_sig.items()[0].name
2475+
dec.var.type = sig
24632476

24642477
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
24652478
self.analyze(s.type)
@@ -2673,3 +2686,23 @@ def returns_any_if_called(expr: Node) -> bool:
26732686
elif isinstance(expr, CallExpr):
26742687
return returns_any_if_called(expr.callee)
26752688
return False
2689+
2690+
2691+
def find_fixed_callable_return(expr: Node) -> Optional[CallableType]:
2692+
if isinstance(expr, RefExpr):
2693+
if isinstance(expr.node, FuncDef):
2694+
typ = expr.node.type
2695+
if typ:
2696+
if isinstance(typ, CallableType) and has_no_typevars(typ.ret_type):
2697+
if isinstance(typ.ret_type, CallableType):
2698+
return typ.ret_type
2699+
elif isinstance(expr, CallExpr):
2700+
t = find_fixed_callable_return(expr.callee)
2701+
if t:
2702+
if isinstance(t.ret_type, CallableType):
2703+
return t.ret_type
2704+
return None
2705+
2706+
2707+
def has_no_typevars(typ: Type) -> bool:
2708+
return is_same_type(typ, erase_typevars(typ))

mypy/test/data/check-functions.test

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,146 @@ def g(): pass
863863
def dec(f): pass
864864

865865

866+
-- Decorator functions in import cycles
867+
-- ------------------------------------
868+
869+
870+
[case testDecoratorWithIdentityTypeInImportCycle]
871+
import a
872+
873+
[file a.py]
874+
import b
875+
from d import dec
876+
@dec
877+
def f(x: int) -> None: pass
878+
b.g(1) # E
879+
880+
[file b.py]
881+
import a
882+
from d import dec
883+
@dec
884+
def g(x: str) -> None: pass
885+
a.f('')
886+
887+
[file d.py]
888+
from typing import TypeVar
889+
T = TypeVar('T')
890+
def dec(f: T) -> T: return f
891+
892+
[out]
893+
tmp/a.py:1: note: In module imported here,
894+
main:1: note: ... from here:
895+
tmp/b.py:5: error: Argument 1 to "f" has incompatible type "str"; expected "int"
896+
main:1: note: In module imported here:
897+
tmp/a.py:5: error: Argument 1 to "g" has incompatible type "int"; expected "str"
898+
899+
[case testDecoratorWithNoAnnotationInImportCycle]
900+
import a
901+
902+
[file a.py]
903+
import b
904+
from d import dec
905+
@dec
906+
def f(x: int) -> None: pass
907+
b.g(1, z=4)
908+
909+
[file b.py]
910+
import a
911+
from d import dec
912+
@dec
913+
def g(x: str) -> None: pass
914+
a.f('', y=2)
915+
916+
[file d.py]
917+
def dec(f): return f
918+
919+
[case testDecoratorWithFixedReturnTypeInImportCycle]
920+
import a
921+
922+
[file a.py]
923+
import b
924+
from d import dec
925+
@dec
926+
def f(x: int) -> str: pass
927+
b.g(1)()
928+
929+
[file b.py]
930+
import a
931+
from d import dec
932+
@dec
933+
def g(x: int) -> str: pass
934+
a.f(1)()
935+
936+
[file d.py]
937+
from typing import Callable
938+
def dec(f: Callable[[int], str]) -> Callable[[int], str]: return f
939+
940+
[out]
941+
tmp/a.py:1: note: In module imported here,
942+
main:1: note: ... from here:
943+
tmp/b.py:5: error: "str" not callable
944+
main:1: note: In module imported here:
945+
tmp/a.py:5: error: "str" not callable
946+
947+
[case testDecoratorWithCallAndFixedReturnTypeInImportCycle]
948+
import a
949+
950+
[file a.py]
951+
import b
952+
from d import dec
953+
@dec()
954+
def f(x: int) -> str: pass
955+
b.g(1)()
956+
957+
[file b.py]
958+
import a
959+
from d import dec
960+
@dec()
961+
def g(x: int) -> str: pass
962+
a.f(1)()
963+
964+
[file d.py]
965+
from typing import Callable
966+
def dec() -> Callable[[Callable[[int], str]], Callable[[int], str]]: pass
967+
968+
[out]
969+
tmp/a.py:1: note: In module imported here,
970+
main:1: note: ... from here:
971+
tmp/b.py:5: error: "str" not callable
972+
main:1: note: In module imported here:
973+
tmp/a.py:5: error: "str" not callable
974+
975+
[case testDecoratorWithCallAndFixedReturnTypeInImportCycleAndDecoratorArgs]
976+
import a
977+
978+
[file a.py]
979+
import b
980+
from d import dec
981+
@dec(1)
982+
def f(x: int) -> str: pass
983+
b.g(1)()
984+
985+
[file b.py]
986+
import a
987+
from d import dec
988+
@dec(1)
989+
def g(x: int) -> str: pass
990+
a.f(1)()
991+
992+
[file d.py]
993+
from typing import Callable
994+
def dec(x: str) -> Callable[[Callable[[int], str]], Callable[[int], str]]: pass
995+
996+
[out]
997+
tmp/a.py:1: note: In module imported here,
998+
main:1: note: ... from here:
999+
tmp/b.py:3: error: Argument 1 to "dec" has incompatible type "int"; expected "str"
1000+
tmp/b.py:5: error: "str" not callable
1001+
main:1: note: In module imported here:
1002+
tmp/a.py:3: error: Argument 1 to "dec" has incompatible type "int"; expected "str"
1003+
tmp/a.py:5: error: "str" not callable
1004+
1005+
8661006
-- Conditional function definition
8671007
-- -------------------------------
8681008

@@ -1384,3 +1524,19 @@ with a:
13841524
def f() -> None:
13851525
pass
13861526
f(1) # E: Too many arguments for "f"
1527+
1528+
1529+
[case testNameForDecoratorMethod]
1530+
from typing import Callable
1531+
1532+
class A:
1533+
def f(self) -> None:
1534+
# In particular, test that the error message contains "g" of "A".
1535+
self.g() # E: Too few arguments for "g" of "A"
1536+
self.g(1)
1537+
@dec
1538+
def g(self, x: str) -> None: pass
1539+
1540+
def dec(f: Callable[[A, str], None]) -> Callable[[A, int], None]: pass
1541+
[out]
1542+
main: note: In member "f" of class "A":

mypy/test/data/check-inference.test

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,9 @@ x = 1
14061406
[out]
14071407

14081408
[case testMultipassAndDecoratedMethod]
1409-
from typing import Callable
1409+
from typing import Callable, TypeVar
1410+
1411+
T = TypeVar('T')
14101412

14111413
class A:
14121414
def f(self) -> None:
@@ -1415,7 +1417,7 @@ class A:
14151417
@dec
14161418
def g(self, x: str) -> None: pass
14171419

1418-
def dec(f: Callable[[A, str], None]) -> Callable[[A, int], None]: pass
1420+
def dec(f: Callable[[A, str], T]) -> Callable[[A, int], T]: pass
14191421
[out]
14201422
main: note: In member "f" of class "A":
14211423

mypy/test/data/check-multiple-inheritance.test

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ class A(G[int]):
233233
class B(A, int): pass
234234

235235
[case testCannotDetermineTypeInMultipleInheritance]
236-
from typing import Callable
236+
from typing import Callable, TypeVar
237+
T = TypeVar('T')
237238
class A(B, C):
238239
def f(self): pass
239240
class B:
@@ -242,8 +243,9 @@ class B:
242243
class C:
243244
@dec
244245
def f(self): pass
245-
def dec(f) -> Callable[[], None]: return None
246+
def dec(f: Callable[..., T]) -> Callable[..., T]:
247+
return f
246248
[out]
247249
main: note: In class "A":
248-
main:2: error: Cannot determine type of 'f' in base class 'B'
249-
main:2: error: Cannot determine type of 'f' in base class 'C'
250+
main:3: error: Cannot determine type of 'f' in base class 'B'
251+
main:3: error: Cannot determine type of 'f' in base class 'C'

0 commit comments

Comments
 (0)
0