8000 [SDPA] Add testing to ensure stride order exactly matches · pytorch/pytorch@07600f7 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 07600f7

Browse files
committed
[SDPA] Add testing to ensure stride order exactly matches
ghstack-source-id: 721d10e Pull Request resolved: #152894
1 parent ac792a0 commit 07600f7

File tree

5 files changed

+188
-21
lines changed

5 files changed

+188
-21
lines changed

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,37 @@ at::cuda::philox::unpack_cudnn<<<1, 1, 0, stream>>>(arg, seed_ptr, offset_ptr);
118118
namespace native {
119119

120120
namespace {
121+
// Create output tensor with strides matching query layout
122+
at::Tensor create_output_with_matching_layout(
123+
const at::Tensor& query,
124+
at::IntArrayRef output_shape,
125+
at::TensorOptions options
126+
) {
127+
// Get the "fill order" - an argsort on the strides of the query tensor
128+
const int dims = query.dim();
129+
std::vector<int64_t> fill_order(dims);
130+
std::iota(fill_order.begin(), fill_order.end(), 0);
131+
132+
const auto query_strides = query.strides();
133+
std::stable_sort(
134+
fill_order.begin(),
135+
fill_order.end(),
136+
[&query_strides](int64_t idx1, int64_t idx2) {
137+
return query_strides[idx1] < query_strides[idx2];
138+
});
139+
140 8000 +
// Construct new strides that preserve the same layout ordering
141+
std::vector<int64_t> new_strides(dims);
142+
int64_t current_stride = 1;
143+
for (const int64_t dim_idx : fill_order) {
144+
new_strides[dim_idx] = current_stride;
145+
current_stride *= output_shape[dim_idx];
146+
}
147+
148+
// Create tensor with the constructed strides
149+
return at::empty(output_shape, options)
150+
.as_strided(output_shape, new_strides, 0);
151+
}
121152

122153

123154
static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4;
@@ -1433,11 +1464,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
14331464
}
14341465
kernel_launched = true;
14351466

1436-
res = at::empty(
1437-
{B, M, num_heads, Kv},
1438-
query.options().dtype(
1439-
CutlassToAtenDtype<typename Kernel::output_t>::atScalarType()));
1440-
1467+
auto opts = query.options().dtype(CutlassToAtenDtype<typename Kernel::output_t>::atScalarType());
1468+
res = create_output_with_matching_layout(query, {B, M, num_heads, Kv}, opts);
14411469
// NOTE: Should be aligned (by padding) in case M is
14421470
// not a good number for loading during backward
14431471
constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE;
@@ -1455,11 +1483,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
14551483
: nullptr;
14561484
at::Tensor output_accum;
14571485
if (Kernel::kNeedsOutputAccumulatorBuffer) {
1458-
output_accum = at::empty(
1459-
{B, M, num_heads, Kv},
1460-
query.options().dtype(
1461-
CutlassToAtenDtype<
1462-
typename Kernel::output_accum_t>::atScalarType()));
1486+
auto opts = query.options().dtype(CutlassToAtenDtype<typename Kernel::output_t>::atScalarType());
1487+
output_accum = create_output_with_matching_layout(query, {B, M, num_heads, Kv}, opts);
14631488
p.output_accum_ptr =
14641489
(typename Kernel::output_accum_t*)output_accum.data_ptr();
14651490
} else {
@@ -1494,12 +1519,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
14941519
ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0));
14951520
ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0));
14961521
ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0));
1522+
14971523
ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1));
14981524
ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1));
14991525
ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1));
1526+
15001527
ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2));
15011528
ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
15021529
ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));
1530+
15031531
ASSIGN_CHECK_OVERFLOW(p.o_strideM, res.stride(1));
15041532

