8000 Move all torch.nn.modules type annotations inline by ezyang · Pull Request #38211 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Move all torch.nn.modules type annotations inline #38211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f352827
Move all torch.nn.modules type annotations inline
ezyang May 10, 2020
9d573b4
Update on "Move all torch.nn.modules type annotations inline"
ezyang May 11, 2020
07753d1
Update on "Move all torch.nn.modules type annotations inline"
ezyang May 11, 2020
61f853e
Update on "Move all torch.nn.modules type annotations inline"
ezyang May 11, 2020
70d706d
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 3, 2020
c8ce832
Update on "Move 8000 all torch.nn.modules type annotations inline"
ezyang Jun 3, 2020
abf86ce
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 3, 2020
e4ef38e
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 3, 2020
cbfef83
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 4, 2020
efca616
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 4, 2020
98f9001
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 4, 2020
6e248ab
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 4, 2020
4c41494
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 4, 2020
37042d1
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 6, 2020
87004ff
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 7, 2020
db976a6
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 7, 2020
6082278
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 7, 2020
3382863
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 7, 2020
7c34006
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 8, 2020
9320e50
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 8, 2020
019c8d0
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 9, 2020
4f26b2b
Update on "Move all torch.nn.modules type annotations inline"
ezyang Jun 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update on "Move all torch.nn.modules type annotations inline"
Just because the annotations are inline doesn't mean the files type
check; most of the newly annotated files have type errors and I
added exclusions for them in mypy.ini.  The payoff of moving
all of these modules inline is I can delete the relevant code
generation logic for the pyi files (which was added ignore
annotations that weren't actually relevant anymore.)
Because we aren't actually typechecking these modules in most
cases, it is inevitable that some of these type annotations are wrong.
I slavishly copied the old annotations from the pyi files unless there
was an obvious correction I could make.  These annotations will probably
need fixing up later.

Moving these annotations inline was really hairy because of interactions
with JIT, and also the fact that Python type erasure is a lie (inheriting
from Generic *does* change the behavior of your object). Here is
the list of things I had to fix and/or work around:

- The quantization translation passes previously barfed if the weight/bias arguments were inferred to be Optional. Previously, TorchScript type inference would have inferred that these arguments were non-Optional (because type inference happens after module construction), but accurate type annotations on these parameters override this inference process, causing the arguments to be optional. I fixed this by making the quantized operator signatures line up exactly with the non-quantized signatures, so we never change the types of the arguments. This change involved mostly making a bunch of quantized kernels take optional, and then error if they were passed nullopt. (You can have any color you like, as long as it's non-null.)
- I removed Generic support for Module and ModuleList. The intentions behind this were admirable, but making Module inherit from Generic ended up being more headache than it was worth. First, in Python 3.6 and earlier, Generic has a nontrivial metaclass, which means all subsequent metaclass shenanigans (e.g., ScriptModule) need to line up the right metaclass. Second, Generic defines `__new__` specially, which means that `inspect.signature` doesn't work (see https://bugs.python.org/issue40897), and I found a case of people using precisely this in the wild. Between these two problems, and also the general problem which is that the parametrization here is an incomplete fix (parametrization helps with output typing, but it doesn't solve problems with input typing (and with mypy as it stands this is unfixable, see python/mypy#3028) I decided to just eliminate Module generics entirely. We still apply the Callable trick so that subclasses of Module don't cause mypy to complain, but otherwise you are on your own for getting accurate type information out of Modules.
- The `Callable` trick on `forward` caused TorchScript to stop performing inference on the forward body, which is bad because in general we can only figure out the most accurate type by doing TorchScript inference. I added a special case to `infer_type` to ensure we always do inference for `Module.forward`, even if it is annotated (which it is), and another special case to make sure we ignore references to Callable (which we shouldn't process) recursively.
- When `__annotations__` is set on a class (as is the case when you add type annotations), JIT will incorrectly add further annotations to the parent class. This PR fixes #39463 by testing if `__annotations__` is defined on the specific class, excluding parent classes from the test.
- Added a missing fake source range to the invocation of `get_signature`
- In some cases, we cannot provide accurate typing for parameters on modules. This usually occurs when you have an `Optional[Tensor]` parameter, whose optional-ness is determined at `__init__` time. Without the annotation, TorchScript will infer the correct refined type depending on arguments to the constructor, but with the annotation, it will never do a refinement at `__init__` time, and you'll end up with the wrong type. I ended up just straight up deleting type annotations in all of these cases. A more robust fix might be to make some way to force TorchScript to do inference even if there is an explicit annotation, in case of refinement.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D21497397](https://our.internmc.facebook.com/intern/diff/D21497397)

[ghstack-poisoned]
  • Loading branch information
ezyang committed Jun 9, 2020
commit 4f26b2bb33f43535fdd7279543e3f79ffe0f12df
3 changes: 3 additions & 0 deletions test/type_hint_tests/module_list.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch

# ModuleList with elements of type Module
class FooModule(torch.nn.Module):
pass

class BarModule(torch.nn.Module):
pass

Expand Down
15 changes: 12 additions & 3 deletions torch/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,18 @@ class _ConvNd(Module):
weight: Tensor
bias: Optional[Tensor]

def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t,
padding: _size_1_t,
dilation: _size_1_t,
transposed: bool,
output_padding: _size_1_t,
groups: int,
bias: Optional[Tensor],
padding_mode: str) -> None:
super(_ConvNd, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PairwiseDistance(Module):
eps: float
keepdim: bool

def __init__(self, p: float = 2., eps: float = 1e-6, keepdim: bool = False):
def __init__(self, p: float = 2., eps: float = 1e-6, keepdim: bool = False) -> None:
super(PairwiseDistance, self).__init__()
self.norm = p
self.eps = eps
Expand Down Expand Up @@ -66,7 +66,7 @@ class CosineSimilarity(Module):
dim: int
eps: float

def __init__(self, dim: int = 1, eps: float = 1e-8):
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super(CosineSimilarity, self).__init__()
self.dim = dim
self.eps = eps
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class _DropoutNd(Module):
p: float
inplace: bool

def __init__(self, p: float = 0.5, inplace: bool = False):
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super(_DropoutNd, self).__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Flatten(Module):
start_dim: int
end_dim: int

def __init__(self, start_dim: int = 1, end_dim: int = -1):
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
dilation: _size_any_t = 1,
padding: _size_any_t = 0,
stride: _size_any_t = 1
):
) -> None:
super(Fold, self).__init__()
self.output_size = output_size
self.kernel_size = kernel_size
Expand Down Expand Up @@ -283,7 +283,7 @@ def __init__(
dilation: _size_any_t = 1,
padding: _size_any_t = 0,
stride: _size_any_t = 1
):
) -> None:
super(Unfold, self).__init__()
self.kernel_size = kernel_size
self.dilation = dilation
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/instancenorm.py
F438
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(
momentum: float = 0.1,
affine: bool = False,
track_running_stats: bool = False
):
) -> None:
super(_InstanceNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)

