8000 Softmax Decomp Causes Incorrect Gradients when Using `torch.compile` with `F.multi_head_attention_forward` · Issue #152309 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Softmax Decomp Causes Incorrect Gradients when Using torch.compile with F.multi_head_attention_forward #152309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
8000
defaultd661 opened this issue Apr 28, 2025 · 15 comments
Assignees
Labels
high priority module: aotdispatch umbrella label for AOTAutograd issues module: correctness (silent) issue that returns an incorrect result silently module: decompositions Topics related to decomposition (excluding PrimTorch) module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ubn "unbreak now", our utmost priority label.

Comments

@defaultd661
Copy link
defaultd661 commented Apr 28, 2025

🐛 Describe the bug

When using torch.compile to compile a model that internally calls torch.nn.functional.multi_head_attention_forward, the computed gradients differ significantly from the ones obtained via eager mode.

To Reproduce

import torch
import torch.nn as nn
import torch.nn.functional as F


class ReproMultihead(nn.Module):

    def __init__(self):
        super().__init__()
        self.embed_dim = 256
        self.num_heads = self.embed_dim // 64
        self.in_proj_weight = nn.Parameter(torch.empty(3 * self.embed_dim,
            self.embed_dim))
        self.in_proj_bias = nn.Parameter(torch.empty(3 * self.embed_dim))
        self.out_proj_weight = nn.Parameter(torch.empty(self.embed_dim,
            self.embed_dim))
        self.out_proj_bias = nn.Parameter(torch.empty(self.embed_dim))
        nn.init.constant_(self.in_proj_weight, 0.1)
        nn.init.constant_(self.in_proj_bias, 0.1)
        nn.init.constant_(self.out_proj_weight, 0.1)
        nn.init.constant_(self.out_proj_bias, 0.1)

    def forward(self, x):
        x_t = x.transpose(0, 1)
        attn_output, _ = F.multi_head_attention_forward(query=x_t, key=x_t,
            value=x_t, embed_dim_to_check=self.embed_dim, num_heads=self.
            num_heads, in_proj_weight=self.in_proj_weight, in_proj_bias=
            self.in_proj_bias, bias_k=None, bias_v=None, add_zero_attn=
            False, dropout_p=0.0, out_proj_weight=self.out_proj_weight,
            out_proj_bias=self.out_proj_bias, training=True,
            key_padding_mask=None, need_weights=False, attn_mask=None,
            use_separate_proj_weight=False, q_proj_weight=None,
            k_proj_weight=None, v_proj_weight=None, static_k=None, static_v
            =None, average_attn_weights=True, is_causal=False)
        return attn_output.transpose(0, 1)

def test_bug():
    torch.set_default_device('cuda')
    torch.manual_seed(0)

    model = ReproMultihead().cuda()

    compiled_model = ReproMultihead().cuda()
    compiled_model = torch.compile(compiled_model)

    x = torch.randn((1, 512, 256), device='cuda', requires_grad=True)
    x_compiled = x.clone().detach().requires_grad_(True)

    out_eager = model(x)
    out_compiled = compiled_model(x_compiled)

    out_eager.sum().backward()
    out_compiled.sum().backward()

    weight_diff = torch.max(torch.abs(model.in_proj_weight.grad -
        compiled_model.in_proj_weight.grad)).item()
    print('weight_diff =', weight_diff)
    bias_diff = torch.max(torch.abs(model.in_proj_bias.grad -
        compiled_model.in_proj_bias.grad)).item()
    print('bias_diff =', bias_diff)

if __name__ == '__main__':
    test_bug()

Output

weight_diff = 0.130126953125
bias_diff = 0.12890625

Versions

PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @SherlockNoMad @bdhirsh

@eellison
Copy link
Contributor

I couldn't repro this on a100 or h100.

weight_diff = 0.0
bias_diff = 0.0009765625

Can you try again on main, and reopen if this issue still exists ?

@defaultd661
Copy link
Author

On the CPU, the correct results can be obtained: weight_diff = 0.0 and bias_diff = 0.009765625; however, on CUDA, the results are still weight_diff = 0.130126953125 and bias_diff = 0.12890625.

