diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index 28568811fcbd..02b2ed3022ba 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -70,7 +70,7 @@ def philox_rand_offset( device_property = torch.cuda.get_device_properties(torch.cuda.current_device()) blocks_per_sm = device_property.max_threads_per_multi_processor // block_size grid_size = (numel + block_size - 1) // block_size - grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) + grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) # type: ignore[call-overload] offset = ( (numel - 1) // (block_size * grid_size * unroll) + 1 ) * curand4_engine_calls diff --git a/torch/_tensor.py b/torch/_tensor.py index b4e2e84bc3dd..c8faa3a4ce85 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -5,7 +5,9 @@ from collections import OrderedDict from copy import deepcopy from numbers import Number -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union + +from typing_extensions import ParamSpec # Python 3.10+ import torch import torch._C as _C @@ -27,12 +29,18 @@ ) from torch.utils.dlpack import DLDeviceType +T = TypeVar("T") +P = ParamSpec("P") +Self = TypeVar("Self", bound="Tensor") # replacement of PEP-673 Self + -def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): +def _handle_torch_function_and_wrap_type_error_to_not_implemented( + f: Callable[P, T] +) -> Callable[P, T]: assigned = functools.WRAPPER_ASSIGNMENTS @functools.wraps(f, assigned=assigned) - def wrapped(*args, **kwargs): + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: try: # See https://github.com/pytorch/pytorch/issues/75462 if has_torch_function(args): @@ -848,25 +856,42 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None ) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rsub__(self, other): + def __rsub__(self, other: Any) -> "Tensor": return _C._VariableFunctions.rsub(self, other) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rdiv__(self, other): + def __rdiv__(self, other: Any) -> "Tensor": return self.reciprocal() * other __rtruediv__ = __rdiv__ __itruediv__ = _C._TensorBase.__idiv__ - __pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( - _C._TensorBase.pow + # TODO Currently there's no standard solution to combine ParamSpec and overload. + # mypy will only pass the first overload to ParamSpec, making `__pow__` and + # `__ipow__` only accept `exponent: Tensor`. Therefore, we explicitly annotate + # them and force the linter to accept so that end-users can see it correct. + # However, we directly use Any here, otherwise there'll be too many to suppress. + + def __pow__(self, exponent: Any) -> "Tensor": + ... + + __pow__ = ( # noqa: F811 + _handle_torch_function_and_wrap_type_error_to_not_implemented( + _C._TensorBase.pow + ) ) - __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( - _C._TensorBase.pow_ + + def __ipow__(self: Self, exponent: Any) -> Self: + ... + + __ipow__ = ( # noqa: F811 + _handle_torch_function_and_wrap_type_error_to_not_implemented( # pass UFMT + _C._TensorBase.pow_ # type: ignore[arg-type] + ) ) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rmod__(self, other): + def __rmod__(self, other: Any) -> "Tensor": return torch.remainder(other, self) def __format__(self, format_spec): @@ -877,28 +902,28 @@ def __format__(self, format_spec): return object.__format__(self, format_spec) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rpow__(self, other): + def __rpow__(self, other: Any) -> "Tensor": dtype = torch.result_type(other, self) return torch.tensor(other, dtype=dtype, device=self.device) ** self @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __floordiv__(self, other): + def __floordiv__(self, other: Any) -> "Tensor": return torch.floor_divide(self, other) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: Any) -> "Tensor": return torch.floor_divide(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rlshift__(self, other): + def __rlshift__(self, other: Any) -> "Tensor": return torch.bitwise_left_shift(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rrshift__(self, other): + def __rrshift__(self, other: Any) -> "Tensor": return torch.bitwise_right_shift(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rmatmul__(self, other): + def __rmatmul__(self, other: Any) -> "Tensor": return torch.matmul(other, self) __pos__ = _C._TensorBase.positive