8000 [SDPA] Add testing to ensure stride order exactly matches by drisspg · Pull Request #152894 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[SDPA] Add testing to ensure stride order exactly matches #152894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: gh/drisspg/149/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1423,11 +1423,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
}
kernel_launched = true;

res = at::empty(
{B, M, num_heads, Kv},
query.options().dtype(
CutlassToAtenDtype<typename Kernel::output_t>::atScalarType()));

auto opts = query.options().dtype(CutlassToAtenDtype<typename Kernel::output_t>::atScalarType());
res = sdp::create_output_with_matching_layout(query, {B, M, num_heads, Kv}, opts);
// NOTE: Should be aligned (by padding) in case M is
// not a good number for loading during backward
constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE;
Expand All @@ -1445,11 +1442,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
: nullptr;
at::Tensor output_accum;
if (Kernel::kNeedsOutputAccumulatorBuffer) {
output_accum = at::empty(
{B, M, num_heads, Kv},
query.options().dtype(
CutlassToAtenDtype<
typename Kernel::output_accum_t>::atScalarType()));
auto opts = query.options().dtype(CutlassToAtenDtype<typename Kernel::output_t>::atScalarType());
output_accum = sdp::create_output_with_matching_layout(query, {B, M, num_heads, Kv}, opts);
p.output_accum_ptr =
(typename Kernel::output_accum_t*)output_accum.data_ptr();
} else {
Expand Down Expand Up @@ -1484,12 +1478,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0));
ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0));
ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0));

ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1));
ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1));
ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1));

ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2));
ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));

ASSIGN_CHECK_OVERFLOW(p.o_strideM, res.stride(1));

if (bias.has_value()) {
Expand Down
15 changes: 8 additions & 7 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,9 @@ _efficient_attention_backward(
grad_k = chunk.select(2, 1);
grad_v = chunk.select(2, 2);
} else {
grad_q = at::empty(query.sizes(), query.options());
grad_k = at::empty(key.sizes(), key.options());
grad_v = at::empty(value.sizes(), value.options());
grad_q = at::empty_like(query);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure flash attention used to have the same bug, I guess it was copied and pasted from here and never fixed here.

grad_k = at::empty_like(key);
grad_v = at::empty_like(value);
}

if (bias_requires_grad) {
Expand Down Expand Up @@ -730,10 +730,11 @@ _efficient_attention_backward(
ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2));
ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2));
ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2));
p.gQKV_strideM_multiplier = shared_storage_dqdkdv ? 3 : 1;
TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1));
TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1));
TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1));

