8000 Can understand functools.total_ordering · python/mypy@cf81f1c · GitHub
[go: up one dir, main page]

Skip to content

Commit cf81f1c

Browse files
committed
Can understand functools.total_ordering
1 parent fdcbb74 commit cf81f1c

File tree

4 files changed

+334
-0
lines changed

4 files changed

+334
-0
lines changed

mypy/plugins/default.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def get_class_decorator_hook(self, fullname: str
9090
) -> Optional[Callable[[ClassDefContext], None]]:
9191
from mypy.plugins import attrs
9292
from mypy.plugins import dataclasses
93+
from mypy.plugins import functools
9394

9495
if fullname in attrs.attr_class_makers:
9596
return attrs.attr_class_maker_callback
@@ -100,6 +101,10 @@ def get_class_decorator_hook(self, fullname: str
100101
)
101102
elif fullname in dataclasses.dataclass_makers:
102103
return dataclasses.dataclass_class_maker_callback
104+
105+
if fullname in functools.functools_total_ordering_makers:
106+
return functools.functools_total_ordering_maker_callback
107+
103108
return None
104109

105110

mypy/plugins/functools.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Plugin for supporting the functools standard library module."""
2+
from typing import Dict, Optional
3+
from typing_extensions import Final
4+
5+
import mypy.plugin
6+
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
7+
from mypy.plugins.common import add_method
8+
from mypy.types import CallableType, Type
9+
10+
11+
functools_total_ordering_makers: Final = {
12+
'functools.total_ordering',
13+
}
14+
15+
_ORDERING_METHODS: Final = {
16+
'__lt__',
17+
'__le__',
18+
'__gt__',
19+
'__ge__',
20+
}
21+
22+
23+
def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext,
24+
auto_attribs_default: bool = False) -> None:
25+
"""Add dunder methods to classes decorated with functools.total_ordering."""
26+
if ctx.api.options.python_version < (3, 2):
27+
ctx.api.fail("functools.total_ordering is not supported in Python 2 or 3.1", ctx.reason)
28+
return
29+
30+
comparison_methods = _analyze_class(ctx)
31+
if not comparison_methods:
32+
ctx.api.fail('must define at least one ordering operation: < > <= >=', ctx.reason)
33+
return
34+
35+
root = max(comparison_methods) # prefer __lt__ to __le__ to __gt__ to __ge__
36+
root_method = comparison_methods[root]
37+
other_type = _find_other_type(root_method)
38+
if isinstance(root_method.type, CallableType):
39+
ret_type = root_method.type.ret_type
40+
else:
41+
ret_type = ctx.api.named_type('__builtins__.bool')
42+
for additional_op in _ORDERING_METHODS - {root}:
43+
if additional_op not in comparison_methods:
44+
args = [Argument(Var('other', other_type), other_type, None, ARG_POS)]
45+
add_method(ctx, additional_op, args, ret_type)
46+
47+
48+
def _find_other_type(method: FuncItem) -> Optional[Type]:
49+
"""Find the type of the ``other`` argument in a comparision method."""
50+
first_arg_pos = 0 if method.is_static else 1
51+
cur_pos_arg = 0
52+
other_arg = None
53+
for arg in method.arguments:
54+
if arg.kind == ARG_POS:
55+
if cur_pos_arg == first_arg_pos:
56+
other_arg = arg
57+
break
58+
59+
cur_pos_arg += 1
60+
elif arg.kind != ARG_STAR2:
61+
other_arg = arg
62+
break
63+
return None if other_arg is None else other_arg.type_annotation
64+
65+
66+
def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> Dict[str, FuncItem]:
67+
"""Analyze the class body, its parents, and return the comparison methods found."""
68+
# Traverse the MRO and collect ordering methods.
69+
comparison_methods = {} # type: Dict[str, FuncItem]
70+
# Skip object because total_ordering does not use methods from object
71+
for cls in ctx.cls.info.mro[:-1]:
72+
for name in _ORDERING_METHODS:
73+
if name in cls.names and name not in comparison_methods:
74+
node = cls.names[name].node
75+
if isinstance(node, FuncItem):
76+
comparison_methods[name] = node
77+
78+
return comparison_methods

mypy/test/testcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
'check-reports.test',
8989
'check-errorcodes.test',
9090
'check-annotated.test',
91+
'check-functools.test',
9192
]
9293

9394
# Tests that use Python 3.8-only AST features (like expression-scoped ignores):

test-data/unit/check-functools.test

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
[case testTotalOrderingEqLt]
2+
from functools import total_ordering
3+
4+
@total_ordering
5+
class Ord:
6+
def __eq__(self, other: object) -> bool:
7+
return False
8+
9+
def __lt__(self, other: "Ord") -> bool:
10+
return False
11+
12+
Ord() < Ord()
13+
Ord() <= Ord()
14+
Ord() == Ord()
15+
Ord() > Ord()
16+
Ord() >= Ord()
17+
18+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
19+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
20+
Ord() == 1
21+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
22+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
23+
[builtins fixtures/ops.pyi]
24+
[builtins fixtures/dict.pyi]
25+
26+
[case testTotalOrderingEqLe]
27+
from functools import total_ordering
28+
29+
@total_ordering
30+
class Ord:
31+
def __eq__(self, other: object) -> bool:
32+
return False
33+
34+
def __le__(self, other: "Ord") -> bool:
35+
return False
36+
37+
Ord() < Ord()
38+
Ord() <= Ord()
39+
Ord() == Ord()
40+
Ord() > Ord()
41+
Ord() >= Ord()
42+
43+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
44+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
45+
Ord() == 1
46+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
47+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
48+
[builtins fixtures/ops.pyi]
49+
[builtins fixtures/dict.pyi]
50+
51+
[case testTotalOrderingEqGt]
52+
from functools import total_ordering
53+
54+
@total_ordering
55+
class Ord:
56+
def __eq__(self, other: object) -> bool:
57+
return False
58+
59+
def __gt__(self, other: "Ord") -> bool:
60+
return False
61+
62+
Ord() < Ord()
63+
Ord() <= Ord()
64+
Ord() == Ord()
65+
Ord() > Ord()
66+
Ord() >= Ord()
67+
68+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
69+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
70+
Ord() == 1
71+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
72+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
73+
[builtins fixtures/ops.pyi]
74+
[builtins fixtures/dict.pyi]
75+
76+
[case testTotalOrderingEqGe]
77+
from functools import total_ordering
78+
79+
@total_ordering
80+
class Ord:
81+
def __eq__(self, other: object) -> bool:
82+
return False
83+
84+
def __ge__(self, other: "Ord") -> bool:
85+
return False
86+
87+
Ord() < Ord()
88+
Ord() <= Ord()
89+
Ord() == Ord()
90+
Ord() > Ord()
91+
Ord() >= Ord()
92+
93+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
94+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
95+
Ord() == 1
96+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
97+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
98+
[builtins fixtures/ops.pyi]
99+
[builtins fixtures/dict.pyi]
100+
101+
[case testTotalOrderingEqGe]
102+
from functools import total_ordering
103+
104+
@total_ordering
105+
class Ord:
106+
def __eq__(self, other: object) -> bool:
107+
return False
108+
109+
def __ge__(self, other: "Ord") -> bool:
110+
return False
111+
112+
Ord() < Ord()
113+
Ord() <= Ord()
114+
Ord() == Ord()
115+
Ord() > Ord()
116+
Ord() >= Ord()
117+
118+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
119+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
120+
Ord() == 1
121+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
122+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
123+
[builtins fixtures/ops.pyi]
124+
[builtins fixtures/dict.pyi]
125+
126+
[case testTotalOrderingLt]
127+
from functools import total_ordering
128+
129+
@total_ordering
130+
class Ord:
131+
def __lt__(self, other: "Ord") -> bool:
132+
return False
133+
134+
Ord() < Ord()
135+
Ord() <= Ord()
136+
Ord() == Ord()
137+
Ord() > Ord()
138+
Ord() >= Ord()
139+
140+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
141+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
142+
Ord() == 1
143+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
144+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
145+
[builtins fixtures/ops.pyi]
146+
[builtins fixtures/dict.pyi]
147+
148+
[case testTotalOrderingLe]
149+
from functools import total_ordering
150+
151+
@total_ordering
152+
class Ord:
153+
def __le__(self, other: "Ord") -> bool:
154+
return False
155+
156+
Ord() < Ord()
157+
Ord() <= Ord()
158+
Ord() == Ord()
159+
Ord() > Ord()
160+
Ord() >= Ord()
161+
162+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
163+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
164+
Ord() == 1
165+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
166+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
167+
[builtins fixtures/ops.pyi]
168+
[builtins fixtures/dict.pyi]
169+
170+
[case testTotalOrderingGt]
171+
from functools import total_ordering
172+
173+
@total_ordering
174+
class Ord:
175+
def __gt__(self, other: "Ord") -> bool:
176+
return False
177+
178+
Ord() < Ord()
179+
Ord() <= Ord()
180+
Ord() == Ord()
181+
Ord() > Ord()
182+
Ord() >= Ord()
183+
184+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
185+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
186+
Ord() == 1
187+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
188+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
189+
[builtins fixtures/ops.pyi]
190+
[builtins fixtures/dict.pyi]
191+
192+
[case testTotalOrderingGe]
193+
from functools import total_ordering
194+
195+
@total_ordering
196+
class Ord:
197+
def __ge__(self, other: "Ord") -> bool:
198+
return False
199+
200+
Ord() < Ord()
201+
Ord() <= Ord()
202+
Ord() == Ord()
203+
Ord() > Ord()
204+
Ord() >= Ord()
205+
206+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
207+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
208+
Ord() == 1
209+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
210+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
211+
[builtins fixtures/ops.pyi]
212+
[builtins fixtures/dict.pyi]
213+
214+
[case testTotalOrderingGe]
215+
from functools import total_ordering
216+
217+
@total_ordering
218+
class Ord:
219+
def __ge__(self, other: "Ord") -> bool:
220+
return False
221+
222+
Ord() < Ord()
223+
Ord() <= Ord()
224+
Ord() == Ord()
225+
Ord() > Ord()
226+
Ord() >= Ord()
227+
228+
Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
229+
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
230+
Ord() == 1
231+
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
232+
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
233+
[builtins fixtures/ops.pyi]
234+
[builtins fixtures/dict.pyi]
235+
236+
[case testTotalOrderingEq]
237+
from functools import total_ordering
238+
239+
@total_ordering # E: must define at least one ordering operation: < > <= >=
240+
class Ord:
241+
def __eq__(self, other: object) -> bool:
242+
return False
243+
244+
Ord() < Ord() # E: Unsupported left operand type for < ("Ord")
245+
Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord")
246+
Ord() == Ord()
247+
Ord() > Ord() # E: Unsupported left operand type for > ("Ord")
248+
Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord")
249+
[builtins fixtures/ops.pyi]
250+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)
0