8000 [BE][Easy] Fix `PYI034`: non-self-return-type in tensor method hints · XuehaiPan/pytorch@15bc35b · GitHub
[go: up one dir, main page]

Skip to content

Commit 15bc35b

Browse files
committed
[BE][Easy] Fix PYI034: non-self-return-type in tensor method hints
ghstack-source-id: 9326463 Pull Request resolved: pytorch#129886
1 parent 7f0736a commit 15bc35b

File tree

2 files changed

+92
-52
lines changed

2 files changed

+92
-52
lines changed

tools/pyi/gen_pyi.py

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,17 @@
3939
)
4040

4141
from torchgen.api.python import (
42+
all_ops,
43+
binary_ops,
44+
comparison_ops,
4245
format_function_signature as defs,
46+
inplace_binary_ops,
4347
PythonSignatureGroup,
4448
PythonSignatureNativeFunctionPair,
4549
returns_structseq_pyi,
50+
symmetric_comparison_ops,
51+
to_py_type_ops,
52+
unary_ops,
4653
)
4754
from torchgen.gen import parse_native_yaml, parse_tags_yaml
4855
from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
@@ -182,50 +189,6 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
182189
"copy_",
183190
]
184191

185-
binary_ops = (
186-
"add",
187-
"sub",
188-
"mul",
189-
"div",
190-
"pow",
191-
"lshift",
192-
"rshift",
193-
"mod",
194-
"truediv",
195-
"matmul",
196-
"floordiv",
197-
"radd",
198-
"rsub",
199-
"rmul",
200-
"rtruediv",
201-
"rfloordiv",
202-
"rpow", # reverse arithmetic
203-
"and",
204-
"or",
205-
"xor",
206-
"rand",
207-
"ror",
208-
"rxor", # logic
209-
"iadd",
210-
"iand",
211-
"idiv",
212-
"ilshift",
213-
"imul",
214-
"ior",
215-
"irshift",
216-
"isub",
217-
"ixor",
218-
"ifloordiv",
219-
"imod", # inplace ops
220-
)
221-
symmetric_comparison_ops = ("eq", "ne")
222-
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
223-
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
224-
225-
unary_ops = ("neg", "abs", "invert")
226-
to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
227-
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
228-
229192

230193
def sig_for_ops(opname: str) -> list[str]:
231194
"""sig_for_ops(opname : str) -> list[str]
@@ -237,17 +200,25 @@ def sig_for_ops(opname: str) -> list[str]:
237200
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
238201

239202
name = opname[2:-2]
240-
if name in binary_ops:
203+
if name in symmetric_comparison_ops:
204+
# e.g.: `__eq__`, `__ne__`
205+
# unsafe override https://github.com/python/mypy/issues/5704
206+
# PYI032 any-eq-ne-annotation https://docs.astral.sh/ruff/rules/any-eq-ne-annotation
207+
return [
208+
f"def {opname}(self, other: Any) -> Tensor: ... # type: ignore[override] # noqa: PYI032"
209+
]
210+
if name in inplace_binary_ops:
211+
# e.g.: `__iadd__`, `__imul__`
212+
# Use `Self` as return type instead of `Tensor` to allow for subclasses
213+
return [f"def {opname}(self, other: Any) -> Self: ..."]
214+
if name in binary_ops or name in comparison_ops:
215+
# e.g.: `__add__`, `__mul__` and `__le__`, `__gt__`
241216
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
242-
if name in comparison_ops:
243-
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
244-
if name in symmetric_comparison_ops:
245-
# unsafe override https://github.com/python/mypy/issues/5704
246-
sig += " # type: ignore[override]"
247-
return [sig]
248217
if name in unary_ops:
218+
# e.g.: `__pos__`, `__neg__`, `__abs__`
249219
return [f"def {opname}(self) -> Tensor: ..."]
250220
if name in to_py_type_ops:
221+
# e.g.: `__int__`, `__index__`, `__float__`, `__bool__`
251222
if name in {"bool", "float", "complex"}:
252223
tname = name
253224
elif name == "nonzero":

torchgen/api/python.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,66 @@
197197
# For examples, only pyi signatures include return types.
198198

199199

200+
inplace_binary_ops = {
201+
"iadd",
202+
"iand",
203+
"ifloordiv",
204+
"ilshift",
205+
"imatmul",
206+
"imod",
207+
"imul",
208+
"ior",
209+
"ipow",
210+
"irshift",
211+
"isub",
212+
"itruediv",
213+
"ixor",
214+
}
215+
binary_ops = inplace_binary_ops | {
216+
"add",
217+
"sub",
218+
"mul",
219+
"div",
220+
"pow",
221+
"lshift",
222+
"rshift",
223+
"mod",
224+
"truediv",
225+
"matmul",
226+
"floordiv",
227+
"radd",
228+
"rsub",
229+
"rmul",
230+
"rtruediv",
231+
"rfloordiv",
232+
"rpow", # reverse arithmetic
233+
"and",
234+
"or",
235+
"xor",
236+
"rand",
237+
"ror",
238+
"rxor", # logic
239+
"iadd",
240+
"iand",
241+
"idiv",
242+
"ilshift",
243+
"imul",
244+
"ior",
245+
"irshift",
246+
"isub",
247+
"ixor",
248+
"ifloordiv",
249+
"imod", # inplace ops
250+
}
251+
symmetric_comparison_ops = {"eq", "ne"}
252+
asymmetric_comparison_ops = {"ge", "gt", "lt", "le"}
253+
comparison_ops = symmetric_comparison_ops | asymmetric_comparison_ops
254+
255+
unary_ops = {"pos", "neg", "abs", "invert"}
256+
to_py_type_ops = {"bool", "float", "complex", "long", "index", "int", "nonzero"}
257+
all_ops = binary_ops | comparison_ops | unary_ops | to_py_type_ops
258+
259+
200260
def format_function_signature(
201261
name: str, arguments: Iterable[str] = (), return_type: str | None = None
202262
) -> str:
@@ -1070,14 +1130,23 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
10701130

10711131

10721132
def returns_str_pyi(signature: PythonSignature) -> str:
1133+
name = signature.name
10731134
field_names = structseq_fieldnames(signature.returns.returns)
10741135
if field_names:
1075-
return f"torch.return_types.{signature.name}"
1136+
return f"torch.return_types.{name}"
10761137

10771138
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
10781139
if len(python_returns) > 1:
10791140
return "tuple[" + ", ".join(python_returns) + "]"
10801141
if len(python_returns) == 1:
1142+
if (
1143+
name.startswith("__")
1144+
and name.endswith("__")
1145+
and name[2:-2] in inplace_binary_ops # e.g.: `__iadd__`, `__imul__`
1146+
):
1147+
# Got in-place dunder magic method
1148+
# use `Self` as return type to allow for subclasses
1149+
return "Self"
10811150
return python_returns[0]
10821151
return "None"
10831152

0 commit comments

Comments
 (0)
0