8000 Fix #152280: add Literal[…] PaddingMode to Conv modules by AnandVishesh1301 · Pull Request #152590 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix #152280: add Literal[…] PaddingMode to Conv modules #152590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions torch/ao/nn/qat/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn
from torch.ao.nn.intrinsic import _FusedModule
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.conv import PaddingMode
from torch.nn.modules.utils import _pair, _single, _triple


Expand All @@ -26,7 +27,7 @@ def __init__(
output_padding: tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
padding_mode: PaddingMode,
qconfig=None,
device=None,
dtype=None,
Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
qconfig=None,
device=None,
dtype=None,
Expand Down Expand Up @@ -208,7 +209,7 @@ def __init__(
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
qconfig=None,
device=None,
dtype=None,
Expand Down Expand Up @@ -272,7 +273,7 @@ def __init__(
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
qconfig=None,
device=None,
dtype=None,
Expand Down
3 changes: 2 additions & 1 deletion torch/ao/nn/quantized/dynamic/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch._ops import ops
from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
from torch.nn.common_types import _size_1_t
from torch.nn.modules.conv import PaddingMode
from torch.nn.modules.utils import _pair, _single, _triple


Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
reduce_range=True,
Expand Down
3 changes: 2 additions & 1 deletion torch/ao/nn/quantized/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn.functional as F
from torch._ops import ops
from torch.nn.common_types import _size_1_t
from torch.nn.modules.conv import PaddingMode
from torch.nn.modules.utils import _pair, _single, _triple
from torch.nn.utils import fuse_conv_bn_weights

Expand Down Expand Up @@ -401,7 +402,7 @@ def __init__(
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
):
Expand Down
5 changes: 3 additions & 2 deletions torch/ao/nn/quantized/reference/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.common_types import _size_1_t
from torch.nn.modules.conv import PaddingMode

from .utils import ReferenceQuantizedModule

Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
weight_qparams: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -282,7 +283,7 @@ def __init__(
groups: int = 1,
bias: bool = True,
dilation: _size_1_t = 1,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
weight_qparams: Optional[dict[str, Any]] = None,
Expand Down
33 changes: 18 additions & 15 deletions torch/nn/modules/conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
from typing import Literal, Optional, Union
from typing_extensions import deprecated

import torch
Expand All @@ -10,6 +10,9 @@
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.parameter import Parameter, UninitializedParameter


PaddingMode = Literal["zeros", "reflect", "replicate", "circular"]

from .lazy import LazyModuleMixin
from .module import Module
from .utils import _pair, _reverse_repeat_tuple, _single, _triple
Expand Down Expand Up @@ -79,7 +82,7 @@ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -
transposed: bool
output_padding: tuple[int, ...]
groups: int
padding_mode: str
padding_mode: PaddingMode
Copy link
Collaborator
@Skylion007 Skylion007 May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the abstract base classes can still have str and only the leaf classes need PaddingMode for backwards compatibility

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Skylion007 thank you for your feedback and suggestion regarding keeping str in the abstract/base classes for backward compatibility.

After further testing, I found that mypy and lintrunner checks fail if the type is str in the base class and PaddingMode in the leaf classes, due to Python’s type system limitations. To ensure that all type checks pass and the PR can be merged, I have used PaddingMode in both the abstract and concrete classes.

I agree that keeping str in the base class would be ideal for maximum backward compatibility, and perhaps in the future, the mypy configuration or type checking strategy could be updated to allow this pattern.

Please let me know if you’d like me to explore any alternative solutions, or if this approach works for now!

weight: Tensor
bias: Optional[Tensor]

Expand All @@ -95,7 +98,7 @@ def __init__(
output_padding: tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
padding_mode: PaddingMode,
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -328,7 +331,7 @@ def __init__(
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
padding_mode: PaddingMode = "zeros", # TODO: refine this type
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -509,7 +512,7 @@ def __init__(
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
padding_mode: PaddingMode = "zeros", # TODO: refine this type
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -680,7 +683,7 @@ def __init__(
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -927,7 +930,7 @@ def __init__(
groups: int = 1,
bias: bool = True,
dilation: _size_1_t = 1,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1117,7 +1120,7 @@ def __init__(
groups: int = 1,
bias: bool = True,
dilation: _size_2_t = 1,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1310,7 +1313,7 @@ def __init__(
groups: int = 1,
bias: bool = True,
dilation: _size_3_t = 1,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1503,7 +1506,7 @@ def __init__(
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1572,7 +1575,7 @@ def __init__(
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
padding_mode: PaddingMode = "zeros", # TODO: refine this type
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1642,7 +1645,7 @@ def __init__(
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1710,7 +1713,7 @@ def __init__(
groups: int = 1,
bias: bool = True,
dilation: _size_1_t = 1,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1779,7 +1782,7 @@ def __init__(
groups: int = 1,
bias: bool = True,
dilation: int = 1,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -1848,7 +1851,7 @@ def __init__(
groups: int = 1,
bias: bool = True,
dilation: _size_3_t = 1,
padding_mode: str = "zeros",
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
) -> None:
Expand Down
0