10000 [Intel GPU] Enable SDPA on XPU by DDEle · Pull Request #147614 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Intel GPU] Enable SDPA on XPU #147614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

DDEle
Copy link
Contributor
@DDEle DDEle commented Feb 21, 2025

Motivation

This PR is part of the plan of OneDNN Upstreaming, as #114848 (comment) stated. The support of SDPA is via the overridable variance on XPU backend. Beside the added Attention.cpp file, Graph.h is added to hold utils for OneDNN graph including those for kernel/compile graph caching. In addition, a selection of testcases in test/test_transformers.py are copied into the new test/xpu/test_transformers.py and modified accordingly to provide additional tests beyond ./third_party/torch-xpu-ops/test/xpu/test_ops_xpu.py.

Depends on OneDNN version v3.7 upgrade in #147498
Depends on BUILD_GRAPH switch in #147608

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented Feb 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147614

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 2 Unrelated Failures

As of commit 6e0b7d2 with merge base af720cd (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@DDEle
Copy link
Contributor Author
DDEle commented Feb 21, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Feb 21, 2025
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

Copy link
Contributor

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

@EikanWang EikanWang marked this pull request as draft February 21, 2025 13:25
@EikanWang
Copy link
Collaborator

@pytorchbot rebase

@EikanWang EikanWang added this to the 2.7.0 milestone Feb 24, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased onednn_graph_sdpa-integration onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout onednn_graph_sdpa-integration && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the onednn_graph_sdpa-integration branch from 5ce9f89 to 2b35316 Compare February 24, 2025 16:57
@EikanWang
Copy link
Collaborator

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased onednn_graph_sdpa-integration onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout onednn_graph_sdpa-integration && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the onednn_graph_sdpa-integration branch from 2b35316 to 199187d Compare February 24, 2025 16:59
@EikanWang EikanWang added ciflow/xpu Run XPU CI tasks ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end labels Feb 24, 2025
@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Feb 25, 2025
@EikanWang EikanWang marked this pull request as ready for review February 25, 2025 04:24
@EikanWang
Copy link
Collaborator

For Intel GPU, this PR can significantly improve performance for key workloads like Stable Diffusion, benefiting Torch users, especially on client devices. In terms of timeline, we are strongly aiming to catch up with the upcoming PyTorch release. We would greatly appreciate it if @albanD, @jansel, and @desertfire could prioritize the review of this PR.

.lintrunner.toml Outdated
@@ -1251,6 +1251,7 @@ exclude_patterns = [
'test/test_testing.py',
'test/test_torch.py',
'test/test_transformers.py',
'test/xpu/test_transformers.py',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this excluded?

Copy link
Contributor Author
@DDEle DDEle Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just merged test/xpu/test_transformers.py to test/test_transformer.py

@@ -244,6 +244,7 @@ def convert_return(typ: BaseType, val: str) -> str:
"_scaled_dot_product_flash_attention",
"_scaled_dot_product_efficient_attention",
"_scaled_dot_product_cudnn_attention",
"_scaled_dot_product_fused_attention_overrideable",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per my understanding, AOTI needs to check if an operation returns empty tensor potentially and then generate c_shim code accordingly. Currently, I think _scaled_dot_product_fused_attention_overrideable should have similar behaviors with other _scaled_dot_product operations, say it may return empty tensor.

@@ -14866,6 +14867,7 @@
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
XPU: _scaled_dot_product_fused_attention_overrideable_xpu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg is this the right place to override it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think so - at least not from within core. This is meant as a generic op that can be used to register different backends to through privateuse1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD, @drisspg, the original idea intends to avoid adding a new op. Would you be okay with adding an operation like _scaled_dot_product_mkldnn_attention, which is similar to _scaled_dot_product_cudnn_attention?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed things w/ Alban, It is okay to add the XPU dispatch to any of the available SDPA apis. Which ever one most clostly aligns w/ what you need for forward backwards makes the most sense.

In that context this change seems fine



instantiate_device_type_tests(
TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we re-use the existing test?

Copy link
Contributor Author
@DDEle DDEle Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try to reuse the existing test/test_transformers.py and I was not sure if it is good as there are too many things to be generalized.

  1. For example, test_onednn_attention_fail_d256 is from test/test_transformers.py:test_cudnn_attention_fail_d128 and obviously we need to use different backend.
  2. Another example is that test_scaled_dot_product_attention_fused_kernels_packed tests nested tensor which is not supported on XPU.
  3. For test/test_transformers.py::TestTransformers::test_scaled_dot_product_attention, it tests float and double in a loop, while XPU does not support double currently.
  4. In addition, all cases with grad are not applicable on XPU as training will be supported later.
    Actually, there are only 2 cases in test/test_transformers.py::TestSDPA but 8 cases in TestSDPACpuOnly and 28 cases in TestSDPACudaOnly. I think it would be a mess after merging XPU tests into them.

@EikanWang EikanWang requested a review from drisspg February 26, 2025 16:06
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased onednn_graph_sdpa-integration onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout onednn_graph_sdpa-integration && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the onednn_graph_sdpa-integration branch from 06da205 to 6e0b7d2 Compare February 28, 2025 08:30
@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks labels Feb 28, 2025
@EikanWang EikanWang added ciflow/xpu Run XPU CI tasks ciflow/trunk Trigger trunk jobs on your pull request labels Feb 28, 2025
@EikanWang EikanWang moved this from In Progress to Approved in PyTorch Intel Feb 28, 2025
@EikanWang
Copy link
Collaborator
EikanWang commented Feb 28, 2025

The test_graph_partition related failures are irrelevant to this PR. The failures were due to a landed PR introducing CUDA-specific code in the test cases. And there has been a PR to fix these failures - #148121

@EikanWang
Copy link
Collaborator

In terms of inductor/test_kernel_benchmark.py::TestKernelBenchmark::test_matmul_triton_kernel_benchmark failure, it was due to #147620 enabled force_shape_pad for triton kernel benchmark, while Intel GPU supports this scenario. I submitted another PR to fix the failure - #148237.

@EikanWang EikanWang requested a review from albanD March 1, 2025 15:09
@EikanWang
Copy link
Collaborator

@albanD , may I know if we have addressed your comments?

@EikanWang
Copy link
Collaborator

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

@EikanWang
Copy link
Collaborator

@albanD , I will merge this PR first as we are preparing an internal full validation for the upcoming branch cut. And we will continue refine the changes if we have not addressed the comments well.

@github-project-automation github-project-automation bot moved this from Approved to Done in PyTorch Intel Mar 4, 2025
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
Motivation
===

This PR is part of the plan of OneDNN Upstreaming, as pytorch#114848 [(comment)](pytorch#114848 (comment)) stated. The support of SDPA is via the overridable variance on XPU backend. Beside the added `Attention.cpp` file, `Graph.h` is added to hold utils for OneDNN graph including those for kernel/compile graph caching. In addition, a selection of testcases in `test/test_transformers.py` are copied into the new `test/xpu/test_transformers.py` and modified accordingly to provide additional tests beyond `./third_party/torch-xpu-ops/test/xpu/test_ops_xpu.py`.

Depends on OneDNN version v3.7 upgrade in pytorch#147498
Depends on BUILD_GRAPH switch in pytorch#147608

Pull Request resolved: pytorch#147614
Approved by: https://github.com/jansel, https://github.com/EikanWang
pytorchmergebot pushed a commit to min-jean-cho/pytorch that referenced this pull request Mar 5, 2025
Motivation
===

This PR is part of the plan of OneDNN Upstreaming, as pytorch#114848 [(comment)](pytorch#114848 (comment)) stated. The support of SDPA is via the overridable variance on XPU backend. Beside the added `Attention.cpp` file, `Graph.h` is added to hold utils for OneDNN graph including those for kernel/compile graph caching. In addition, a selection of testcases in `test/test_transformers.py` are copied into the new `test/xpu/test_transformers.py` and modified accordingly to provide additional tests beyond `./third_party/torch-xpu-ops/test/xpu/test_ops_xpu.py`.

Depends on OneDNN version v3.7 upgrade in pytorch#147498
Depends on BUILD_GRAPH switch in pytorch#147608

Pull Request resolved: pytorch#147614
Approved by: https://github.com/jansel, https://github.com/EikanWang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks keep-going Don't stop on first failure, keep running tests until the end Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source topic: not user facing topic category
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

7 participants
0