10000 Move all torch.nn.modules type annotations inline · pytorch/pytorch@fe7a09d · GitHub
[go: up one dir, main page]

Skip to content

Commit fe7a09d

Browse files
committed
Move all torch.nn.modules type annotations inline
Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were some hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Third, JIT appears to not like it when you add type signatures to __call__ on Module, so we have to 10000 "hide" the signature (using the same trick that we did on issue two). Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: d30476b Pull Request resolved: #38211
1 parent 9ea36e6 commit fe7a09d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+978
-2072
lines changed

mypy.ini

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,51 @@ ignore_errors = True
133133
[mypy-torch._tensor_str]
134134
ignore_errors = True
135135

136+
[mypy-torch.nn.modules.activation]
137+
ignore_errors = True
138+
139+
[mypy-torch.nn.modules.batchnorm]
140+
ignore_errors = True
141+
142+
[mypy-torch.nn.modules.container]
143+
ignore_errors = True
144+
145+
[mypy-torch.nn.modules.conv]
146+
ignore_errors = True
147+
148+
[mypy-torch.nn.modules.fold]
149+
ignore_errors = True
150+
151+
[mypy-torch.nn.modules.instancenorm]
152+
ignore_errors = True
153+
154+
[mypy-torch.nn.modules.linear]
155+
ignore_errors = True
156+
157+
[mypy-torch.nn.modules.loss]
158+
ignore_errors = True
159+
160+
[mypy-torch.nn.modules.module]
161+
ignore_errors = True
162+
163+
[mypy-torch.nn.modules.normalization]
164+
ignore_errors = True
165+
166+
[mypy-torch.nn.modules.padding]
167+
ignore_errors = True
168+
169+
[mypy-torch.nn.modules.pooling]
170+
ignore_errors = True
171+
172+
[mypy-torch.nn.modules.rnn]
173+
ignore_errors = True
174+
175+
[mypy-torch.nn.modules.sparse]
176+
ignore_errors = True
177+
178+
[mypy-torch.nn.modules.upsampling]
179+
ignore_errors = True
180+
136181
[mypy-torch.nn.parallel._functions]
137182
ignore_errors = True
138183

tools/pyi/gen_pyi.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -341,26 +341,6 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
341341

342342
return type_hints
343343

344-
def gen_nn_modules(out):
345-
def replace_forward(m):
346-
# We instruct mypy to not emit errors for the `forward` and `__call__` declarations since mypy
347-
# would otherwise correctly point out that Module's descendants' `forward` declarations
348-
# conflict with `Module`s. Specifically, `Module` defines `forward(self, *args)` while the
349-
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
350-
# violates Liskov substitutability. The 'mypy' team recommended this solution for now.
351-
forward_def = m.group(0) + " # type: ignore"
352-
call_def = re.sub(r'def forward', 'def __call__', forward_def)
353-
new_def = "{}\n{}".format(forward_def, call_def)
354-
return new_def
355-
pattern = re.compile(r'^\s*def forward\(self.*$', re.MULTILINE)
356-
for fname in glob.glob("torch/nn/modules/*.pyi.in"):
357-
with open(fname, 'r') as f:
358-
src = f.read()
359-
res = pattern.sub(replace_forward, src)
360-
fname_out = fname[:-3]
361-
with open(os.path.join(out, fname_out), 'w') as f:
362-
f.write(res)
363-
364344
def gen_nn_functional(out):
365345
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
366346
# through an `_add_docstr` call
@@ -420,7 +400,6 @@ def gen_nn_functional(out):
420400

421401
def gen_nn_pyi(out):
422402
gen_nn_functional(out)
423-
gen_nn_modules(out)
424403

