8000 Add type annotations to torch.overrides (#50824) · pytorch/pytorch@9dfbfe9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9dfbfe9

Browse files
Add type annotations to torch.overrides (#50824)
Summary: This is a follow up PR of #48493. Fixes #48492 Pull Request resolved: #50824 Reviewed By: bdhirsh Differential Revision: D26050736 Pulled By: ezyang fbshipit-source-id: 049605fd271cff28c8b6e300c163e9df3b3ea23b
1 parent 75cba9d commit 9dfbfe9

File tree

5 files changed

+57
-13
lines changed

5 files changed

+57
-13
lines changed

tools/pyi/gen_pyi.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
125125
'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic
126126
'and', 'or', 'xor', # logic
127127
'iadd', 'iand', 'idiv', 'ilshift', 'imul',
128-
'ior', 'irshift', 'isub', 'ixor', # inplace ops
128+
'ior', 'irshift', 'isub', 'ixor', 'ifloordiv', 'imod', # inplace ops
129129
)
130130
comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
131131
unary_ops = ('neg', 'abs', 'invert')
@@ -324,6 +324,32 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
324324
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
325325
'nonzero': ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...',
326326
'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'],
327+
'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, '
328+
'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, '
329+
'reduce: Optional[bool] = None, reduction: str = ..., '
330+
'pos_weight: Optional[Tensor] = None) -> Tensor: ...'],
331+
'cosine_embedding_loss': ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, '
332+
'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., '
333+
'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
334+
'ctc_loss': ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,'
335+
' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...'],
336+
'hinge_embedding_loss': ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,'
337+
' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., '
338+
'reduction: str = ...) -> Tensor: ...'],
339+
'kl_div': ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., '
340+
'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...'],
341+
'margin_ranking_loss': ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,'
342+
' margin: float = ..., size_average: Optional[bool] = ..., '
343+
' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
344+
'triplet_margin_loss': ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, '
345+
'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., '
346+
'size_average: Optional[bool] = ..., '
347+
'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
348+
'dsmm': ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
349+
'hsmm': ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
350+
'saddmm': ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, '
351+
'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...'],
352+
'spmm': ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
327353
})
328354
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
329355
unsorted_function_hints[binop].append(
@@ -382,10 +408,12 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
382408
'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM),
383409
],
384410
'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
411+
'_make_subclass': ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."],
385412
# clamp has no default values in the Declarations
386413
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
387414
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
388415
'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
416+
'__get__': ["def __get__(self, instance, owner=None) -> Tensor: ..."],
389417
'__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
390418
'__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
391419
" -> None: ...".format(INDICES)],
@@ -402,13 +430,17 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
402430
'numpy': ['def numpy(self) -> Any: ...'],
403431
'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
404432
'map_': ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'],
433+
'map2_': ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...'],
405434
'storage': ['def storage(self) -> Storage: ...'],
435+
'storage_type': ['def storage_type(self) -> Storage: ...'],
406436
'type': ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...',
407437
'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
408438
],
409439
'get_device': ['def get_device(self) -> _int: ...'],
410440
'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'],
441+
'has_names': ['def has_names(self) -> _bool: ...'],
411442
'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
443+
'_is_view': ['def _is_view(self) -> _bool: ...'],
412444
'is_cuda': ['is_cuda: _bool'],
413445
'is_leaf': ['is_leaf: _bool'],
414446
'is_sparse': ['is_sparse: _bool'],

torch/_C/__init__.pyi.in

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class device:
2727
type: str # THPDevice_type
2828
index: _int # THPDevice_index
2929

30+
def __get__(self, instance, owner=None) -> device: ...
31+
3032
# THPDevice_pynew
3133
@overload
3234
def __init__(self, device: Union[_device, _int, str]) -> None: ...
@@ -249,6 +251,9 @@ def _jit_is_script_object(obj: Any) -> _bool: ...
249251
def _last_executed_optimized_graph() -> Graph: ...
250252
def parse_type_comment(comment: str) -> Decl: ...
251253
def merge_type_from_type_comment(decl: Decl, type_annotation_decl: Decl, is_method: _bool) -> Decl: ...
254+
def parse_ir(input: str) -> Graph: ...
255+
def parse_schema(schema: str) -> FunctionSchema: ...
256+
def get_device(input: Tensor) -> _int: ...
252257
def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallback) -> JitType: ...
253258
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
254259
def _run_emit_module_hook(m: ScriptModule): ...
@@ -506,13 +511,15 @@ def _get_qengine() -> _int: ... # THPModule_qEngine
506511
def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine
507512
def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
508513
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
514+
def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction
509515
def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function
510516
def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
511517
def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic
512518
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
513519
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
514520
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
515521
def _demangle(str) -> str: ... # c10::demangle
522+
def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_torch_function
516523

517524
# Defined in `valgrind.h` and `callgrind.h` respecitively.
518525
def _valgrind_supported_platform() -> _bool: ... # NVALGRIND
@@ -644,9 +651,12 @@ class _TensorBase(object):
644651
imag: Tensor
645652
T: Tensor
646653
ndim: _int
654+
output_nr: _int
647655
_version: _int
648656
_base: Optional[Tensor]
657+
_cdata: _int
649658
grad_fn: Any
659+
_grad_fn: Any
650660
_grad: Optional[Tensor]
651661
_backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
652662
${tensor_method_hints}