Version

The above results were obtained using torch-nightly, with the version information as follows:

Collecting environment information...
PyTorch version: 2.8.0.dev20250428+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-135-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A6000
Nvidia driver version: 570.124.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               64
On-line CPU(s) list:                  0-63
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Gold 6444Y
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   16
Socket(s):                            2
Stepping:                             8
Frequency boost:                      enabled
CPU max MHz:                          3601.0000
CPU min MHz:                          800.0000
BogoMIPS:                             7200.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            1.5 MiB (32 instances)
L1i cache:                            1 MiB (32 instances)
L2 cache:                             64 MiB (32 instances)
L3 cache:                             90 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-15,32-47
NUMA node1 CPU(s):                    16-31,48-63
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.8.0.87
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250428+cu128
[pip3] torchaudio==2.6.0.dev20250428+cu128
[pip3] torchvision==0.22.0.dev20250428+cu128
[conda] numpy                     2.1.2                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.8.3.14                pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.8.57                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.8.61                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.8.57                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.8.0.87                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.3.41                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.9.55                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.2.55                pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.7.53                pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.3                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.26.2                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.8.61                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.8.55                  pypi_0    pypi
[conda] pytorch-triton            3.3.0+git96316ce5          pypi_0    pypi
[conda] torch                     2.8.0.dev20250428+cu128          pypi_0    pypi
[conda] torchaudio                2.6.0.dev20250428+cu128          pypi_0    pypi
[conda] torchvision               0.22.0.dev20250428+cu128          pypi_0    pypi

CPU Code

import torch
import torch.nn as nn
import torch.nn.functional as F


class ReproMultihead(nn.Module):

    def __init__(self):
        super().__init__()
        self.embed_dim = 256
        self.num_heads = self.embed_dim // 64
        self.in_proj_weight = nn.Parameter(torch.empty(3 * self.embed_dim,
            self.embed_dim))
        self.in_proj_bias = nn.Parameter(torch.empty(3 * self.embed_dim))
        self.out_proj_weight = nn.Parameter(torch.empty(self.embed_dim,
            self.embed_dim))
        self.out_proj_bias = nn.Parameter(torch.empty(self.embed_dim))
        nn.init.constant_(self.in_proj_weight, 0.1)
        nn.init.constant_(self.in_proj_bias, 0.1)
        nn.init.constant_(self.out_proj_weight, 0.1)
        nn.init.constant_(self.out_proj_bias, 0.1)

    def forward(self, x):
        x_t = x.transpose(0, 1)
        attn_output, _ = F.multi_head_attention_forward(query=x_t, key=x_t,
            value=x_t, embed_dim_to_check=self.embed_dim, num_heads=self.
            num_heads, in_proj_weight=self.in_proj_weight, in_proj_bias=
            self.in_proj_bias, bias_k=None, bias_v=None, add_zero_attn=
            False, dropout_p=0.0, out_proj_weight=self.out_proj_weight,
            out_proj_bias=self.out_proj_bias, training=True,
            key_padding_mask=None, need_weights=False, attn_mask=None,
            use_separate_proj_weight=False, q_proj_weight=None,
            k_proj_weight=None, v_proj_weight=None, static_k=None, static_v
            =None, average_attn_weights=True, is_causal=False)
        return attn_output.transpose(0, 1)

def test_bug():
    torch.set_default_device('cpu')
    torch.manual_seed(0)

    model = ReproMultihead()

    compiled_model = ReproMultihead()
    compiled_model = torch.compile(compiled_model)

    x = torch.randn((1, 512, 256), requires_grad=True)
    x_compiled = x.clone().detach().requires_grad_(True)

    out_eager = model(x)
    out_compiled = compiled_model(x_compiled)

    out_eager.sum().backward()
    out_compiled.sum().backward()

    weight_diff = torch.max(torch.abs(model.in_proj_weight.grad -
        compiled_model.in_proj_weight.grad)).item()
    print('weight_diff =', weight_diff)
    bias_diff = torch.max(torch.abs(model.in_proj_bias.grad -
        compiled_model.in_proj_bias.grad)).item()
    print('bias_diff =', bias_diff)