p.gQKV_strideM_multiplier = 1;
ASSIGN_CHECK_OVERFLOW(p.gQ_strideM, grad_q.stride(1));
ASSIGN_CHECK_OVERFLOW(p.gK_strideM, grad_k.stride(1));
ASSIGN_CHECK_OVERFLOW(p.gV_strideM, grad_v.stride(1));

ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0));
ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,19 +680,6 @@ struct AttentionBackwardKernel {
unsigned long long dropout_batch_head_rng_offset = 0;
float dropout_prob = 0.0f;

CUTLASS_HOST_DEVICE int32_t o_strideM() const {
return head_dim_value * num_heads;
}
CUTLASS_HOST_DEVICE int32_t gQ_strideM() const {
return gQKV_strideM_multiplier * num_heads * head_dim;
}
CUTLASS_HOST_DEVICE int32_t gK_strideM() const {
return gQKV_strideM_multiplier * num_h 10000 eads * head_dim;
}
CUTLASS_HOST_DEVICE int32_t gV_strideM() const {
return gQKV_strideM_multiplier * num_heads * head_dim_value;
}

// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int64_t o_strideH = -1;
Expand Down Expand Up @@ -723,6 +710,11 @@ struct AttentionBackwardKernel {
int64_t gV_strideH = 0;
int64_t gB_strideH = 0;

int32_t o_strideM = 0;
int32_t gQ_strideM = 0;
int32_t gK_strideM = 0;
int32_t gV_strideM = 0;

CUTLASS_HOST_DEVICE int16_t num_splits_key_device() const {
#ifdef __CUDA_ARCH__
return kEnableSplitKeys ? gridDim.x : 1;
Expand Down Expand Up @@ -789,13 +781,13 @@ struct AttentionBackwardKernel {
value_ptr += k_start * v_strideM;
assert(bias_ptr == nullptr);
assert(grad_bias_ptr == nullptr);
output_ptr += q_start * o_strideM();
output_ptr += q_start * o_strideM;
grad_output_ptr += q_start * gO_strideM;
delta_ptr += q_start;

grad_query_ptr += q_start * gQ_strideM();
grad_key_ptr += k_start * gK_strideM();
grad_value_ptr += k_start * gV_strideM();
grad_query_ptr += q_start * gQ_strideM;
grad_key_ptr += k_start * gK_strideM;
grad_value_ptr += k_start * gV_strideM;
}

query_ptr += batch_id * q_strideB + head_id * q_strideH;
Expand Down Expand Up @@ -1422,8 +1414,8 @@ struct AttentionBackwardKernel {
if (!skipBoundsChecks && key >= p.num_keys) {
continue;
}
auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM;
auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM;

for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
gv_ptr[k] = scalar_t(0);
Expand Down Expand Up @@ -1788,8 +1780,8 @@ struct AttentionBackwardKernel {
num_keys_in_block, p.head_dim_value - col, num_queries_in_block);
auto createEpilogueIter = [&]() {
return typename MatmulGradV::OutputTileIterator(
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
p.grad_value_ptr + key_start * p.gV_strideM() + col,
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM},
p.grad_value_ptr + key_start * p.gV_strideM + col,
{num_keys_in_block, p.head_dim_value - col},
thread_id);
};
Expand Down Expand Up @@ -2126,8 +2118,8 @@ struct AttentionBackwardKernel {
// NOTE: We're not releasing the lock because no one is expected
// to come after us (we're the last one to write)
typename MatmulGradQ::OutputTileIterator output_it(
typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
p.grad_query_ptr + query_start * p.gQ_strideM() + col,
typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM},
p.grad_query_ptr + query_start * p.gQ_strideM + col,
{problem_size.m(), problem_size.n()},
thread_id);
// if `direct_store` is True, we store to gmem (`*gmem = accum`)
Expand Down Expand Up @@ -2168,8 +2160,8 @@ struct AttentionBackwardKernel {
num_queries_in_block);
auto createEpilogueIter = [&]() {
return typename MatmulGradK::OutputTileIterator(
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
p.grad_key_ptr + key_start * p.gK_strideM() + col,
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM},
p.grad_key_ptr + key_start * p.gK_strideM + col,
{num_keys_in_block,
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col},
thread_id);
Expand Down Expand Up @@ -2424,8 +2416,8 @@ struct AttentionBackwardKernel {
: cutlass::fast_min(
(int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start);
typename MatmulGradV::OutputTileIterator outputV_it(
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
p.grad_value_ptr + key_start * p.gV_strideM(),
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM},
p.grad_value_ptr + key_start * p.gV_strideM,
{num_keys_in_block, p.head_dim_value},
thread_id);
accumulateInGmem<MatmulGradV>(
Expand All @@ -2437,8 +2429,8 @@ struct AttentionBackwardKernel {
lane_id);

typename MatmulGradK::OutputTileIterator outputK_it(
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
p.grad_key_ptr + key_start * p.gK_strideM(),
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM},
p.grad_key_ptr + key_start * p.gK_strideM,
{num_keys_in_block,
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim},
thread_id);
Expand Down Expand Up @@ -2522,7 +2514,7 @@ struct AttentionBackwardKernel {
laneFirstCol);
const AccessType* __restrict__ output_ptr =
reinterpret_cast<const AccessType*>(
p.output_ptr + (query_start + laneRow) * p.o_strideM() +
p.output_ptr + (query_start + laneRow) * p.o_strideM +
laneFirstCol);

static constexpr int64_t kMaxIters =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ struct AttentionKernel {
query_ptr += batch_id * q_strideB;
key_ptr += batch_id * k_strideB;
value_ptr += batch_id * v_strideB;
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
output_ptr += batch_id * q_strideB;

// Reuse q_strides since we want to guarantee exact match w/ input
if (output_accum_ptr != nullptr) {
output_accum_ptr +=
int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
int64_t(batch_id * q_strideB);
}
q_start = 0;
k_start = 0;
Expand All @@ -252,15 +254,14 @@ struct AttentionKernel {

value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr +=
int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
int64_t(q_start + query_start) * o_strideM + head_id * q_strideH;

if (kSupportsBias && attn_bias_ptr != nullptr) {
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
}
if (output_accum_ptr != nullptr) {
output_accum_ptr +=
int64_t(q_start + query_start) * (head_dim_value * num_heads) +
head_id * head_dim_value;
int64_t(q_start + query_start) * q_strideM + head_id * q_strideH;
} else {
// Accumulate directly in the destination buffer (eg for f32)
output_accum_ptr = (accum_t*)output_ptr;
Expand Down
39 changes: 39 additions & 0 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
#include <c10/util/Array.h>
#include <c10/util/Exception.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#endif

#if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h>
#endif
Expand Down Expand Up @@ -917,4 +924,36 @@ bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug) {
return true;
}

// Create output tensor with strides matching query layout
at::Tensor create_output_with_matching_layout(
const at::Tensor& query,
at::IntArrayRef output_shape,
at::TensorOptions options
) {
// Get the "fill order" - an argsort on the strides of the query tensor
const int dims = query.dim();
std::vector<int64_t> fill_order(dims);
std::iota(fill_order.begin(), fill_order.end(), 0);

const auto query_strides = query.strides();
std::stable_sort(
fill_order.begin(),
fill_order.end(),
[&query_strides](int64_t idx1, int64_t idx2) {
return query_strides[idx1] < query_strides[idx2];
});

// Construct new strides that preserve the same layout ordering
std::vector<int64_t> new_strides(dims);
int64_t current_stride = 1;
for (const int64_t dim_idx : fill_order) {
new_strides[dim_idx] = current_stride;
current_stride *= output_shape[dim_idx];
}

// Create tensor with the constructed strides
return at::empty(output_shape, options)
.as_strided(output_shape, new_strides, 0);
}

} // namespace sdp
6 changes: 6 additions & 0 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@ C10_EXPORT bool can_use_flash_attention(sdp_params const& params, bool debug);
C10_EXPORT bool can_use_mem_efficient_attention(sdp_params const& params, bool debug);
C10_EXPORT bool can_use_cudnn_attention(sdp_params const& params, bool debug);

// Create output tensor with strides matching query layout
at::Tensor create_output_with_matching_layout(
const at::Tensor& query,
at::IntArrayRef output_shape,
at::TensorOptions options);

} // namespace sdp
11 changes: 7 additions & 4 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2937,12 +2937,15 @@ def test_flex_attention_backward_stride_ordering(
)
out = func(query, key, value)
grad_output = torch.randn_like(out)
out.backward(grad_output)

grad_query, grad_key, grad_value = torch.autograd.grad(
out, [query, key, value], grad_output
)

for leaf, grad, name in [
(query, query.grad, "query"),
(key, key.grad, "key"),
(value, value.grad, "value"),
(query, grad_query, "query"),
(key, grad_key, "key"),
(value, grad_value, "value"),
]:
input_stride_order = get_stride_order(grad.stride())
orig_stride_order = get_stride_order(leaf.stride())
Expand Down
Loading
Loading
0