8000 add type annotations to torch.nn.quantized.modules.conv (#49702) · pytorch/pytorch@55919a4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 55919a4

Browse files
add type annotations to torch.nn.quantized.modules.conv (#49702)
Summary: closes gh-49700 No mypy issues were found in the first three entries deleted from `mypy.ini`: ``` [mypy-torch.nn.qat.modules.activations] ignore_errors = True [mypy-torch.nn.qat.modules.conv] ignore_errors = True [mypy-torch.nn.quantized.dynamic.modules.linear] ignore_errors = True ``` Pull Request resolved: #49702 Reviewed By: walterddr, zou3519 Differential Revision: D25767119 Pulled By: ezyang fbshipit-source-id: cb83e53549a299538e1b154cf8b79e3280f7392a
1 parent 54ce171 commit 55919a4

File tree

2 files changed

+45
-25
lines changed

2 files changed

+45
-25
lines changed

mypy.ini

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,7 @@ ignore_errors = True
9191
[mypy-torch.nn.modules.pooling]
9292
ignore_errors = True
9393

94-
[mypy-torch.nn.qat.modules.activations]
95-
ignore_errors = True
96-
97-
[mypy-torch.nn.qat.modules.conv]
98-
ignore_errors = True
99-
100-
[mypy-torch.nn.quantized.dynamic.modules.linear]
101-
ignore_errors = True
102-
103-
[mypy-torch.nn.quantized.modules.conv]
94+
[mypy-torch.nn.parallel._functions]
10495
ignore_errors = True
10596

10697
[mypy-torch._appdirs]

torch/nn/quantized/modules/conv.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# coding=utf-8
22
r"""Quantized convolution modules."""
33

4-
from typing import Optional, List
4+
from typing import Optional, List, TypeVar
55

66
import torch
77
import torch.nn as nn
@@ -16,11 +16,17 @@
1616

1717
class _ConvNd(nn.Module):
1818

19-
def __init__(self, in_channels, out_channels, kernel_size, stride,
20-
padding, dilation,
21-
transposed, output_padding,
22-
groups, bias,
19+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
20+
padding=0, dilation=1, groups=1, bias=True,
2321
padding_mode='zeros'):
22+
# All subclasses have this signature - See PR #49702s
23+
raise NotImplementedError
24+
25+
def _init(self, in_channels, out_channels, kernel_size, stride,
26+
padding, dilation,
27+
transposed, output_padding,
28+
groups, bias,
29+
padding_mode='zeros'):
2430
super(_ConvNd, self).__init__()
2531
if padding_mode != 'zeros':
2632
raise NotImplementedError(
@@ -54,6 +60,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,
5460
self.scale = 1.0
5561
self.zero_point = 0
5662

63+
def set_weight_bias(self, qweight, bias_float):
64+
raise NotImplementedError
65+
66+
def bias(self):
67+
raise NotImplementedError
68+
69+
def _weight_bias(self):
70+
raise NotImplementedError
71+
5772
def extra_repr(self):
5873
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
5974
', stride={stride}, scale={scale}, zero_point={zero_point}')
@@ -155,7 +170,8 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
155170
assert weight_post_process.dtype == torch.qint8, \
156171
'Weight observer must have a dtype of qint8'
157172
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
158-
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
173+
# the __init__ call used is the one from derived classes and not the one from _ConvNd
174+
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
159175
mod.stride, mod.padding, mod.dilation, mod.groups,
160176
mod.bias is not None, mod.padding_mode)
161177
qconv.set_weight_bias(qweight, mod.bias)
@@ -233,7 +249,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
233249
padding = _pair_from_first(padding)
234250
dilation = _pair_from_first(dilation)
235251

236-
super(Conv1d, self).__init__(
252+
# Subclasses of _ConvNd needs to call _init rather than __init__. See
253+
# discussion on PR #49702
254+
super(Conv1d, self)._init(
237255
in_channels, out_channels, kernel_size, stride, padding, dilation,
238256
False, _single(0), groups, bias, padding_mode)
239257

@@ -319,7 +337,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
319337
stride = _pair(stride)
320338
padding = _pair(padding)
321339
dilation = _pair(dilation)
322-
super(Conv2d, self).__init__(
340+
# Subclasses of _ConvNd need to call _init rather than __init__. See
341+
# discussion on PR #49702
342+
super(Conv2d, self)._init(
323343
in_channels, out_channels, kernel_size, stride, padding, dilation,
324344
False, _pair(0), groups, bias, padding_mode)
325345

@@ -403,7 +423,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
403423
stride = _triple(stride)
404424
padding = _triple(padding)
405425
dilation = _triple(dilation)
406-
super(Conv3d, self).__init__(
426+
# Subclasses of _ConvNd need to call _init rather than __init__. See
427+
# discussion on PR #49702
428+
super(Conv3d, self)._init(
407429
in_channels, out_channels, kernel_size, stride, padding, dilation,
408430
False, _triple(0), groups, bias, padding_mode)
409431

@@ -450,15 +472,20 @@ def from_float(cls, mod):
450472
return cls.get_qconv(mod, activation_post_process)
451473

452474
# === Transposed Convolutions ===
475+
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
453476

454477
class _ConvTransposeNd(_ConvNd):
478+
479+
_FLOAT_MODULE = MOD
480+
455481
def __init__(self, in_channels, out_channels, kernel_size, stride,
456482
padding, dilation, transposed, output_padding,
457483
groups, bias, padding_mode):
458484
if padding_mode != 'zeros':
459485
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
460-
461-
super(_ConvTransposeNd, self).__init__(
486+
# Subclasses of _ConvNd need to call _init rather than __init__. See
487+
# discussion on PR #49702
488+
super(_ConvTransposeNd, self)._init(
462489
in_channels, out_channels, kernel_size, stride,
463490
padding, dilation, transposed, output_padding,
464491
groups, bias, padding_mode)
@@ -477,9 +504,10 @@ def from_float(cls, mod):
477504
mod (Module): a float module, either produced by torch.quantization
478505
utilities or provided by the user
479506
"""
480-
assert type(mod) == cls._FLOAT_MODULE, \
481-
' nnq.' + cls.__name__ + '.from_float only works for ' + \
482-
cls._FLOAT_MODULE.__name__
507+
# derived classes override cls._FLOAT_MODULE attribute
508+
msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
509+
cls._FLOAT_MODULE.__name__
510+
assert type(mod) == cls._FLOAT_MODULE, msg
483511
assert hasattr(mod, 'qconfig'), \
484512
'Input float module must have qconfig defined.'
485513
weight_post_process = mod.qconfig.weight()
@@ -488,7 +516,8 @@ def from_float(cls, mod):
488516
assert weight_post_process.dtype == torch.qint8, \
489517
'Weight observer must have a dtype of qint8'
490518
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
491-
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
519+
# the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
520+
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
492521
mod.stride, mod.padding, mod.output_padding, mod.groups,
493522
mod.bias is not None, mod.dilation, mod.padding_mode)
494523
qconv.set_weight_bias(qweight, mod.bias)

0 commit comments

Comments
 (0)
0