-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Intel GPU] Enable SDPA on XPU #147614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Intel GPU] Enable SDPA on XPU #147614
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147614
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Unrelated FailuresAs of commit 6e0b7d2 with merge base af720cd ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
5ce9f89
to
2b35316
Compare
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
2b35316
to
199187d
Compare
For Intel GPU, this PR can significantly improve performance for key workloads like Stable Diffusion, benefiting Torch users, especially on client devices. In terms of timeline, we are strongly aiming to catch up with the upcoming PyTorch release. We would greatly appreciate it if @albanD, @jansel, and @desertfire could prioritize the review of this PR. |
.lintrunner.toml
Outdated
@@ -1251,6 +1251,7 @@ exclude_patterns = [ | |||
'test/test_testing.py', | |||
'test/test_torch.py', | |||
'test/test_transformers.py', | |||
'test/xpu/test_transformers.py', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this excluded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just merged test/xpu/test_transformers.py to test/test_transformer.py
@@ -244,6 +244,7 @@ def convert_return(typ: BaseType, val: str) -> str: | |||
"_scaled_dot_product_flash_attention", | |||
"_scaled_dot_product_efficient_attention", | |||
"_scaled_dot_product_cudnn_attention", | |||
"_scaled_dot_product_fused_attention_overrideable", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add this here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per my understanding, AOTI needs to check if an operation returns empty tensor potentially and then generate c_shim code accordingly. Currently, I think _scaled_dot_product_fused_attention_overrideable
should have similar behaviors with other _scaled_dot_product
operations, say it may return empty tensor.
@@ -14866,6 +14867,7 @@ | |||
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) | |||
dispatch: | |||
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable | |||
XPU: _scaled_dot_product_fused_attention_overrideable_xpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg is this the right place to override it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think so - at least not from within core. This is meant as a generic op that can be used to register different backends to through privateuse1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirmed things w/ Alban, It is okay to add the XPU dispatch to any of the available SDPA apis. Which ever one most clostly aligns w/ what you need for forward backwards makes the most sense.
In that context this change seems fine
test/xpu/test_transformers.py
Outdated
|
||
|
||
instantiate_device_type_tests( | ||
TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we re-use the existing test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did try to reuse the existing test/test_transformers.py
and I was not sure if it is good as there are too many things to be generalized.
- For example,
test_onednn_attention_fail_d256
is fromtest/test_transformers.py:test_cudnn_attention_fail_d128
and obviously we need to use different backend. - Another example is that
test_scaled_dot_product_attention_fused_kernels_packed
tests nested tensor which is not supported on XPU. - For
test/test_transformers.py::TestTransformers::test_scaled_dot_product_attention
, it testsfloat
anddouble
in a loop, while XPU does not support double currently. - In addition, all cases with grad are not applicable on XPU as training will be supported later.
Actually, there are only 2 cases intest/test_transformers.py::TestSDPA
but 8 cases inTestSDPACpuOnly
and 28 cases inTestSDPACudaOnly
. I think it would be a mess after merging XPU tests into them.
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
06da205
to
6e0b7d2
Compare
The |
@albanD , may I know if we have addressed your comments? |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 5 checks: pull / linux-focal-cpu-py3.10-gcc9-bazel-test / filter, xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 1, 4, linux.idc.xpu), xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 2, 4, linux.idc.xpu), xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 3, 4, linux.idc.xpu), xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 4, 4, linux.idc.xpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@albanD , I will merge this PR first as we are preparing an internal full validation for the upcoming branch cut. And we will continue refine the changes if we have not addressed the comments well. |
Motivation === This PR is part of the plan of OneDNN Upstreaming, as pytorch#114848 [(comment)](pytorch#114848 (comment)) stated. The support of SDPA is via the overridable variance on XPU backend. Beside the added `Attention.cpp` file, `Graph.h` is added to hold utils for OneDNN graph including those for kernel/compile graph caching. In addition, a selection of testcases in `test/test_transformers.py` are copied into the new `test/xpu/test_transformers.py` and modified accordingly to provide additional tests beyond `./third_party/torch-xpu-ops/test/xpu/test_ops_xpu.py`. Depends on OneDNN version v3.7 upgrade in pytorch#147498 Depends on BUILD_GRAPH switch in pytorch#147608 Pull Request resolved: pytorch#147614 Approved by: https://github.com/jansel, https://github.com/EikanWang
Motivation === This PR is part of the plan of OneDNN Upstreaming, as pytorch#114848 [(comment)](pytorch#114848 (comment)) stated. The support of SDPA is via the overridable variance on XPU backend. Beside the added `Attention.cpp` file, `Graph.h` is added to hold utils for OneDNN graph including those for kernel/compile graph caching. In addition, a selection of testcases in `test/test_transformers.py` are copied into the new `test/xpu/test_transformers.py` and modified accordingly to provide additional tests beyond `./third_party/torch-xpu-ops/test/xpu/test_ops_xpu.py`. Depends on OneDNN version v3.7 upgrade in pytorch#147498 Depends on BUILD_GRAPH switch in pytorch#147608 Pull Request resolved: pytorch#147614 Approved by: https://github.com/jansel, https://github.com/EikanWang
Motivation
This PR is part of the plan of OneDNN Upstreaming, as #114848 (comment) stated. The support of SDPA is via the overridable variance on XPU backend. Beside the added
Attention.cpp
file,Graph.h
is added to hold utils for OneDNN graph including those for kernel/compile graph caching. In addition, a selection of testcases intest/test_transformers.py
are copied into the newtest/xpu/test_transformers.py
and modified accordingly to provide additional tests beyond./third_party/torch-xpu-ops/test/xpu/test_ops_xpu.py
.Depends on OneDNN version v3.7 upgrade in #147498
Depends on BUILD_GRAPH switch in #147608
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov