8000 [SDPA] Add testing to ensure stride order exactly matches by drisspg · Pull Request #152894 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[SDPA] Add testing to ensure stride order exactly matches #152894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: gh/drisspg/149/base
Choose a base branch
from

Conversation

drisspg
Copy link
Contributor
@drisspg drisspg commented May 6, 2025

[ghstack-poisoned]
Copy link
pytorch-bot bot commented May 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152894

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 6, 2025
drisspg added a commit that referenced this pull request May 6, 2025
@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 6, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
@drisspg
Copy link
Contributor Author
drisspg commented May 6, 2025

Update mem eff striding

[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
@@ -2469,6 +2469,73 @@ def test_cudnn_attention_different_dk_dv(self, device):

self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("backend", PLATFORM_SPECIFIC_SDPA, name_fn=lambda x: x.name)
@parametrize("compile_mode", ["eager", "inductor"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So locally all test pass but that doesn't seem possible, since i didnt even update the meta for mem eff yet

i wonder if there are too many recompiles and then just falls back to eager..

if compile_mode == "inductor":
run_sdpa = torch.compile(run_sdpa, backend="inductor", fullgraph=True)

with sdpa_kernel(backends=[backend]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if compile_mode == "eager", can you enable CrossRefFakeMode ?

[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 6, 2025
grad_q = at::empty(query.sizes(), query.options());
grad_k = at::empty(key.sizes(), key.options());
grad_v = at::empty(value.sizes(), value.options());
grad_q = at::empty_like(query);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure flash attention used to have the same bug, I guess it was copied and pasted from here and never fixed here.

[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 14, 2025
8000
drisspg added a commit that referenced this pull request May 14, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0