425404
def gen_pyi(declarations_path, out):
426405
"""gen_pyi()

torch/CMakeLists.txt

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -218,60 +218,10 @@ endif()
218218
# upsampling
219219
# )
220220
# list(TRANSFORM Modules PREPEND "${TORCH_SRC_DIR}/nn/modules/")
221-
# set(ModuleStubIn ${Modules})
222-
# set(ModuleStubOut ${Modules})
223-
# list(TRANSFORM ModuleStubIn APPEND ".pyi.in")
224-
# list(TRANSFORM ModuleStubOut APPEND ".pyi")
225-
set(ModulesStubIn
226-
${TORCH_SRC_DIR}/nn/modules/__init__.pyi.in
227-
${TORCH_SRC_DIR}/nn/modules/activation.pyi.in
228-
${TORCH_SRC_DIR}/nn/modules/adaptive.pyi.in
229-
${TORCH_SRC_DIR}/nn/modules/batchnorm.pyi.in
230-
${TORCH_SRC_DIR}/nn/modules/container.pyi.in
231-
${TORCH_SRC_DIR}/nn/modules/conv.pyi.in
232-
${TORCH_SRC_DIR}/nn/modules/distance.pyi.in
233-
${TORCH_SRC_DIR}/nn/modules/dropout.pyi.in
234-
${TORCH_SRC_DIR}/nn/modules/fold.pyi.in
235-
${TORCH_SRC_DIR}/nn/modules/flatten.pyi.in
236-
${TORCH_SRC_DIR}/nn/modules/instancenorm.pyi.in
237-
${TORCH_SRC_DIR}/nn/modules/linear.pyi.in
238-
${TORCH_SRC_DIR}/nn/modules/loss.pyi.in
239-
${TORCH_SRC_DIR}/nn/modules/module.pyi.in
240-
${TORCH_SRC_DIR}/nn/modules/normalization.pyi.in
241-
${TORCH_SRC_DIR}/nn/modules/padding.pyi.in
242-
${TORCH_SRC_DIR}/nn/modules/pixelshuffle.pyi.in
243-
${TORCH_SRC_DIR}/nn/modules/pooling.pyi.in
244-
${TORCH_SRC_DIR}/nn/modules/rnn.pyi.in
245-
${TORCH_SRC_DIR}/nn/modules/sparse.pyi.in
246-
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi.in
247-
)
248-
set(ModulesStubOut
249-
${TORCH_SRC_DIR}/nn/modules/__init__.pyi
250-
${TORCH_SRC_DIR}/nn/modules/activation.pyi
251-
${TORCH_SRC_DIR}/nn/modules/adaptive.pyi
252-
${TORCH_SRC_DIR}/nn/modules/batchnorm.pyi
253-
${TORCH_SRC_DIR}/nn/modules/container.pyi
254-
${TORCH_SRC_DIR}/nn/modules/conv.pyi
255-
${TORCH_SRC_DIR}/nn/modules/distance.pyi
256-
${TORCH_SRC_DIR}/nn/modules/dropout.pyi
257-
${TORCH_SRC_DIR}/nn/modules/fold.pyi
258-
${TORCH_SRC_DIR}/nn/modules/instancenorm.pyi
259-
${TORCH_SRC_DIR}/nn/modules/linear.pyi
260-
${TORCH_SRC_DIR}/nn/modules/loss.pyi
261-
${TORCH_SRC_DIR}/nn/modules/module.pyi
262-
${TORCH_SRC_DIR}/nn/modules/normalization.pyi
263-
${TORCH_SRC_DIR}/nn/modules/padding.pyi
264-
${TORCH_SRC_DIR}/nn/modules/pixelshuffle.pyi
265-
${TORCH_SRC_DIR}/nn/modules/pooling.pyi
266-
${TORCH_SRC_DIR}/nn/modules/rnn.pyi
267-
${TORCH_SRC_DIR}/nn/modules/sparse.pyi
268-
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi
269-
)
270221
add_custom_target(torch_python_stubs DEPENDS
271222
"${TORCH_SRC_DIR}/_C/__init__.pyi"
272223
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
273224
"${TORCH_SRC_DIR}/nn/functional.pyi"
274-
${ModuleStubOut}
275225
)
276226
# For Declarations.yaml dependency
277227
add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
@@ -280,7 +230,6 @@ add_custom_command(
280230
"${TORCH_SRC_DIR}/_C/__init__.pyi"
281231
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
282232
"${TORCH_SRC_DIR}/nn/functional.pyi"
283-
${ModuleStubOut}
284233
COMMAND
285234
"${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi
286235
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
@@ -289,7 +238,6 @@ add_custom_command(
289238
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
290239
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
291240
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
292-
${ModuleStubIn}
293241
"${TOOLS_PATH}/pyi/gen_pyi.py"
294242
WORKING_DIRECTORY
295243
"${TORCH_ROOT}"

torch/jit/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,22 @@
2828
import warnings
2929
import weakref
3030

31+
from typing import TypeVar, Generic
32+
33+
# See https://github.com/python/typing/issues/449
34+
try:
35+
# Python 3.6 and earlier
36+
from typing import GenericMeta
37+
except ImportError:
38+
GenericMeta = type
39+
3140

3241
# These are imported so users can access them from the `torch.jit` module
3342
from torch._jit_internal import Final, _overload, _overload_method
3443
from torch._jit_internal import ignore, export, unused
3544

45+
T_co = TypeVar('T_co', covariant=True)
46+
3647
def _parse_env(name, default, true_message, false_message):
3748
value = os.environ.get(name)
3849
if value is None:
@@ -1485,8 +1496,10 @@ def __getitem__(self, k):
14851496
# run. This has to occur after the user-defined __init__ so that submodules and
14861497
# parameters are initialized _before_ the script compiler resolve references to
14871498
# `self.param` or `self.module`.
1488-
class ScriptMeta(type):
1489-
def __init__(cls, name, bases, attrs):
1499+
# This inherits from GenericMeta because it's used to metaclass Module (which is
1500+
# generic). This is not necessary in Python 3.7 and later.
1501+
class ScriptMeta(GenericMeta):
1502+
def __init__(cls, name, bases, attrs): # noqa: B902
14901503
# Aggregate all the ScriptMethods and constants from superclasses
14911504
cls._methods = {}
14921505
cls._constants_set = set(getattr(cls, '__constants__', ()))
File renamed without changes.

torch/nn/modules/__init__.pyi.in

Lines changed: 0 additions & 49 deletions
This file was deleted.

0 commit comments

Comments
 (0)
0