15051533
if (bias.has_value()) {

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,12 @@ struct AttentionKernel {
237237
query_ptr += batch_id * q_strideB;
238238
key_ptr += batch_id * k_strideB;
239239
value_ptr += batch_id * v_strideB;
240-
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
240+
output_ptr += batch_id * q_strideB;
241+
242+
// Reuse q_strides since we want to guarantee exact match w/ input
241243
if (output_accum_ptr != nullptr) {
242244
output_accum_ptr +=
243-
int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
245+
int64_t(batch_id * q_strideB);
244246
}
245247
q_start = 0;
246248
k_start = 0;
@@ -252,15 +254,14 @@ struct AttentionKernel {
252254

253255
value_ptr += k_start * v_strideM + head_id * v_strideH;
254256
output_ptr +=
255-
int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
257+
int64_t(q_start + query_start) * o_strideM + head_id * q_strideH;
256258

257259
if (kSupportsBias && attn_bias_ptr != nullptr) {
258260
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
259261
}
260262
if (output_accum_ptr != nullptr) {
261263
output_accum_ptr +=
262-
int64_t(q_start + query_start) * (head_dim_value * num_heads) +
263-
head_id * head_dim_value;
264+
int64_t(q_start + query_start) * q_strideM + head_id * q_strideH;
264265
} else {
265266
// Accumulate directly in the destination buffer (eg for f32)
266267
output_accum_ptr = (accum_t*)output_ptr;

test/inductor/test_flex_attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2937,12 +2937,15 @@ def test_flex_attention_backward_stride_ordering(
29372937
)
29382938
out = func(query, key, value)
29392939
grad_output = torch.randn_like(out)
2940-
out.backward(grad_output)
2940+
2941+
grad_query, grad_key, grad_value = torch.autograd.grad(
2942+
out, [query, key, value], grad_output
2943+
)
29412944

29422945
for leaf, grad, name in [
2943-
(query, query.grad, "query"),
2944-
(key, key.grad, "key"),
2945-
(value, value.grad, "value"),
2946+
(query, grad_query, "query"),
2947+
(key, grad_key, "key"),
2948+
(value, grad_value, "value"),
29462949
]:
29472950
input_stride_order = get_stride_order(grad.stride())
29482951
orig_stride_order = get_stride_order(leaf.stride())

test/test_transformers.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Optional
2222
import torch.utils.cpp_extension
2323
from torch.testing._internal.common_nn import NNTestCase
24+
from torch._inductor.test_case import TestCase as InductorTestCase
2425
from torch.testing._internal.common_utils import (
2526
TEST_WITH_ROCM,
2627
skipIfRocm,
@@ -2469,6 +2470,7 @@ def test_cudnn_attention_different_dk_dv(self, device):
24692470

24702471
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
24712472

2473+
24722474
@skipIfRocm # No cuDNN Attention
24732475
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
24742476
def test_cudnn_attention_gqa(self, device):
@@ -4285,6 +4287,89 @@ def test_is_causal_and_mask_fails(self, device):
42854287
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
42864288
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
42874289

4290+
4291+
class TestSDPACompile(InductorTestCase):
4292+
4293+
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
4294+
@parametrize("backend", PLATFORM_SPECIFIC_SDPA, name_fn=lambda x: x.name)
4295+
@parametrize("compile_mode", ["eager", "inductor"])
4296+
@parametrize(
4297+
"permute_order",
4298+
[perm + (3,) for perm in itertools.permutations([0, 1, 2])],
4299+
)
4300+
@parametrize("shape", [(2, 4, 128, 16), (4, 2, 64, 32)])
4301+
def test_sdpa_stride_ordering_and_backward(self, device, backend, compile_mode, permute_order, shape):
4302+
from torch._inductor.ir import get_stride_order
4303+
make_tensor = partial(
4304+
torch.randn,
4305+
shape,
4306+
device=device,
4307+
dtype=torch.float16,
4308+
requires_grad=False,
4309+
)
4310+
4311+
# Create and permute tensors
4312+
query, key, value = make_tensor(), make_tensor(), make_tensor()
4313+
query = query.permute(permute_order)
4314+
key = key.permute(permute_order)
4315+
value = value.permute(permute_order)
4316+
4317+
# Create leaves
4318+
query.requires_grad_()
4319+
key.requires_grad_()
4320+
value.requires_grad_()
4321+
4322+
def run_sdpa(q, k, v):
4323+
return torch.nn.functional.scaled_dot_product_attention(q, k, v)
4324+
4325+
if compile_mode == "inductor":
4326+
run_sdpa = torch.compile(run_sdpa, backend="inductor", fullgraph=True)
4327+
else:
4328+
original_run_sdpa = run_sdpa
4329+
4330+
def run_sdpa(q, k, v):
4331+
with torch._subclasses.CrossRefFakeMode():
4332+
return original_run_sdpa(q, k, v)
4333+
4334+
4335+
with sdpa_kernel(backends=[backend]):
4336+
out = run_sdpa(query, key, value)
4337+
4338+
# Check out and query
4339+
out_stride_order = get_stride_order(out.stride())
4340+
query_stride_order = get_stride_order(query.stride())
4341+
4342+
self.assertEqual(
4343+
out_stride_order,
4344+
query_stride_order,
4345+
f"Compile mode: {compile_mode}, Backend: {backend}, "
4346+
f"Forward: out {out_stride_order}, query {query_stride_order}",
4347+
)
4348+
4349+
grad_output = torch.randn_like(out)
4350+
cm = torch._subclasses.CrossRefFakeMode() if compile_mode == "eager" else contextlib.nullcontext()
4351+
4352+
with cm:
4353+
grad_query, grad_key, grad_value = torch.autograd.grad(out, [query, key, value], grad_output)
4354+
4355+
# Check that gradient stride orders match input stride orders
4356+
for leaf, grad, name in [
4357+
(query, grad_query, "query"),
4358+
(key, grad_key, "key"),
4359+
(value, grad_value, "value"),
4360+
]:
4361+
grad_stride_order = get_stride_order(grad.stride())
4362+
input_stride_order = get_stride_order(leaf.stride())
4363+
self.assertEqual(
4364+
grad_stride_order,
4365+
input_stride_order,
4366+
f"Compile mode: {compile_mode}, Backend: {backend}, "
4367+
f"Backward for {name}: grad {grad_stride_order}, input {input_stride_order}",
4368+
)
4369+
4370+
4371+
4372+
42884373
if NOTEST_CPU:
42894374
device_types = ("cuda", )
42904375
else:
@@ -4297,6 +4382,7 @@ def test_is_causal_and_mask_fails(self, device):
42974382
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
42984383
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
42994384
instantiate_device_type_tests(TestSDPAXpuOnly, globals(), only_for="xpu", allow_xpu=True)
4385+
instantiate_device_type_tests(TestSDPACompile, globals(), only_for=("cuda"))
43004386

43014387
if __name__ == '__main__':
43024388
run_tests()

torch/_meta_registrations.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,51 @@ def check_inplace_broadcast(self_shape, *args_shape):
101101
)
102102

103103

104+
def _construct_strides(
105+
sizes: Sequence[int],
106+
fill_order: Sequence[int],
107+
) -> Sequence[int]:
108+
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
109+
# Initialize strides
110+
assert len(sizes) == len(fill_order), (
111+
"Length of sizes must match the length of the fill order"
112+
)
113+
strides = [0] * len(sizes)
114+
115+
# Start with stride 1 for the innermost dimension
116+
current_stride = 1
117+
118+
# Iterate through the fill order populating strides
119+
for dim in fill_order:
120+
strides[dim] = current_stride
121+
current_stride *= sizes[dim]
122+
123+
return strides
124+
125+
126+
def _permute_strides(out: torch.Tensor, query_strides: tuple[int, ...]) -> torch.Tensor:
127+
"""
128+
Create a new tensor with the same data and shape as the input,
129+
but with strides permuted based on the input tensor's stride order.
130+
131+
Args:
132+
out (torch.Tensor): The output tensor of attention.
133+
query_strides (List[int]): The stride order of the input query tensor
134+
135+
Returns:
136+
torch.Tensor: A new tensor with same shape and data as the input,
137+
but with strides permuted based on the query tensor's stride order.
138+
"""
139+
from torch._inductor.ir import get_fill_order
140+
141+
fill_order = get_fill_order(query_strides)
142+
assert out.storage_offset() == 0, "Only support storage_offset == 0"
143+
out_strides = _construc EF80 t_strides(out.shape, fill_order)
144+
new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)
145+
new_out.copy_(out)
146+
return new_out
147+
148+
104149
@register_meta([aten.linspace, aten.logspace])
105150
@out_wrapper()
106151
def meta_linspace_logspace(
@@ -5878,7 +5923,9 @@ def meta__scaled_dot_product_efficient_attention(
58785923
num_heads = query.size(-2)
58795924
Kv = value.size(-1)
58805925

5881-
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
5926+
out_shape = (B, M, num_heads, Kv)
5927+
res = query.new_empty(out_shape)
5928+
res = _permute_strides(res, query.stride())
58825929

58835930
if torch.version.hip and torch.cuda.is_available():
58845931
"""Please see: https://github.com/pytorch/pytorch/issues/146848
@@ -6131,7 +6178,9 @@ def meta__efficient_attention_forward(
61316178
num_heads = query.size(-2)
61326179
Kv = value.size(-1)
61336180

6134-
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
6181+
out_shape = (B, M, num_heads, Kv)
6182+
res = query.new_empty(out_shape)
6183+
res = _permute_strides(res, query.stride())
61356184

61366185
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
61376186
actual_max_seqlen_q = M

0 commit comments

Comments
 (0)
0