if __name__ == '__main__':
    test_bug()

@defaultd661
Copy link
Author

I couldn't repro this on a100 or h100.

weight_diff = 0.0
bias_diff = 0.0009765625

Can you try again on main, and reopen if this issue still exists ?

It seems that I don't have permission to reopen the issue.

@eellison eellison reopened this Apr 29, 2025
@eellison
Copy link
Contributor

cc @drisspg can you repro ?

@eellison eellison changed the title Incorrect Gradients when Using torch.compile with F.multi_head_attention_forward Softmax Decomp Causes Incorrect Gradients when Using torch.compile with F.multi_head_attention_forward Apr 29, 2025
@eellison eellison added module: decompositions Topics related to decomposition (excluding PrimTorch) module: aotdispatch umbrella label for AOTAutograd issues module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Apr 29, 2025
@pytorch-bot pytorch-bot bot added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Apr 29, 2025
@eellison eellison added module: correctness (silent) issue that returns an incorrect result silently high priority and removed module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Apr 29, 2025
@eellison
Copy link
Contributor

Okay, I can repro this by forcing the math backend. And the error is due to the softmax decomp.

import torch
import torch.nn as nn
import torch.nn.functional as F


class ReproMultihead(nn.Module):

    def __init__(self):
        super().__init__()
        self.embed_dim = 256
        self.num_heads = self.embed_dim // 64
        self.in_proj_weight = nn.Parameter(torch.empty(3 * self.embed_dim,
            self.embed_dim))
        self.in_proj_bias = nn.Parameter(torch.empty(3 * self.embed_dim))
        self.out_proj_weight = nn.Parameter(torch.empty(self.embed_dim,
            self.embed_dim))
        self.out_proj_bias = nn.Parameter(torch.empty(self.embed_dim))
        nn.init.constant_(self.in_proj_weight, 0.1)
        nn.init.constant_(self.in_proj_bias, 0.1)
        nn.init.constant_(self.out_proj_weight, 0.1)
        nn.init.constant_(self.out_proj_bias, 0.1)

    def forward(self, x):
        x_t = x.transpose(0, 1)
        attn_output, _ = F.multi_head_attention_forward(query=x_t, key=x_t,
            value=x_t, embed_dim_to_check=self.embed_dim, num_heads=self.
            num_heads, in_proj_weight=self.in_proj_weight, in_proj_bias=
            self.in_proj_bias, bias_k=None, bias_v=None, add_zero_attn=
            False, dropout_p=0.0, out_proj_weight=self.out_proj_weight,
            out_proj_bias=self.out_proj_bias, training=True,
            key_padding_mask=None, need_weights=False, attn_mask=None,
            use_separate_proj_weight=False, q_proj_weight=None,
            k_proj_weight=None, v_proj_weight=None, static_k=None, static_v
            =None, average_attn_weights=True, is_causal=False)
        return attn_output.transpose(0, 1)

def test_bug():

    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    # torch.backends.cuda.enable_math_sdp(False)


    def test_fn():
        torch.set_default_device('cuda')
        torch.manual_seed(0)

        model = ReproMultihead().cuda()
        torch._dynamo.reset()
        compiled_model = ReproMultihead().cuda()
        compiled_model = torch.compile(compiled_model)

        x = torch.randn((1, 512, 256), device='cuda', requires_grad=True)
        x_compiled = x.clone().detach().requires_grad_(True)

        out_eager = model(x)
        out_compiled = compiled_model(x_compiled)

        out_eager.sum().backward()
        out_compiled.sum().backward()

        try:
            torch.testing.assert_close(model.in_proj_weight.grad, compiled_model.in_proj_weight.grad)
            return True
        except:
            return False
        
        # weight_diff = torch.max(torch.abs(model.in_proj_weight.grad -
        #     compiled_model.in_proj_weight.grad)).item()
        # print('weight_diff =', weight_diff)
        # bias_diff = torch.max(torch.abs(model.in_proj_bias.grad -
        #     compiled_model.in_proj_bias.grad)).item()
        # print('bias_diff =', bias_diff)

    from torch._inductor.compiler_bisector import CompilerBisector

    CompilerBisector.do_bisect(test_fn)



