8000 Behavior of kernel_size parameter of torch.nn.functional.avg_pool2d does not match with documentation · Issue #153149 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Behavior of kernel_size parameter of torch.nn.functional.avg_pool2d does not match with documentation #153149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sahas3 opened this issue May 8, 2025 · 2 comments
Labels
actionable module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn module: pooling oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@sahas3
Copy link
sahas3 commented May 8, 2025

📚 The doc issue

Documentation states that the kernel_size parameter can be a single number or a tuple of two integers (kH, kW). However, a tuple of a single integer works as well and based on my observation the behavior is same as passing a single integer, that is, the same value is used along height and width dimensions.

>>> import torch                                                                                                                                                                                                                                                                                                                    
>>> input = torch.rand([16, 528, 16, 16], dtype=torch.float32)                                                                                                           
>>> o = torch.nn.functional.avg_pool2d(arg_1, 2)                                                                                                                         
>>> print(o.shape)                                                                                                                                                       
torch.Size([16, 528, 8, 8])                                                                                                                                              
>>> o = torch.nn.functional.avg_pool2d(arg_1, (2,))                                                                                                                      
>>> print(o.shape)                                                                                                                                                       
torch.Size([16, 528, 8, 8])                                                                                                                                              
>>> o = torch.nn.functional.avg_pool2d(arg_1, (2,2,2))                                                                                                                   
Traceback (most recent call last):                                                                                                                                       
  File "<stdin>", line 1, in <module>                                                                                                                                    
RuntimeError: avg_pool2d: kernel_size must either be a single int, or a tuple of two ints                                                                                
>>> o = torch.nn.functional.avg_pool2d(arg_1, (2))                                                                                                                       
>>> print(o.shape)                                                                                                                                                       
torch.Size([16, 528, 8, 8])

I am just looking for clarity on whether this is an intended behavior and if downstream consumers of PyTorch models should support this behavior or error out cleanly. This undocumented behavior is causing an assertion in the torch-mlir repository that is used to represent PyTorch ExportedProgram models in MLIR. For more details, please see llvm/torch-mlir#3885.

Versions

Collecting environment information...
PyTorch version: 2.5.1+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (x86_64)
GCC version: (Debian 12.2.0-14) 12.2.0
Clang version: 14.0.6
CMake version: version 4.0.0
Libc version: glibc-2.36

Python version: 3.11.2 (main, Nov 30 2024, 21:22:50) [GCC 12.2.0] (64-bit runtime)
Python platform: Linux-6.1.0-33-amd64-x86_64-with-glibc2.36
Is CUDA available: False
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: Quadro P620
GPU 1: NVIDIA TITAN V

Nvidia driver version: 550.54.14
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) W-2133 CPU @ 3.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 4
CPU(s) scaling MHz: 42%
CPU max MHz: 3900.0000
CPU min MHz: 1200.0000
BogoMIPS: 7200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 6 MiB (6 instances)
L3 cache: 8.3 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT vulnerable

Versions of relevant libraries:
[pip3] numpy==2.0.0
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.21.1
[pip3] onnxscript==0.2.5
[pip3] torch==2.5.1+cpu
[pip3] torchvision==0.20.1+cpu
[conda] Could not collect

Suggest a potential alternative/fix

If supporting tuple of single int is intentional, it'll be good to update the documentation to match that behavior.

Thanks!

cc @svekars @sekyondaMeta @AlannaBurke @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@jbschlosser jbschlosser added module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module actionable module: pooling and removed actionable labels May 8, 2025
@albanD
Copy link
Collaborator
albanD commented May 9, 2025

Hey!
Thanks for the report, it is indeed just a documentation issue as a single-element tuple is treated the same as a single int.

@sahas3
Copy link
Author
sahas3 commented May 14, 2025

Hi @albanD , thanks for the confirmation.

Is it possible to update the behavior in ExportedProgram ?

The single int case

class AvgPool2dFloatModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ap2d = torch.nn.AvgPool2d(
            kernel_size=2,
        )

    def forward(self, x):
        return self.ap2d(x)

module = AvgPool2dFloatModule()
input = torch.rand([16, 528, 16, 16])
prog = torch.export.export(module, (input, ))
print(prog)

produces

class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[16, 528, 16, 16]"):
             # File: /mathworks/devel/sandbox/sayans/geckWorks/torchBugs/pool.py:47 in forward, code: return self.ap2d(x)
            avg_pool2d: "f32[16, 528, 8, 8]" = torch.ops.aten.avg_pool2d.default(x, [2, 2], [2, 2]);  x = None
            return (avg_pool2d,)

but the single-element tuple produces

class AvgPool2dFloatModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ap2d = torch.nn.AvgPool2d(
            kernel_size=(2,)
        )

    def forward(self, x):
        return self.ap2d(x)

module = AvgPool2dFloatModule()
input = torch.rand([16, 528, 16, 16])
prog = torch.export.export(module, (input, ))
print(prog)
class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[16, 528, 16, 16]"):
             # File: /mathworks/devel/sandbox/sayans/geckWorks/torchBugs/pool.py:47 in forward, code: return self.ap2d(x)
            avg_pool2d: "f32[16, 528, 8, 8]" = torch.ops.aten.avg_pool2d.default(x, [2], [2]);  x = None
            return (avg_pool2d,)

Note that the kernel_value is not being repeated.

I'd be happy to contribute to fix this if you can point me to where to look at the PyTorch code-base, I am not familiar with it.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn module: pooling oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0