1
1
# coding=utf-8
2
2
r"""Quantized convolution modules."""
3
3
4
- from typing import Optional , List
4
+ from typing import Optional , List , TypeVar
5
5
6
6
import torch
7
7
import torch .nn as nn
16
16
17
17
class _ConvNd (nn .Module ):
18
18
19
- def __init__ (self , in_channels , out_channels , kernel_size , stride ,
20
- padding , dilation ,
21
- transposed , output_padding ,
22
- groups , bias ,
19
+ def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
20
+ padding = 0 , dilation = 1 , groups = 1 , bias = True ,
23
21
padding_mode = 'zeros' ):
22
+ # All subclasses have this signature - See PR #49702s
23
+ raise NotImplementedError
24
+
25
+ def _init (self , in_channels , out_channels , kernel_size , stride ,
26
+ padding , dilation ,
27
+ transposed , output_padding ,
28
+ groups , bias ,
29
+ padding_mode = 'zeros' ):
24
30
super (_ConvNd , self ).__init__ ()
25
31
if padding_mode != 'zeros' :
26
32
raise NotImplementedError (
@@ -54,6 +60,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,
54
60
self .scale = 1.0
55
61
self .zero_point = 0
56
62
63
+ def set_weight_bias (self , qweight , bias_float ):
64
+ raise NotImplementedError
65
+
66
+ def bias (self ):
67
+ raise NotImplementedError
68
+
69
+ def _weight_bias (self ):
70
+ raise NotImplementedError
71
+
57
72
def extra_repr (self ):
58
73
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
59
74
', stride={stride}, scale={scale}, zero_point={zero_point}' )
@@ -155,7 +170,8 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
155
170
assert weight_post_process .dtype == torch .qint8 , \
156
171
'Weight observer must have a dtype of qint8'
157
172
qweight = _quantize_weight (mod .weight .float (), weight_post_process )
158
- qconv = cls (mod .in_channels , mod .out_channels , mod .kernel_size ,
173
+ # the __init__ call used is the one from derived classes and not the one from _ConvNd
174
+ qconv = cls (mod .in_channels , mod .out_channels , mod .kernel_size , # type: ignore[call-arg]
159
175
mod .stride , mod .padding , mod .dilation , mod .groups ,
160
176
mod .bias is not None , mod .padding_mode )
161
177
qconv .set_weight_bias (qweight , mod .bias )
@@ -233,7 +249,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
233
249
padding = _pair_from_first (padding )
234
250
dilation = _pair_from_first (dilation )
235
251
236
- super (Conv1d , self ).__init__ (
252
+ # Subclasses of _ConvNd needs to call _init rather than __init__. See
253
+ # discussion on PR #49702
254
+ super (Conv1d , self )._init (
237
255
in_channels , out_channels , kernel_size , stride , padding , dilation ,
238
256
False , _single (0 ), groups , bias , padding_mode )
239
257
@@ -319,7 +337,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
319
337
stride = _pair (stride )
320
338
padding = _pair (padding )
321
339
dilation = _pair (dilation )
322
- super (Conv2d , self ).__init__ (
340
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
341
+ # discussion on PR #49702
342
+ super (Conv2d , self )._init (
323
343
in_channels , out_channels , kernel_size , stride , padding , dilation ,
324
344
False , _pair (0 ), groups , bias , padding_mode )
325
345
@@ -403,7 +423,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
403
423
stride = _triple (stride )
404
424
padding = _triple (padding )
405
425
dilation = _triple (dilation )
406
- super (Conv3d , self ).__init__ (
426
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
427
+ # discussion on PR #49702
428
+ super (Conv3d , self )._init (
407
429
in_channels , out_channels , kernel_size , stride , padding , dilation ,
408
430
False , _triple (0 ), groups , bias , padding_mode )
409
431
@@ -450,15 +472,20 @@ def from_float(cls, mod):
450
472
return cls .get_qconv (mod , activation_post_process )
451
473
452
474
# === Transposed Convolutions ===
475
+ MOD = TypeVar ('MOD' , bound = nn .modules .conv ._ConvNd )
453
476
454
477
class _ConvTransposeNd (_ConvNd ):
478
+
479
+ _FLOAT_MODULE = MOD
480
+
455
481
def __init__ (self , in_channels , out_channels , kernel_size , stride ,
456
482
padding , dilation , transposed , output_padding ,
457
483
groups , bias , padding_mode ):
458
484
if padding_mode != 'zeros' :
459
485
raise ValueError ('Only "zeros" padding mode is supported for {}' .format (self .__class__ .__name__ ))
460
-
461
- super (_ConvTransposeNd , self ).__init__ (
486
+ # Subclasses of _ConvNd need to call _init rather than __init__. See
487
+ # discussion on PR #49702
488
+ super (_ConvTransposeNd , self )._init (
462
489
in_channels , out_channels , kernel_size , stride ,
463
490
padding , dilation , transposed , output_padding ,
464
491
groups , bias , padding_mode )
@@ -477,9 +504,10 @@ def from_float(cls, mod):
477
504
mod (Module): a float module, either produced by torch.quantization
478
505
utilities or provided by the user
479
506
"""
480
- assert type (mod ) == cls ._FLOAT_MODULE , \
481
- ' nnq.' + cls .__name__ + '.from_float only works for ' + \
482
- cls ._FLOAT_MODULE .__name__
507
+ # derived classes override cls._FLOAT_MODULE attribute
508
+ msg = ' nnq.' + cls .__name__ + '.from_float only works for ' + \
509
+ cls ._FLOAT_MODULE .__name__
510
+ assert type (mod ) == cls ._FLOAT_MODULE , msg
483
511
assert hasattr (mod , 'qconfig' ), \
484
512
'Input float module must have qconfig defined.'
485
513
weight_post_process = mod .qconfig .weight ()
@@ -488,7 +516,8 @@ def from_float(cls, mod):
488
516
assert weight_post_process .dtype == torch .qint8 , \
489
517
'Weight observer must have a dtype of qint8'
490
518
qweight = _quantize_weight (mod .weight .float (), weight_post_process )
491
- qconv = cls (mod .in_channels , mod .out_channels , mod .kernel_size ,
519
+ # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
520
+ qconv = cls (mod .in_channels , mod .out_channels , mod .kernel_size , # type: ignore[call-arg]
492
521
mod .stride , mod .padding , mod .output_padding , mod .groups ,
493
522
mod .bias is not None , mod .dilation , mod .padding_mode )
494
523
qconv .set_weight_bias (qweight , mod .bias )
0 commit comments