8000 Add typing on `Tensor` dunder methods for binary arithmetic ops by lkct · Pull Request #103394 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add typing on Tensor dunder methods for binary arithmetic ops #103394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/_prims/rng_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def philox_rand_offset(
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
grid_size = (numel + block_size - 1) // block_size
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) # type: ignore[call-overload]
Copy link
Contributor Author
@lkct lkct Jun 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is out of the scope of this PR as I'm not sure why we use Tensor (instead of int) for grid_size here.

The error is no overload for min(Tensor, int). This is because Tensor.__lt__() returns Tensor, incompatible with typeshed's annotation on min, which requires:
https://github.com/python/typeshed/blob/052d2b9f3a9afaa56b86f0c84bc9d8769458d96a/stdlib/_typeshed/__init__.pyi#L53-L54

class SupportsDunderLT(Protocol[_T_contra]):
    def __lt__(self, __other: _T_contra) -> bool: ...

This requirement is introduced in python/typeshed#7093, where they ask __lt__ must return bool but not any other bool-compatible type.

offset = (
(numel - 1) // (block_size * grid_size * unroll) + 1
) * curand4_engine_calls
Expand Down
57 changes: 41 additions & 16 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from collections import OrderedDict
from copy import deepcopy
from numbers import Number
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union

from typing_extensions import ParamSpec # Python 3.10+

import torch
import torch._C as _C
Expand All @@ -27,12 +29,18 @@
)
from torch.utils.dlpack import DLDeviceType

T = TypeVar("T")
P = ParamSpec("P")
Self = TypeVar("Self", bound="Tensor") # replacement of PEP-673 Self


def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
def _handle_torch_function_and_wrap_type_error_to_not_implemented(
f: Callable[P, T]
) -> Callable[P, T]:
assigned = functools.WRAPPER_ASSIGNMENTS

@functools.wraps(f, assigned=assigned)
def wrapped(*args, **kwargs):
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
try:
# See https://github.com/pytorch/pytorch/issues/75462
if has_torch_function(args):
Expand Down Expand Up @@ -848,25 +856,42 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None
)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rsub__(self, other):
def __rsub__(self, other: Any) -> "Tensor":
return _C._VariableFunctions.rsub(self, other)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rdiv__(self, other):
def __rdiv__(self, other: Any) -> "Tensor":
return self.reciprocal() * other

__rtruediv__ = __rdiv__
__itruediv__ = _C._TensorBase.__idiv__

__pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C._TensorBase.pow
# TODO Currently there's no standard solution to combine ParamSpec and overload.
# mypy will only pass the first overload to ParamSpec, making `__pow__` and
# `__ipow__` only accept `exponent: Tensor`. Therefore, we explicitly annotate
# them and force the linter to accept so that end-users can see it correct.
# However, we directly use Any here, otherwise there'll be too many to suppress.

def __pow__(self, exponent: Any) -> "Tensor":
...

__pow__ = ( # noqa: F811
_handle_torch_function_and_wrap_type_error_to_not_implemented(
_C._TensorBase.pow
)
)
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C._TensorBase.pow_

def __ipow__(self: Self, exponent: Any) -> Self:
...

__ipow__ = ( 8000 # noqa: F811
_handle_torch_function_and_wrap_type_error_to_not_implemented( # pass UFMT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why # pass UFMT is needed here but on on line 879)?

Suggested change
_handle_torch_function_and_wrap_type_error_to_not_implemented( # pass UFMT
_handle_torch_function_and_wrap_type_error_to_not_implemented(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, UFMT requires # noqa: F811 moved onto line 888. However, FLAKE8 and RUFF need this noqa on Line 887.
So it's needed to pass UFMT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to change _handle_torch_function_and_wrap_type_error_to_not_implemented to a shorter name. Any suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why # pass UFMT is needed here but on on line 879)?

Sorry, seems I didn't positively answer this question. They are different in that line 889 has a # type: ignore[arg-type], but not on 880. Therefore this is added to 888 but not 879.

However, I tried a little more and found out that we can remove the type: ignore after #103376 is landed.

_C._TensorBase.pow_ # type: ignore[arg-type]
)
)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmod__(self, other):
def __rmod__(self, other: Any) -> "Tensor":
return torch.remainder(other, self)

def __format__(self, format_spec):
Expand All @@ -877,28 +902,28 @@ def __format__(self, format_spec):
return object.__format__(self, format_spec)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rpow__(self, other):
def __rpow__(self, other: Any) -> "Tensor":
dtype = torch.result_type(other, self)
return torch.tensor(other, dtype=dtype, device=self.device) ** self

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __floordiv__(self, other):
def __floordiv__(self, other: Any) -> "Tensor":
return torch.floor_divide(self, other)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rfloordiv__(self, other):
def __rfloordiv__(self, other: Any) -> "Tensor":
return torch.floor_divide(other, self)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rlshift__(self, other):
def __rlshift__(self, other: Any) -> "Tensor":
return torch.bitwise_left_shift(other, self)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rrshift__(self, other):
def __rrshift__(self, other: Any) -> "Tensor":
return torch.bitwise_right_shift(other, self)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmatmul__(self, other):
def __rmatmul__(self, other: Any) -> "Tensor":
return torch.matmul(other, self)

__pos__ = _C._TensorBase.positive
Expand Down
0