8000 Move all torch.nn.modules type annotations inline · pytorch/pytorch@e682eb6 · GitHub
[go: up one dir, main page]

Skip to content

Commit e682eb6

Browse files
committed
Move all torch.nn.modules type annotations inline
Pull Request resolved: #38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D21497397](https://our.internmc.facebook.com/intern/diff/D21497397/) ghstack-source-id: 77bf67e
1 parent 6d13b58 commit e682eb6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1020
-2119
lines changed

aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,19 @@ void compute_fused_params(
4242
template <bool ReluFused>
4343
Tensor q_batch_norm2d_impl(
4444
Tensor qx,
45-
Tensor weight,
46-
Tensor bias,
45+
c10::optional<Tensor> mb_weight,
46+
c10::optional<Tensor> mb_bias,
4747
Tensor mean,
4848
Tensor var,
4949
double eps,
5050
double output_scale,
5151
int64_t output_zero_point) {
5252

53+
TORCH_CHECK(mb_weight.has_value(), "Weight must be provided");
54+
TORCH_CHECK(mb_bias.has_value(), "Bias must be provided");
55+
const auto& weight = *mb_weight;
56+
const auto& bias = *mb_bias;
57+
5358
if (qx.numel() == 0) {
5459
auto out = qx.clone();
5560
return out;
@@ -131,14 +136,20 @@ Tensor q_batch_norm2d_impl(
131136
template <bool ReluFused>
132137
Tensor q_batch_norm3d_impl(
133138
Tensor qx,
134-
Tensor weight,
135-
Tensor bias,
139+
c10::optional<Tensor> mb_weight,
140+
c10::optional<Tensor> mb_bias,
136141
Tensor mean,
137142
Tensor var,
138143
double eps,
139144
double output_scale,
140145
int64_t output_zero_point) {
141146

147+
TORCH_CHECK(mb_weight.has_value(), "Weight must be provided")
148+
TORCH_CHECK(mb_bias.has_value(), "Bias must be provided")
149+
150+
const auto& weight = *mb_weight;
151+
const auto& bias = *mb_bias;
152+
142153
if (qx.numel() == 0) {
143154
auto out = qx.clone();
144155
return out;
@@ -231,8 +242,12 @@ Tensor quantized_batch_norm(
231242
double output_scale,
232243
int64_t output_zero_point) {
233244
Tensor qy;
245+
// TODO: this should arguably support 3d as well
234246
qy = q_batch_norm2d_impl<false>(
235-
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
247+
qx,
248+
weight.defined() ? c10::make_optional(weight) : c10::nullopt,
249+
bias.defined() ? c10::make_optional(bias) : c10::nullopt,
250+
mean, var, eps, output_scale, output_zero_point);
236251
return qy;
237252
}
238253

aten/src/ATen/native/quantized/cpu/qnormalization.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,33 +123,43 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
123123
m.impl("layer_norm", [](
124124
Tensor input,
125125
std::vector<int64_t> normalized_shape, // because IntArrayRef doesn't work
126-
Tensor weight /* optional */,
127-
Tensor bias /* optional */,
126+
c10::optional<Tensor> weight,
127+
c10::optional<Tensor> bias /* optional */,
128128
double eps,
129129
double output_scale,
130130
int64_t output_zero_point) {
131-
return quantized_layer_norm_impl(input, normalized_shape, weight, bias, eps, output_scale, output_zero_point);
131+
return quantized_layer_norm_impl(
132+
input, normalized_shape,
133+
weight.has_value() ? *weight : Tensor(),
134+
bias.has_value() ? *bias : Tensor(),
135+
eps, output_scale, output_zero_point);
132136
});
133137
m.impl("group_norm", [](
134138
Tensor qx,
135139
int64_t num_groups,
136-
Tensor weight,
137-
Tensor bias,
140+
c10::optional<Tensor> weight,
141+
c10::optional<Tensor> bias,
138142
double eps,
139143
double output_scale,
140144
int64_t output_zero_point) {
141145
return quantized_group_norm_impl(
142-
qx, num_groups, weight, bias, eps, output_scale, output_zero_point);
146+
qx, num_groups,
147+
weight.has_value() ? *weight : Tensor(),
148+
bias.has_value() ? *bias : Tensor(),
149+
eps, output_scale, output_zero_point);
143150
});
144151
m.impl("instance_norm", [](
145152
Tensor qx,
146-
Tensor weight,
147-
Tensor bias,
153+
c10::optional<Tensor> weight,
154+
c10::optional<Tensor> bias,
148155
double eps,
149156
double output_scale,
150157
int64_t output_zero_point) {
151158
return quantized_instance_norm_impl(
152-
qx, weight, bias, eps, output_scale, output_zero_point);
159+
qx,
160+
weight.has_value() ? *weight : Tensor(),
161+
bias.has_value() ? *bias : Tensor(),
162+
eps, output_scale, output_zero_point);
153163
});
154164
}
155165

aten/src/ATen/native/quantized/library.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ TORCH_LIBRARY(quantized, m) {
2525
m.def("add_scalar_relu(Tensor qa, Scalar b) -> Tensor qc");
2626
m.def("add_scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out");
2727
m.def("add_scalar_relu_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out");
28-
m.def("batch_norm2d(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
29-
m.def("batch_norm2d_relu(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
30-
m.def("batch_norm3d(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
31-
m.def("batch_norm3d_relu(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
28+
m.def("batch_norm2d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
29+
m.def("batch_norm2d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
30+
m.def("batch_norm3d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
31+
m.def("batch_norm3d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
3232
m.def("clamp(Tensor qx, Scalar? min, Scalar? max) -> Tensor qy");
3333
m.def("threshold(Tensor qx, Scalar threshold, Scalar value) -> Tensor qy");
3434
m.def("cat(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor");
@@ -64,9 +64,9 @@ TORCH_LIBRARY(quantized, m) {
6464
m.def("conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]");
6565
m.def("conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int");
6666
m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor");
67-
m.def("group_norm(Tensor input, int num_groups, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
68-
m.def("instance_norm(Tensor input, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
69-
m.def("layer_norm(Tensor input, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
67+
m.def("group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
68+
m.def("instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
69+
m.def("layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
7070
m.def(
7171
"linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
7272
m.def(

mypy.ini

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,51 @@ ignore_errors = True
146146
[mypy-torch._tensor_str]
147147
ignore_errors = True
148148

149+
[mypy-torch.nn.modules.activation]
150+
ignore_errors = True
151+
152+
[mypy-torch.nn.modules.batchnorm]
153+
ignore_errors = True
154+
155+
[mypy-torch.nn.modules.container]
156+
ignore_errors = True
157+
158+
[mypy-torch.nn.modules.conv]
159+
ignore_errors = True
160+
161+
[mypy-torch.nn.modules.fold]
162+
ignore_errors = True
163+
164+
[mypy-torch.nn.modules.instancenorm]
165+
ignore_errors = True
166+
167+
[mypy-torch.nn.modules.linear]
168+
ignore_errors = True
169+
170+
[mypy-torch.nn.modules.loss]
171+
ignore_errors = True
172+
173+
[mypy-torch.nn.modules.module]
174+
ignore_errors = True
175+
176+
[mypy-torch.nn.modules.normalization]
177+
ignore_errors = True
178+
179+
[mypy-torch.nn.modules.padding]
180+
ignore_errors = True
181+
182+
[mypy-torch.nn.modules.pooling]
183+
ignore_errors = True
184+
185+
[mypy-torch.nn.modules.rnn]
186+
ignore_errors = True
187+
188+
[mypy-torch.nn.modules.sparse]
189+
ignore_errors = True
190+
191+
[mypy-torch.nn.modules.upsampling]
192+
ignore_errors = True
193+
149194
[mypy-torch.nn.parallel._functions]
150195
ignore_errors = True
151196

test/type_hint_tests/module_list.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,5 @@
1-
from typing import Iterable
2-
31
import torch
42

5-
# ModuleList with elements of a specific type
6-
class FooModule(torch.nn.Module):
7-
def ten(self) -> int:
8-
return 10
9-
10-
class FooCollector(torch.nn.Module):
11-
def __init__(self, ml: Iterable[FooModule]) -> None:
12-
super(FooCollector, self).__init__()
13-
self.ml: torch.nn.ModuleList[FooModule] = torch.nn.ModuleList(ml)
14-
15-
def foo_sum(self) -> int:
16-
return sum(foo.ten() for foo in self.ml)
17-
18-
collector = FooCollector([FooModule(), FooModule()])
19-
twenty = collector.foo_sum()
20-
twenty == 20
21-
223
# ModuleList with elements of type Module
234
class BarModule(torch.nn.Module):
245
pass

tools/pyi/gen_pyi.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import print_function
22
import os
33
import collections
4-
import glob
54
import yaml
65
import re
76
import argparse
@@ -342,26 +341,6 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
342341

343342
return type_hints
344343

345-
def gen_nn_modules(out):
346-
def replace_forward(m):
347-
# We instruct mypy to not emit errors for the `forward` and `__call__` declarations since mypy
348-
# would otherwise correctly point out that Module's descendants' `forward` declarations
349-
# conflict with `Module`s. Specifically, `Module` defines `forward(self, *args)` while the
350-
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
351-
# violates Liskov substitutability. The 'mypy' team recommended this solution for now.
352-
forward_def = m.group(0) + " # type: ignore"
353-
call_def = re.sub(r'def forward', 'def __call__', forward_def)
354-
new_def = "{}\n{}".format(forward_def, call_def)
355-
return new_def
356-
pattern = re.compile(r'^\s*def forward\(self.*$', re.MULTILINE)
357-
for fname in glob.glob("torch/nn/modules/*.pyi.in"):
358-
with open(fname, 'r') as f:
359-
src = f.read()
360-
res = pattern.sub(replace_forward, src)
361-
fname_out = fname[:-3]
362-
with open(os.path.join(out, fname_out), 'w') as f:
363-
f.write(res)
364-
365344
def gen_nn_functional(out):
366345
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
367346
# through an `_add_docstr` call
@@ -421,7 +400,6 @@ def gen_nn_functional(out):
421400

422401
def gen_nn_pyi(out):
423402
gen_nn_functional(out)
424-
gen_nn_modules(out)
425403

426404
def gen_pyi(declarations_path, out):
427405
"""gen_pyi()

torch/CMakeLists.txt

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -222,60 +222,10 @@ endif()
222222
# upsampling
223223
# )
224224
# list(TR 10000 ANSFORM Modules PREPEND "${TORCH_SRC_DIR}/nn/modules/")
225-
# set(ModuleStubIn ${Modules})
226-
# set(ModuleStubOut ${Modules})
227-
# list(TRANSFORM ModuleStubIn APPEND ".pyi.in")
228-
# list(TRANSFORM ModuleStubOut APPEND ".pyi")
229-
set(ModulesStubIn
230-
${TORCH_SRC_DIR}/nn/modules/__init__.pyi.in
231-
${TORCH_SRC_DIR}/nn/modules/activation.pyi.in
232-
${TORCH_SRC_DIR}/nn/modules/adaptive.pyi.in
233-
${TORCH_SRC_DIR}/nn/modules/batchnorm.pyi.in
234-
${TORCH_SRC_DIR}/nn/modules/container.pyi.in
235-
${TORCH_SRC_DIR}/nn/modules/conv.pyi.in
236-
${TORCH_SRC_DIR}/nn/modules/distance.pyi.in
237-
${TORCH_SRC_DIR}/nn/modules/dropout.pyi.in
238-
${TORCH_SRC_DIR}/nn/modules/fold.pyi.in
239-
${TORCH_SRC_DIR}/nn/modules/flatten.pyi.in
240-
${TORCH_SRC_DIR}/nn/modules/instancenorm.pyi.in
241-
${TORCH_SRC_DIR}/nn/modules/linear.pyi.in
242-
${TORCH_SRC_DIR}/nn/modules/loss.pyi.in
243-
${TORCH_SRC_DIR}/nn/modules/module.pyi.in
244-
${TORCH_SRC_DIR}/nn/modules/normalization.pyi.in
245-
${TORCH_SRC_DIR}/nn/modules/padding.pyi.in
246-
${TORCH_SRC_DIR}/nn/modules/pixelshuffle.pyi.in
247-
${TORCH_SRC_DIR}/nn/modules/pooling.pyi.in
248-
${TORCH_SRC_DIR}/nn/modules/rnn.pyi.in
249-
${TORCH_SRC_DIR}/nn/modules/sparse.pyi.in
250-
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi.in
251-
)
252-
set(ModulesStubOut
253-
${TORCH_SRC_DIR}/nn/modules/__init__.pyi
254-
${TORCH_SRC_DIR}/nn/modules/activation.pyi
255-
${TORCH_SRC_DIR}/nn/modules/adaptive.pyi
256-
${TORCH_SRC_DIR}/nn/modules/batchnorm.pyi
257-
${TORCH_SRC_DIR}/nn/modules/container.pyi
258-
${TORCH_SRC_DIR}/nn/modules/conv.pyi
259-
${TORCH_SRC_DIR}/nn/modules/distance.pyi
260-
${TORCH_SRC_DIR}/nn/modules/dropout.pyi
261-
${TORCH_SRC_DIR}/nn/modules/fold.pyi
262-
${TORCH_SRC_DIR}/nn/modules/instancenorm.pyi
263-
${TORCH_SRC_DIR}/nn/modules/linear.pyi
264-
${TORCH_SRC_DIR}/nn/modules/loss.pyi
265-
${TORCH_SRC_DIR}/nn/modules/module.pyi
266-
${TORCH_SRC_DIR}/nn/modules/normalization.pyi
267-
${TORCH_SRC_DIR}/nn/modules/padding.pyi
268-
${TORCH_SRC_DIR}/nn/modules/pixelshuffle.pyi
269-
${TORCH_SRC_DIR}/nn/modules/pooling.pyi
270-
${TORCH_SRC_DIR}/nn/modules/rnn.pyi
271-
${TORCH_SRC_DIR}/nn/modules/sparse.pyi
272-
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi
273-
)
274225
add_custom_target(torch_python_stubs DEPENDS
275226
"${TORCH_SRC_DIR}/_C/__init__.pyi"
276227
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
277228
"${TORCH_SRC_DIR}/nn/functional.pyi"
278-
${ModuleStubOut}
279229
)
280230
# For Declarations.yaml dependency
281231
ad C94A d_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
@@ -284,7 +234,6 @@ add_custom_command(
284234
"${TORCH_SRC_DIR}/_C/__init__.pyi"
285235
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
286236
"${TORCH_SRC_DIR}/nn/functional.pyi"
287-
${ModuleStubOut}
288237
COMMAND
289238
"${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi
290239
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
@@ -293,7 +242,6 @@ add_custom_command(
293242
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
294243
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
295244
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
296-
${ModuleStubIn}
297245
"${TOOLS_PATH}/pyi/gen_pyi.py"
298246
WORKING_DIRECTORY
299247
"${TORCH_ROOT}"

torch/jit/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import warnings
3030
import weakref
3131

32-
3332
# These are imported so users can access them from the `torch.jit` module
3433
from torch._jit_internal import Final, _overload, _overload_method
3534
from torch._jit_internal import ignore, export, unused
@@ -1394,6 +1393,10 @@ def interface(obj):
13941393
if not _is_new_style_class(obj):
13951394
raise RuntimeError("TorchScript interfaces must inherit from 'object'")
13961395

1396+
# Expected MRO is:
1397+
# User module
1398+
# torch.nn.modules.module.Module
1399+
# object
13971400
is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3
13981401

13991402
if not is_module_interface and len(obj.mro()) > 2:
@@ -1555,7 +1558,7 @@ def __getitem__(self, k):
15551558
# parameters are initialized _before_ the script compiler resolve references to
15561559
# `self.param` or `self.module`.
15571560
class ScriptMeta(type):
1558-
def __init__(cls, name, bases, attrs):
1561+
def __init__(cls, name, bases, attrs): # noqa: B902
15591562
# Aggregate all the ScriptMethods and constants from superclasses
15601563
cls._methods = {}
15611564
cls._constants_set = set(getattr(cls, '__constants__', ()))
@@ -1641,8 +1644,12 @@ def __setattr__(self, attr, value):
16411644
# This ensures that if we use the attr again in `__init__`, it
16421645
# will look like the actual value, not an instance of Attribute.
16431646
if isinstance(value, Attribute):
1644-
if not hasattr(self, "__annotations__"):
1645-
self.__annotations__ = {}
1647+
# NB: Ensure that we set __annotations__ on the specific
1648+
# class in question, and not on a superclass (which would
1649+
# be wrong wrong wrong!).
1650+
# See also https://github.com/pytorch/pytorch/issues/39463
1651+
if "__annotations__" not in self.__class__.__dict__:
1652+
self.__class__.__annotations__ = {}
16461653
self.__annotations__[attr] = value.type
16471654
value = value.value
16481655
return super(ScriptModule, self).__setattr__(attr, value)

0 commit comments

Comments
 (0)
0