-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Closed
Labels
module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiionmodule: xpuIntel XPU related issuesIntel XPU related issuestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Root cause of:
With:
- Pytorch: 3eb8fa0
On:
- Intel Data Center GPU Max 1550 (PVC)
Running one of the basic ComfyUI workloads (Image Generation) it was found that generated image tensor contains all NaN values if using Pytorch XPU backend. It works all fine producing expected image if using CPU backend. It further was found that NaN values first appear when calling torch.nn.functional.scaled_dot_product_attention
. See comfyanonymous/ComfyUI#8228 for details.
Here are saved input tensors and simple script to reproduce the torch.nn.functional.scaled_dot_product_attention
issue with on the Pytorch side. Playing with the script you can find out that PyTorch XPU will produce output tensor significantly different compared to CPU and CUDA (tried on A10) and some values being NaN:
- Input tensors can be found at https://github.com/dvrogozh/ComfyUI/tree/for8228/debug:
- Test script: https://github.com/dvrogozh/ComfyUI/blob/for8228/debug/test.py
import torch
device="xpu:0"
#device="cpu"
q=torch.load("q.pt", map_location=device)
k=torch.load("k.pt", map_location=device)
v=torch.load("v.pt", map_location=device)
print(q)
print(k)
print(v)
print(f"q.isnan: {torch.isnan(q).any()}")
print(f"q.isinf: {torch.isinf(q).any()}")
print(f"k.isnan: {torch.isnan(k).any()}")
print(f"k.isinf: {torch.isinf(k).any()}")
print(f"v.isnan: {torch.isnan(v).any()}")
print(f"v.isinf: {torch.isinf(v).any()}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
print(f"out.isnan: {torch.isnan(out).any()}")
print(f"out.isinf: {torch.isinf(out).any()}")
print(out)
Metadata
Metadata
Assignees
Labels
module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiionmodule: xpuIntel XPU related issuesIntel XPU related issuestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module