8000 Fix torch.nonzero type annotation (#51635) · pytorch/pytorch@649e683 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 649e683

Browse files
rgommersfacebook-github-bot
authored andcommitted
Fix torch.nonzero type annotation (#51635)
Summary: The overloads are a little tricky here. It's important that the overloads are such that it's unambiguous what `torch.nonzero(x)` will resolve to - so just specify defaults for one of the overloads. Also, `out` is left out of the second overload because a non-None value for `out` is not valid in combination with `as_tuple=True`. Closes gh-51434 Pull Request resolved: #51635 Reviewed By: zhangguanheng66 Differential Revision: D26279203 Pulled By: walterddr fbshipit-source-id: 8459c04fc9fbf7fc5f31b3f631aaac2f98b17ea6
1 parent 0dd1d60 commit 649e683

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

tools/pyi/gen_pyi.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
322322
' layout: _layout=strided, {}) -> Tensor: ...'
323323
.format(FACTORY_PARAMS)],
324324
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
325-
'nonzero': ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...',
326-
'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'],
325+
'nonzero': ['def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...',
326+
'def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
327327
'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, '
328328
'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, '
329329
'reduce: Optional[bool] = None, reduction: str = ..., '
@@ -424,7 +424,8 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
424424
'element_size': ['def element_size(self) -> _int: ...'],
425425
'data_ptr': ['def data_ptr(self) -> _int: ...'],
426426
'dim': ['def dim(self) -> _int: ...'],
427-
'nonzero': ['def nonzero(self, *, as_tuple: _bool=...) -> Tensor: ...'],
427+
'nonzero': ['def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...',
428+
'def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
428429
'numel': ['def numel(self) -> _int: ...'],
429430
'ndimension': ['def ndimension(self) -> _int: ...'],
430431
'nelement': ['def nelement(self) -> _int: ...'],

torch/_C/_VariableFunctions.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
44
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
5+
from typing_extensions import Literal
56
from torch._six import inf
67

78
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import (
88
Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
99
NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
1010
Generic, Set, AnyStr)
11+
from typing_extensions import Literal
1112
from torch._six import inf
1213

1314
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage

0 commit comments

Comments
 (0)
0