8000 Update on "Move all torch.nn.modules type annotations inline" · pytorch/pytorch@019c8d0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 019c8d0

Browse files
committed
Update on "Move all torch.nn.modules type annotations inline"
Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Moving these annotations inline was really hairy because of interactions with JIT, and also the fact that Python type erasure is a lie (inheriting from Generic *does* change the behavior of your object). Here is the list of things I had to fix and/or work around: - The quantization translation passes previously barfed if the weight/bias arguments were inferred to be Optional. Previously, TorchScript type inference would have inferred that these arguments were non-Optional (because type inference happens after module construction), but accurate type annotations on these parameters override this inference process, causing the arguments to be optional. I fixed this by making the quantized operator signatures line up exactly with the non-quantized signatures, so we never change the types of the arguments. This change involved mostly making a bunch of quantized kernels take optional, and then error if they were passed nullopt. (You can have any color you like, as long as it's non-null.) - I removed Generic support for Module and ModuleList. The intentions behind this were admirable, but making Module inherit from Generic ended up being more headache than it was worth. First, in Python 3.6 and earlier, Generic has a nontrivial metaclass, which means all subsequent metaclass shenanigans (e.g., ScriptModule) need to line up the right metaclass. Second, Generic defines `__new__` specially, which means that 10000 `inspect.signature` doesn't work (see https://bugs.python.org/issue40897), and I found a case of people using precisely this in the wild. Between these two problems, and also the general problem which is that the parametrization here is an incomplete fix (parametrization helps with output typing, but it doesn't solve problems with input typing (and with mypy as it stands this is unfixable, see python/mypy#3028) I decided to just eliminate Module generics entirely. We still apply the Callable trick so that subclasses of Module don't cause mypy to complain, but otherwise you are on your own for getting accurate type information out of Modules. - The `Callable` trick on `forward` caused TorchScript to stop performing inference on the forward body, which is bad because in general we can only figure out the most accurate type by doing TorchScript inference. I added a special case to `infer_type` to ensure we always do inference for `Module.forward`, even if it is annotated (which it is), and another special case to make sure we ignore references to Callable (which we shouldn't process) recursively. - When `__annotations__` is set on a class (as is the case when you add type annotations), JIT will incorrectly add further annotations to the parent class. This PR fixes #39463 by testing if `__annotations__` is defined on the specific class, excluding parent classes from the test. - Added a missing fake source range to the invocation of `get_signature` - In some cases, we cannot provide accurate typing for parameters on modules. This usually occurs when you have an `Optional[Tensor]` parameter, whose optional-ness is determined at `__init__` time. Without the annotation, TorchScript will infer the correct refined type depending on arguments to the constructor, but with the annotation, it will never do a refinement at `__init__` time, and you'll end up with the wrong type. I ended up just straight up deleting type annotations in all of these cases. A more robust fix might be to make some way to force TorchScript to do inference even if there is an explicit annotation, in case of refinement. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D21497397](https://our.internmc.facebook.com/intern/diff/D21497397) [ghstack-poisoned]
2 parents 9320e50 + a451dd8 commit 019c8d0

File tree

3 files changed

+9
-21
lines changed

3 files changed

+9
-21
lines changed

torch/nn/parallel/data_parallel.pyi

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
1-
from typing import Any, Optional, TypeVar
1+
from typing import Any, Optional
22
from .common_types import _devices_t, _device_t
33
from ..modules import Module
44
from ... import device, Tensor
55

6-
T_co = TypeVar('T_co', covariant=True)
7-
class DataParallel(Module[T_co]):
6+
class DataParallel(Module):
87
module: Module = ...
98
device_ids: _devices_t = ...
109
dim: int = ...
1110
output_device: _device_t = ...
1211
src_device_obj: device = ...
1312

14-
def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ..., output_device: Optional[_device_t] = ...,
13+
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ..., output_device: Optional[_device_t] = ...,
1514
dim: int = ...) -> None: ...
1615

17-
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
10000
18-
def __call__(self, *inputs: Any, **kwargs: Any) -> T_co: ...
19-
2016

2117
def data_parallel(module: Module, inputs: Any, device_ids: Optional[_devices_t] = ...,
2218
output_device: Optional[_device_t] = ..., dim: int = ...,

torch/nn/parallel/distributed.pyi

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from ..modules import Module
2-
from typing import Any, Optional, TypeVar
2+
from typing import Any, Optional
33
from .common_types import _devices_t, _device_t
44

5-
T_co = TypeVar('T_co', covariant=True)
65

7-
8-
class DistributedDataParallel(Module[T_co]):
6+
class DistributedDataParallel(Module):
97
process_group: Any = ...
108
dim: int = ...
11-
module: Module[T_co] = ...
9+
module: Module = ...
1210
device_ids: _devices_t = ...
1311
output_device: _device_t = ...
1412
broadcast_buffers: bool = ...
@@ -17,11 +15,7 @@ class DistributedDataParallel(Module[T_co]):
1715
bucket_bytes_cap: float = ...
1816

1917
# TODO type process_group once `distributed` module is stubbed
20-
def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ...,
18+
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
2119
output_device: Optional[_device_t] = ..., dim: int = ...,
2220
broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
2321
check_reduction: bool = ...) -> None: ...
24-
25-
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
26-
27-
def __call__(self, *inputs: Any, **kwargs: Any) -> T_co: ...

torch/nn/parallel/replicate.pyi

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ from typing import List, Union, Sequence, TypeVar
22
from ..modules import Module
33
from .common_types import _devices_t
44

5-
T = TypeVar('T')
65

7-
8-
def replicate(network: Module[T], devices: Union[_devices_t, Sequence[_devices_t]], detach: bool = ...) -> List[
9-
Module[T]]: ...
6+
def replicate(network: Module, devices: Union[_devices_t, Sequence[_devices_t]], detach: bool = ...) -> List[
7+
Module]: ...

0 commit comments

Comments
 (0)
0