8000 [inductor] [cpu] `nn.Tanhshrink-atan2` output inconsistent results with eager · Issue #148241 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] [cpu] nn.Tanhshrink-atan2 output inconsistent results with eager #148241

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
shaoyuyoung opened this issue Mar 1, 2025 · 0 comments
Closed
Assignees
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2

Comments

@shaoyuyoung
Copy link
Contributor
shaoyuyoung commented Mar 1, 2025

🐛 Describe the bug

symptom description: when using nn.Tanhshrink-atan2 together, output is inconsistent with eager.
device backend: only CPP
repro

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config
import os
config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)



class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.shrink = nn.Tanhshrink()

    def forward(self, x):
        x = self.shrink(x)
        x = torch.atan2(x, x)
        return x


model = Model()


x = torch.randn(1, 3, 64, 64)

inputs = [x]



def run_test(model, inputs, backend):
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    torch.manual_seed(0)
    output = model(*inputs)
    return output


output = run_test(model, inputs, 'eager')
c_output = run_test(model, inputs, 'inductor')

print(torch.allclose(output, c_output, 1e-3, 1e-3, equal_nan=True))
print(torch.max(torch.abs(output - c_output)))

Error logs

CPP

False
tensor(3.1416)

Triton

True
tensor(0., device='cuda:0')

Versions

nightly 20250225

cc @chauhang @penguinwu

@leslie-fang-intel leslie-fang-intel self-assigned this Mar 1, 2025
@leslie-fang-intel leslie-fang-intel added the oncall: cpu inductor CPU Inductor issues for Intel team to triage label Mar 1, 2025
majing921201 pushed a commit to majing921201/pytorch that referenced this issue Mar 4, 2025
**Summary**
Fix pytorch#148241, The previous vectorized code generation for `tanh` used a decomposed implementation, leading to numerical differences that were further amplified by `atan2`. For example, in the given test case after `tanh`, the eager output at `[0,0,11,47]` was `-5.820766091346741e-10`, while the compiled output was `1.4319084584712982e-08`, resulting in different `atan2` outputs of `-2.3561` and `0.7853`. This issue is fixed by switching to the Sleef implementation.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_tanh_atan2
```

Pull Request resolved: pytorch#148254
Approved by: https://github.com/malfet, https://github.com/jgong5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2
Projects
None yet
Development

No branches or pull requests

2 participants
0