if __name__ == '__main__':
    test_bug()

No bisection status found.
Starting bisection process with system: eager
Moving to the next system: aot_eager
Moving to the next system: aot_eager_decomp_partition
The issue is in the aot_eager_decomp_partition system. Moving to the first subsystem: ConfigChange(name='aot_eager_decomp_partition_cse', config_name='aot_eager_decomp_partition', config_field='cse', config_value=False)
Disabling aot_eager_decomp_partition_cse did not fix the issue.
Moving to the next subsystem: aot_eager_decomp_partition - decomposition
Disabling decomposition fixed the issue.
Starting bisect by getting upper bound.
Upper bound of 76 found for aot_eager_decomp_partition.
Bisecting aot_eager_decomp_partition - decomposition (Range: [0, 76], Midpoint: 38)
Bisecting aot_eager_decomp_partition - decomposition (Range: [0, 38], Midpoint: 19)
Bisecting aot_eager_decomp_partition - decomposition (Range: [0, 19], Midpoint: 9)
Bisecting aot_eager_decomp_partition - decomposition (Range: [10, 19], Midpoint: 14)
Bisecting aot_eager_decomp_partition - decomposition (Range: [10, 14], Midpoint: 12)
Bisecting aot_eager_decomp_partition - decomposition (Range: [10, 12], Midpoint: 11)
Binary search completed for aot_eager_decomp_partition - decomposition. The bisect number is 12. Debug info: <OpOverload(op='aten._softmax', overload='default')>
Bisection status deleted.

@drisspg, any thoughts ?

@eellison eellison added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Apr 29, 2025
@zou3519 zou3519 added the ubn "unbreak now", our utmost priority label. label May 6, 2025
@zou3519 zou3519 self-assigned this May 6, 2025
@zou3519 zou3519 added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label May 13, 2025
@zou3519 zou3519 assigned bdhirsh and unassigned zou3519 May 13, 2025
@bdhirsh
Copy link
Contributor
bdhirsh commented May 14, 2025

I have a pretty small repro of the softmax decomp causing noticeable differences compared to the eager kernel for aten._softmax_backward_data:

import torch
from torch._inductor import inductor_prims

# decomp from torch/_inductor/decompositions.py
def _softmax_backward_data_decomp(grad_output: torch.Tensor, output: torch.Tensor, dim: int, input_dtype: torch.dtype) -> torch.Tensor:
    new_grad_output = grad_output * output
    sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
    # non-FMA version is commented out
    # grad_input = new_grad_output - output * sum_new_grad
    grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output)

    if grad_output.dtype != input_dtype:
        grad_input = grad_input.to(input_dtype)
    out = grad_input.contiguous()
    return out

grad_output = torch.randn(1, 4, 512, 512, device='cuda') * 1000
output = torch.rand(1, 4, 512, 512, device='cuda')

out_ref = torch.ops.aten._softmax_backward_data.default(grad_output, output, dim=-1, input_dtype=torch.float32)
out_test = _softmax_backward_data_decomp(grad_output, output, dim=-1, input_dtype=torch.float32)

print(torch.max(torch.abs(out_ref - out_test)))

A few things to call out:

(1) the difference in my tiny repro is tensor(0.0078, device='cuda:0'). which... is not huge, but seems like it is large enough to be problematic? in the full repro above, this small difference causes a larger difference in the final weight/bias grad values:

fw diff = 0.0008087158203125
weight_diff = 0.02783203125
bias_diff = 0.00390625

(2) all of the inputs/outputs are float32, so there don't seem to be any funky intermediate dtype casting differences going on.

(3) The decomp uses FMA. i tried both the FMA and non-FMA flavor, and I get similar divergences