Expand Down
8 changes: 4 additions & 4 deletions torch/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Linear(Module):
out_features: int
weight: Tensor

def __init__(self, in_features: int, out_features: int, bias: bool = True):
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -101,7 +101,7 @@ def extra_repr(self) -> str:
class _LinearWithBias(Linear):
bias: Tensor

def __init__(self, in_features: int, out_features: int):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__(in_features, out_features, bias=True)


Expand Down Expand Up @@ -149,7 +149,7 @@ class Bilinear(Module):
out_features: int
weight: Tensor

def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True):
def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True) -> None:
super(Bilinear, self).__init__()
self.in1_features = in1_features
self.in2_features = in2_features
Expand All @@ -168,7 +168,7 @@ def reset_parameters(self) -> None:
if self.bias is not None:
init.uniform_(self.bias, -bound, bound)

def forward(self, input1: Tensor, input2: Tensor):
def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
return F.bilinear(input1, input2, self.weight, self.bias)

def extra_repr(self) -> str:
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class LocalResponseNorm(Module):
beta: float
k: float

def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.):
def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None:
super(LocalResponseNorm, self).__init__()
self.size = size
self.alpha = alpha
Expand All @@ -65,7 +65,7 @@ class CrossMapLRN2d(Module):
beta: float
k: float

def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1):
def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None:
super(CrossMapLRN2d, self).__init__()
self.size = size
self.alpha = alpha
Expand Down Expand Up @@ -143,7 +143,7 @@ class LayerNorm(Module):
eps: float
elementwise_affine: bool

def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True):
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
Expand Down Expand Up @@ -218,7 +218,7 @@ class GroupNorm(Module):
eps: float
affine: bool

def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True):
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
super(GroupNorm, self).__init__()
self.num_groups = num_groups
self.num_channels = num_channels
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/pixelshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PixelShuffle(Module):
__constants__ = ['upscale_factor']
upscale_factor: int

def __init__(self, upscale_factor: int):
def __init__(self, upscale_factor: int) -> None:
super(PixelShuffle, self).__init__()
self.upscale_factor = upscale_factor

Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def __init__(self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, p
self.stride = _single(stride or kernel_size)
self.padding = _single(padding)

def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None):
def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
return F.max_unpool1d(input, indices, self.kernel_size, self.stride,
self.padding, output_size)

Expand Down Expand Up @@ -382,7 +382,7 @@ def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, p
self.stride = _pair(stride or kernel_size)
self.padding = _pair(padding)

def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None):
def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
return F.max_unpool2d(input, indices, self.kernel_size, self.stride,
self.padding, output_size)

Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, p
self.stride = _triple(stride or kernel_size)
self.padding = _triple(padding)

def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None):
def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
return F.max_unpool3d(input, indices, self.kernel_size, self.stride,
self.padding, output_size)

Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ class RNNCell(RNNCellBase):
__constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
nonlinearity: str

def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh"):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh") -> None:
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
self.nonlinearity = nonlinearity

Expand Down Expand Up @@ -953,7 +953,7 @@ class LSTMCell(RNNCellBase):
output.append(hx)
"""

def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)

def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -1032,7 +1032,7 @@ class GRUCell(RNNCellBase):
output.append(hx)
"""

def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)

def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, decoder_layer, num_layers, norm=None):

def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None):
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer in turn.

Args:
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class Upsample(Module):
align_corners: bool

def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None,
mode: str = 'nearest', align_corners: Optional[bool] = None):
mode: str = 'nearest', align_corners: Optional[bool] = None) -> None:
super(Upsample, self).__init__()
self.name = type(self).__name__
self.size = size
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.
0