torch/nn/modules/conv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def reset_parameters(self) -> None:
10201020
# has_uninitialized_params is defined in parent class and it is using a protocol on self
10211021
if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc]
10221022
# "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined
1023-
# super class. Turns out that it is defined in _ConvND which is inherited by any class
1023+
# in super class. Turns out that it is defined in _ConvND which is inherited by any class
10241024
# that also inherits _LazyConvXdMixin
10251025
super().reset_parameters() # type: ignore[misc]
10261026

@@ -1031,6 +1031,7 @@ def initialize_parameters(self, input) -> None: # type: ignore[override]
10311031
self.in_channels = input.shape[1]
10321032
if self.in_channels % self.groups != 0:
10331033
raise ValueError('in_channels must be divisible by groups')
1034+
assert isinstance(self.weight, UninitializedParameter)
10341035
if self.transposed:
10351036
self.weight.m 10000 aterialize((
10361037
self.in_channels, self.out_channels // self.groups, *self.kernel_size))

torch/nn/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,6 @@ def initialize_parameters(self, input) -> None: # type: ignore
223223
if self.has_uninitialized_params():
224224
with torch.no_grad():
225225
self.in_features = input.shape[-1]
226-
self.weight.materialize((self.out_features, self.in_features))
226+
self.weight.materialize((self.out_features, self.in_features)) # type: ignore
227227
self.reset_parameters()
228228
# TODO: PartialLinear - maybe in sparse?

torch/overrides.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import collections
2626
import functools
2727
import types
28-
from typing import Dict, Set, List, Any, Callable, Iterable
28+
from typing import Dict, Set, List, Any, Callable, Iterable, Type
2929

3030
import torch
3131
from torch._C import (
@@ -50,7 +50,7 @@ def get_ignored_functions() -> Set[Callable]:
5050
5151
Returns
5252
-------
53-
Tuple[Callable]
53+
Set[Callable]
5454
A tuple of functions that are publicly available in the torch API but cannot
5555
be overridden with ``__torch_function__``. Mostly this is because none of the
5656
arguments of these functions are tensors or tensor-likes.
@@ -102,7 +102,6 @@ def get_ignored_functions() -> Set[Callable]:
102102
torch.has_cuda,
103103
torch.has_cudnn,
104104
torch.has_lapack,
105-
torch.cpp,
106105
torch.device,
107106
torch.dtype,
108107
torch.finfo,
@@ -163,8 +162,8 @@ def get_ignored_functions() -> Set[Callable]:
163162
torch.triu_indices,
164163
torch.vander,
165164
torch.zeros,
165+
torch._jit_internal.boolean_dispatch,
166166
torch.nn.functional.assert_int_or_pair,
167-
torch.nn.functional.boolean_dispatch,
168167
torch.nn.functional.upsample,
169168
torch.nn.functional.upsample_bilinear,
170169
torch.nn.functional.upsample_nearest,
@@ -175,6 +174,8 @@ def get_ignored_functions() -> Set[Callable]:
175174
torch.nn.functional.sigmoid,
176175
torch.nn.functional.hardsigmoid,
177176
torch.nn.functional.tanh,
177+
has_torch_function,
178+
handle_torch_function,
178179
torch.set_autocast_enabled,
179180
torch.is_autocast_enabled,
180181
torch.clear_autocast_cache,
@@ -242,7 +243,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
242243
# function signatures for native kernels that can be consumed by inspect.
243244
# See Issue #28233.
244245
Tensor = torch.Tensor
245-
ret = {
246+
ret: Dict[Callable, Callable] = {
246247
torch.abs: lambda input, out=None: -1,
247248
torch.absolute: lambda input, out=None: -1,
248249
torch.adaptive_avg_pool1d: lambda input, output_size: -1,
@@ -356,7 +357,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
356357
torch.deg2rad: lambda input, out=None: -1,
357358
torch.dequantize: lambda input: -1,
358359
torch.det: lambda input: -1,
359-
torch.linalg.det: lambda input: -1, # alias for torch.det
360+
torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined]
360361
torch.detach: lambda input: -1,
361362
torch.diag: lambda input, diagonal=0, out=None: -1,
362363
torch.diag_embed: lambda input, diagonal=0, out=None: -1,
@@ -517,7 +518,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
517518
torch.less: lambda input, other, out=None: -1,
518519
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
519520
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
520-
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
521+
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
521522
torch.masked_fill: lambda input, mask, value: -1,
522523
torch.masked_scatter: lambda input, mask, source: -1,
523524
torch.masked_select: lambda input, mask, out=None: -1,
@@ -837,6 +838,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
837838
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
838839
torch.tril: lambda input, diagonal=0, out=None: -1,
839840
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
841+
840842
size_average=None, reduce=None, reduction='mean': -1),
841843
torch.triu: lambda input, diagonal=0, out=None: -1,
842844
torch.true_divide: lambda input, other: -1,
@@ -1123,8 +1125,8 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
11231125
https://numpy.org/neps/nep-0018-array-function-protocol.html
11241126
"""
11251127
# Runtime is O(num_arguments * num_unique_types)
1126-
overloaded_types = set()
1127-
overloaded_args = []
1128+
overloaded_types: Set[Type] = set()
1129+
overloaded_args: List[Any] = []
11281130
for arg in relevant_args:
11291131
arg_type = type(arg)
11301132
# We only collect arguments if they have a unique type, which ensures
@@ -1147,7 +1149,6 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
11471149
else:
11481150
overloaded_types = {arg_type}
11491151
overloaded_args = [arg]
1150-
11511152
return overloaded_args
11521153

11531154

0 commit comments

Comments
 (0)
0