8000 [JIT] Compilation-induced discrepancy in F.instance_norm when passing input as running stats · Issue #153315 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[JIT] Compilation-induced discrepancy in F.instance_norm when passing input as running stats #153315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zhouxiaoyaozzz opened this issue May 10, 2025 · 4 comments
Labels
module: correctness (silent) issue that returns an incorrect result silently oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@zhouxiaoyaozzz
Copy link
zhouxiaoyaozzz commented May 10, 2025

🐛 Bug Description

When scripting a model containing F.instance_norm with broadcasted input, JIT-compiled results differ from eager mode.

🔍 Minimal Reproduction Code

import torch
import torch.nn as nn
import torch.nn.functional as F


class NeuralModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlist = nn.ParameterList([
            nn.Parameter(torch.tensor([[
                [[5.0269, 3.4145, 3.8807],
                 [6.3197, 6.5815, 5.3826],
                 [3.4846, 4.3035, 6.9038]],

                [[3.7721, 3.0422, 4.3315],
                 [3.2518, 3.5617, 6.4604],
                 [3.5747, 5.1481, 6.6348]],

                [[3.0173, 3.5291, 6.8552],
                 [5.6125, 6.2321, 6.3142],
                 [4.7299, 4.2638, 5.2731]]
            ]], dtype=torch.float64))
        ])

    def forward(self, x):
        expanded = x.unsqueeze(-1).unsqueeze(-1)  # (3,) -> (3,1,1)
        multiplied = x * self.mlist[0]  # Broadcast multiply
        inst_norm = F.instance_norm(multiplied, x, x)
        log_softmax = F.log_softmax(multiplied, dim=-1)
        bilinear = F.interpolate(log_softmax, scale_factor=1.0, mode='bilinear')  # 修复:1.0 而不是 1

        return {
            'v0_0': expanded,
            'v6_0': inst_norm,
            'v2_0': bilinear
        }

input_data = torch.tensor([494.91649119, 528.01665228, 492.01463052], dtype=torch.float64)
model = NeuralModel()
with torch.no_grad():
    output_eager = model(input_data)

model_scripted = torch.jit.script(model)
with torch.no_grad():
    output_scripted = model_scripted(input_data)

print("NONJIT (v6_0):", output_eager['v6_0'])
print("JIT  (v6_0):", output_scripted['v6_0'])
print("consistency:", torch.allclose(output_eager['v6_0'], output_scripted['v6_0']))

output:

NONJIT (v6_0): tensor([[[[-0.0792, -1.1555, -0.9882],
          [ 0.9260,  1.4719,  0.1728],
          [-1.2785, -0.4180,  1.3487]],

         [[-0.5777, -0.9980, -0.1514],
          [-0.9931, -0.5555,  1.5384],
          [-0.7353,  0.7958,  1.6768]],

         [[-1.7688, -1.1584,  1.3313],
          [ 0.3497,  1.1958,  0.8923],
          [-0.3708, -0.5185,  0.0474]]]], dtype=torch.float64)
JIT  (v6_0): tensor([[[[ 0.1992, -1.2898, -1.0547],
          [ 1.2781,  1.2202,  0.0836],
          [-1.0879, -0.5852,  1.2365]],

         [[-0.3739, -1.1909, -0.2243],
          [-0.8553, -0.7344,  1.5646],
          [-0.5565,  0.6596,  1.7112]],

         [[-1.6616, -1.3545,  1.2870],
          [ 0.7235,  1.0046,  0.8354],
          [-0.0877, -0.7133, -0.0335]]]], dtype=torch.float64)
consistency: False

Versions

Collecting environment information...
PyTorch version: 2.0.1+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Could not collect
GCC version: Could not collect
Clang version: 20.1.2
CMake version: version 4.0.0
Libc version: N/A

Python version: 3.9.7 (tags/v3.9.7:1016ef3, Aug 30 2021, 20:19:38) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 560.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.26.1
[pip3] torch==2.0.1
[conda] Could not collect

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @chauhang @penguinwu

@bigachin
Copy link

Can I take up this issue?

@zhouxiaoyaozzz
Copy link
Author

Can I take up this issue?

Thank you for following this issue! Welcome to the discussion!

@zhouxiaoyaozzz zhouxiaoyaozzz changed the title JIT (ScriptMode) Inconsistency with F.instance_norm Due to Incorrect Parameter Usage [JIT] Compilation-induced discrepancy in F.instance_norm when passing input as running stats May 10, 2025
@bigachin
Copy link

Can I be assigned to this issue?

@zhouxiaoyaozzz
Copy link
Author
zhouxiaoyaozzz commented May 10, 2025

Can I be assigned to this issue?

@bigachin, I am an ordinary user and cannot be directly assigned. I have helped you @pytorchbot

@williamwen42 williamwen42 added oncall: jit Add this issue/PR to JIT oncall triage queue module: correctness (silent) issue that returns an incorrect result silently and removed oncall: pt2 labels May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants
0