-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Add typing on Tensor
dunder methods for binary arithmetic ops
#103394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,7 +5,9 @@ | |||||
from collections import OrderedDict | ||||||
from copy import deepcopy | ||||||
from numbers import Number | ||||||
from typing import Any, Dict, Optional, Tuple, Union | ||||||
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union | ||||||
|
||||||
from typing_extensions import ParamSpec # Python 3.10+ | ||||||
|
||||||
import torch | ||||||
import torch._C as _C | ||||||
|
@@ -27,12 +29,18 @@ | |||||
) | ||||||
from torch.utils.dlpack import DLDeviceType | ||||||
|
||||||
T = TypeVar("T") | ||||||
P = ParamSpec("P") | ||||||
Self = TypeVar("Self", bound="Tensor") # replacement of PEP-673 Self | ||||||
|
||||||
|
||||||
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): | ||||||
def _handle_torch_function_and_wrap_type_error_to_not_implemented( | ||||||
f: Callable[P, T] | ||||||
) -> Callable[P, T]: | ||||||
assigned = functools.WRAPPER_ASSIGNMENTS | ||||||
|
||||||
@functools.wraps(f, assigned=assigned) | ||||||
def wrapped(*args, **kwargs): | ||||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: | ||||||
try: | ||||||
# See https://github.com/pytorch/pytorch/issues/75462 | ||||||
if has_torch_function(args): | ||||||
|
@@ -848,25 +856,42 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None | |||||
) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rsub__(self, other): | ||||||
def __rsub__(self, other: Any) -> "Tensor": | ||||||
return _C._VariableFunctions.rsub(self, other) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rdiv__(self, other): | ||||||
def __rdiv__(self, other: Any) -> "Tensor": | ||||||
return self.reciprocal() * other | ||||||
|
||||||
__rtruediv__ = __rdiv__ | ||||||
__itruediv__ = _C._TensorBase.__idiv__ | ||||||
|
||||||
__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( | ||||||
_C._TensorBase.pow | ||||||
# TODO Currently there's no standard solution to combine ParamSpec and overload. | ||||||
# mypy will only pass the first overload to ParamSpec, making `__pow__` and | ||||||
# `__ipow__` only accept `exponent: Tensor`. Therefore, we explicitly annotate | ||||||
# them and force the linter to accept so that end-users can see it correct. | ||||||
# However, we directly use Any here, otherwise there'll be too many to suppress. | ||||||
|
||||||
def __pow__(self, exponent: Any) -> "Tensor": | ||||||
... | ||||||
|
||||||
__pow__ = ( # noqa: F811 | ||||||
_handle_torch_function_and_wrap_type_error_to_not_implemented( | ||||||
_C._TensorBase.pow | ||||||
) | ||||||
) | ||||||
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( | ||||||
_C._TensorBase.pow_ | ||||||
|
||||||
def __ipow__(self: Self, exponent: Any) -> Self: | ||||||
... | ||||||
|
||||||
__ipow__ = ( 8000 # noqa: F811 | ||||||
_handle_torch_function_and_wrap_type_error_to_not_implemented( # pass UFMT | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Otherwise, UFMT requires There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another option is to change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sorry, seems I didn't positively answer this question. They are different in that line 889 has a However, I tried a little more and found out that we can remove the |
||||||
_C._TensorBase.pow_ # type: ignore[arg-type] | ||||||
) | ||||||
) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rmod__(self, other): | ||||||
def __rmod__(self, other: Any) -> "Tensor": | ||||||
return torch.remainder(other, self) | ||||||
|
||||||
def __format__(self, format_spec): | ||||||
|
@@ -877,28 +902,28 @@ def __format__(self, format_spec): | |||||
return object.__format__(self, format_spec) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rpow__(self, other): | ||||||
def __rpow__(self, other: Any) -> "Tensor": | ||||||
dtype = torch.result_type(other, self) | ||||||
return torch.tensor(other, dtype=dtype, device=self.device) ** self | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __floordiv__(self, other): | ||||||
def __floordiv__(self, other: Any) -> "Tensor": | ||||||
return torch.floor_divide(self, other) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rfloordiv__(self, other): | ||||||
def __rfloordiv__(self, other: Any) -> "Tensor": | ||||||
return torch.floor_divide(other, self) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rlshift__(self, other): | ||||||
def __rlshift__(self, other: Any) -> "Tensor": | ||||||
return torch.bitwise_left_shift(other, self) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rrshift__(self, other): | ||||||
def __rrshift__(self, other: Any) -> "Tensor": | ||||||
return torch.bitwise_right_shift(other, self) | ||||||
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented | ||||||
def __rmatmul__(self, other): | ||||||
def __rmatmul__(self, other: Any) -> "Tensor": | ||||||
return torch.matmul(other, self) | ||||||
|
||||||
__pos__ = _C._TensorBase.positive | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is out of the scope of this PR as I'm not sure why we use
Tensor
(instead ofint
) forgrid_size
here.The error is no overload for
min(Tensor, int)
. This is becauseTensor.__lt__()
returnsTensor
, incompatible with typeshed's annotation onmin
, which requires:https://github.com/python/typeshed/blob/052d2b9f3a9afaa56b86f0c84bc9d8769458d96a/stdlib/_typeshed/__init__.pyi#L53-L54
This requirement is introduced in python/typeshed#7093, where they ask
__lt__
must returnbool
but not any other bool-compatible type.