39
39
)
40
40
41
41
from torchgen .api .python import (
42
+ all_ops ,
43
+ binary_ops ,
44
+ comparison_ops ,
42
45
format_function_signature as defs ,
46
+ inplace_binary_ops ,
43
47
PythonSignatureGroup ,
44
48
PythonSignatureNativeFunctionPair ,
45
49
returns_structseq_pyi ,
50
+ symmetric_comparison_ops ,
51
+ to_py_type_ops ,
52
+ unary_ops ,
46
53
)
47
54
from torchgen .gen import parse_native_yaml , parse_tags_yaml
48
55
from torchgen .model import _TorchDispatchModeKey , DispatchKey , Variant
@@ -182,50 +189,6 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
182
189
"copy_" ,
183
190
]
184
191
185
- binary_ops = (
186
- "add" ,
187
- "sub" ,
188
- "mul" ,
189
- "div" ,
190
- "pow" ,
191
- "lshift" ,
192
- "rshift" ,
193
- "mod" ,
194
- "truediv" ,
195
- "matmul" ,
196
- "floordiv" ,
197
- "radd" ,
198
- "rsub" ,
199
- "rmul" ,
200
- "rtruediv" ,
201
- "rfloordiv" ,
202
- "rpow" , # reverse arithmetic
203
- "and" ,
204
- "or" ,
205
- "xor" ,
206
- "rand" ,
207
- "ror" ,
208
- "rxor" , # logic
209
- "iadd" ,
210
- "iand" ,
211
- "idiv" ,
212
- "ilshift" ,
213
- "imul" ,
214
- "ior" ,
215
- "irshift" ,
216
- "isub" ,
217
- "ixor" ,
218
- "ifloordiv" ,
219
- "imod" , # inplace ops
220
- )
221
- symmetric_comparison_ops = ("eq" , "ne" )
222
- asymmetric_comparison_ops = ("ge" , "gt" , "lt" , "le" )
223
- comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
224
-
225
- unary_ops = ("neg" , "abs" , "invert" )
226
- to_py_type_ops = ("bool" , "float" , "complex" , "long" , "index" , "int" , "nonzero" )
227
- all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
228
-
229
192
230
193
def sig_for_ops (opname : str ) -> list [str ]:
231
194
"""sig_for_ops(opname : str) -> list[str]
@@ -237,17 +200,25 @@ def sig_for_ops(opname: str) -> list[str]:
237
200
assert opname .endswith ("__" ) and opname .startswith ("__" ), f"Unexpected op { opname } "
238
201
239
202
name = opname [2 :- 2 ]
240
- if name in binary_ops :
203
+ if name in symmetric_comparison_ops :
204
+ # e.g.: `__eq__`, `__ne__`
205
+ # unsafe override https://github.com/python/mypy/issues/5704
206
+ # PYI032 any-eq-ne-annotation https://docs.astral.sh/ruff/rules/any-eq-ne-annotation
207
+ return [
208
+ f"def { opname } (self, other: Any) -> Tensor: ... # type: ignore[override] # noqa: PYI032"
209
+ ]
210
+ if name in inplace_binary_ops :
211
+ # e.g.: `__iadd__`, `__imul__`
212
+ # Use `Self` as return type instead of `Tensor` to allow for subclasses
213
+ return [f"def { opname } (self, other: Any) -> Self: ..." ]
214
+ if name in binary_ops or name in comparison_ops :
215
+ # e.g.: `__add__`, `__mul__` and `__le__`, `__gt__`
241
216
return [f"def { opname } (self, other: Any) -> Tensor: ..." ]
242
- if name in comparison_ops :
243
- sig = f"def { opname } (self, other: Any) -> Tensor: ..."
244
- if name in symmetric_comparison_ops :
245
- # unsafe override https://github.com/python/mypy/issues/5704
246
- sig += " # type: ignore[override]"
247
- return [sig ]
248
217
if name in unary_ops :
218
+ # e.g.: `__pos__`, `__neg__`, `__abs__`
249
219
return [f"def { opname } (self) -> Tensor: ..." ]
250
220
if name in to_py_type_ops :
221
+ # e.g.: `__int__`, `__index__`, `__float__`, `__bool__`
251
222
if name in {"bool" , "float" , "complex" }:
252
223
tname = name
253
224
elif name == "nonzero" :
0 commit comments