8000 `torch.batch_norm` shows inconsistent error behavior between CPU and GPU · Issue #153137 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
torch.batch_norm shows inconsistent error behavior between CPU and GPU #153137
Closed
@SilentTester73

Description

@SilentTester73

🐛 Describe the bug

Description

When torch.batch_norm is called with one of running_mean or running_var as a tensor and the other as None, an internal assertion Expected has_running_mean == has_running_var to be true, but got false is triggered on CUDA-enabled GPUs. However, this error is not triggered when the same code is run on the CPU.

Ideally, the behavior should be consistent across devices, meaning either both CPU and GPU should raise this specific error.

To Reproduce:

The following code demonstrates the issue. It tests two scenarios:

  1. running_mean is a Tensor, running_var is None.
  2. running_mean is None, running_var is a Tensor.
import torch

print(f"PyTorch Version: {torch.__version__}")

# Common parameters for torch.batch_norm
weight_param = None
bias_param = None
is_training_param = True # Error occurs with True or False
momentum_param = 0.1
eps_param = 1e-5
cudnn_enabled_param = True # Also occurs with False on GPU

# --- Scenario 1: running_mean is Tensor, running_var is None ---
print("\n--- Scenario 1: running_mean is Tensor, running_var is None ---")
# Input tensor
input_tensor_shape = (3, 4, 5) # N, C, D*
num_features = input_tensor_shape[1]

# CPU
print("  CPU (Scenario 1):")
try:
    input_tensor_cpu = torch.randn(input_tensor_shape)
    running_mean_param_cpu = torch.randn(num_features)
    running_var_param_cpu = None
    
    torch.batch_norm(
        input_tensor_cpu,
        weight_param,
        bias_param,
        running_mean_param_cpu,
        running_var_param_cpu,
        is_training_param,
        momentum_param,
        eps_param,
        cudnn_enabled_param
    )
    print("    CPU: Error not triggered.")
except RuntimeError as e:
    print(f"    CPU Error: {e}")
    if "Expected has_running_mean == has_running_var to be true, but got false" in str(e):
        print("    CPU: Successfully triggered the target error (unexpected based on current behavior).")

# GPU
if torch.cuda.is_available():
    print("  GPU (Scenario 1):")
    try:
        input_tensor_gpu = torch.randn(input_tensor_shape).cuda()
        running_mean_param_gpu = torch.randn(num_features).cuda()
        running_var_param_gpu = None
        
        torch.batch_norm(
            input_tensor_gpu,
            weight_param,
            bias_param,
            running_mean_param_gpu,
            running_var_param_gpu,
            is_training_param,
            momentum_param,
            eps_param,
            cudnn_enabled_param
        )
        print("    GPU: Error not triggered (unexpected for this specific error message).")
    except RuntimeError as e:
        print(f"    GPU Error: {e}")
        if "Expected has_running_mean == has_running_var to be true, but got false" in str(e):
            print("    GPU: Successfully triggered the target error.")
else:
    print("  GPU (Scenario 1): CUDA not available, skipping GPU test.")

# --- Scenario 2: running_mean is None, running_var is Tensor ---
print("\n--- Scenario 2: running_mean is None, running_var is Tensor ---")

# CPU
print("  CPU (Scenario 2):")
try:
    input_tensor_cpu = torch.randn(input_tensor_shape)
    running_mean_param_cpu = None
    running_var_param_cpu = torch.randn(num_features)
    
    torch.batch_norm(
        input_tensor_cpu,
        weight_param,
        bias_param,
        running_mean_param_cpu,
        running_var_param_cpu,
        is_training_param,
        momentum_param,
        eps_param,
        cudnn_enabled_param
    )
    print("    CPU: Error not triggered.")
except RuntimeError as e:
    print(f"    CPU Error: {e}")
    if "Expected has_running_mean == has_running_var to be true, but got false" in str(e):
        print("    CPU: Successfully triggered the target error (unexpected based on current behavior).")

# GPU
if torch.cuda.is_available():
    print("  GPU (Scenario 2):")
    try:
        input_tensor_gpu = torch.randn(input_tensor_shape).cuda()
        running_mean_param_gpu = None
        running_var_param_gpu = torch.randn(num_features).cuda()
        
        torch.batch_norm(
            input_tensor_gpu,
            weight_param,
            bias_param,
            running_mean_param_gpu,
            running_var_param_gpu,
            is_training_param,
            momentum_param,
            eps_param,
            cudnn_enabled_param
        )
        print("    GPU: Error not triggered (unexpected for this specific error message).")
    except RuntimeError as e:
        print(f"    GPU Error: {e}")
        if "Expected has_running_mean == has_running_var to be true, but got false" in str(e):
            print("    GPU: Successfully triggered the target error.")
else:
    print("  GPU (Scenario 2): CUDA not available, skipping GPU test.")

Expected behavior:

The error RuntimeError: Expected has_running_mean == has_running_var to be true, but got false. should either be raised consistently on both CPU and GPU

Actual behavior ):

PyTorch Version: 2.6.0+cu124

--- Scenario 1: running_mean is Tensor, running_var is None ---
CPU (Scenario 1):
CPU: Error not triggered.
GPU (Scenario 1):
GPU Error: Expected has_running_mean == has_running_var to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
GPU: Successfully triggered the target error.

--- Scenario 2: running_mean is None, running_var is Tensor ---
CPU (Scenario 2):
CPU: Error not triggered.
GPU (Scenario 2):
GPU Error: Expected has_running_mean == has_running_var to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
GPU: Successfully triggered the target error.

The full code used for testing can be found at:
https://colab.research.google.com/drive/17xWrbcKvMTpDTHcz_XSrnb5esLCUh970?usp=sharing

Versions

Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.8 (++20240731025043+3b5b5c1ec4a3-1~exp1~20240731145144.92)
CMake version: version 4.0.0
Libc version: glibc-2.39

Python version: 3.12.3 (main, Feb  4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 570.133.20
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        39 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               32
On-line CPU(s) list:                  0-31
Vendor ID:                            GenuineIntel
Model name:                           13th Gen Intel(R) Core(TM) i9-13900F
CPU family:                           6
Model:                                183
Thread(s) per core:                   2
Core(s) per socket:                   24
Socket(s):                            1
Stepping:                             1
CPU(s) scaling MHz:                   49%
CPU max MHz:                          5600.0000
CPU min MHz:                          800.0000
BogoMIPS:                             3993.60
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            896 KiB (24 instances)
L1i cache:                            1.3 MiB (24 instances)
L2 cache:                             32 MiB (12 instances)
L3 cache:                             36 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-31
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] optree==0.15.0
[pip3] torch==2.7.0
[pip3] triton==3.3.0
[conda] Could not collect

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @malfet

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: edge casesAdversarial inputs unlikely to occur in practicemodule: error checkingBugs related to incorrect/lacking error checkingmodule: nnRelated to torch.nnmodule: norms and normalizationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0