(4) how large the difference is depends on the input data distribution. I got a similar data distribution to the MHA example in my repro above (output tensor is between 0 and 1, grad_output` is larger)

I spent some time staring at the cuda kernel but so far I haven't figured out what causes the _softmax_backward_data cuda kernel (link) to diverge from the decomp.

cc @ngimel, in case you know

@ngimel
Copy link
Collaborator
ngimel commented May 14, 2025

It might be fma, cuda kernel might or might not produce fma instructions? Also torch.sum() and result produced in the kernel might slightly differ, as you said the error for the tiny repro is not huge so it's those small things. Softmax, as we know, is particularly sensitive to fma/no fma.

@bdhirsh
Copy link
Contributor
bdhirsh commented May 14, 2025

It didn't appear to be FMA because i tried with FMA on and off and saw similar differences.

Are we ok chalking these differences up to reduction numerics being "different" between eager and compile and living with the (small) difference?

@bdhirsh
Copy link
Contributor
bdhirsh commented May 14, 2025

One thing I will say is that there are two independent issues here:

(1) the softmax decomp causing slightly different numerics compared to eager (this can be fully repro'd with backend="aot_eager_decomp_partition"), that gives modest differences vs. eager in the final weight/bias grads:

weight_diff = 0.02783203125
bias_diff = 0.00390625

(2) there is a second issue I found that only repros with inductor, that I could repro with the efficient_attention backend (not math, so no softmax decomp), that causes larger differences:

weight_diff = 0.130126953125
bias_diff = 0.12890625

I can file a separate issue for that one

@ngimel
Copy link
Collaborator
ngimel commented May 14, 2025

@bdhirsh for your micro repro it's definitely just fp error, when I modify the output to

vals, indices = torch.max(torch.abs(out_ref.view(-1) - out_test.view(-1)), dim=0)
print(out_ref.view(-1)[indices], out_test.view(-1)[indices])

the output I'm getting is tensor(-3613 8000 7.9570, device='cuda:0') tensor(-36137.9453, device='cuda:0'), the relative error is right where you'd expect it to be (1e-6)

@bdhirsh
Copy link
Contributor
bdhirsh commented May 14, 2025

@ngimel doing the same calculation for the weight grad (instead of the forward out) gives me in the ballpark of 3e-5: a bit larger, but not in a way that seems significant.

Do you have a standard set of utils that you recommend using for deciding if "compile vs eager delta is reasonable or not? I've been going off of torch.max(torch.abs(test - ref)), although I agree that this only gives absolute difference which doesn't tell the whole story.

I did this after reading your commend:

# prints tensor(-3.0756e-05, device='cuda:0')
vals, indices = torch.max(torch.abs(model.in_proj_weight.grad.view(-1) - compiled_model.in_proj_weight.grad.view(-1)), dim=0)
print(1 - (model.in_proj_weight.grad.view(-1)[indices] / compiled_model.in_proj_weight.grad.view(-1)[indices]))

@bdhirsh
Copy link
Contributor
bdhirsh commented May 14, 2025

If Natalia / @eellison agree that the relative difference across compile vs eager here is within tolerance, i'll go ahead and close this issue

@ngimel
Copy link
Collaborator
ngimel commented May 14, 2025

@bdhirsh in the torch.compile torchbench test suite we are computing reference in fp64, and then check that difference (in the cosine similarity sense) from this gold standard is approximately the same for eager and compile, but at times it can hide real issues. Having a test with perfect accuracy and no false positives is hard

@bdhirsh
Copy link
Contributor
bdhirsh commented May 14, 2025

tentatively closing this issue

@bdhirsh bdhirsh closed this as completed May 14, 2025
@eellison
Copy link
Contributor

Yea, we should have checked fp64 ref here. We do have this note in the torch.compile issue template..

When comparing eager and torch.compile, use a higher precision result as a baseline. torch._dynamo.utils.same with fp64_ref will handle this comparison.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: aotdispatch umbrella label for AOTAutograd issues module: correctness (silent) issue that returns an incorrect result silently module: decompositions Topics related to decomposition (excluding PrimTorch) module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ubn "unbreak now", our utmost priority label.
Projects
None yet
Development

No branches or pull requests

7 participants
0