10000 [cuDNN][SDPA] Match `query`'s memory layout ordering for `output` in cuDNN SDPA by eqy · Pull Request #138354 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[cuDNN][SDPA] Match query's memory layout ordering for output in cuDNN SDPA #138354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
109 changes: 88 additions & 21 deletions aten/src/ATen/native/cudnn/MHA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,88 @@ auto fixSizeOneDimStrideSDPA(
}
return strides;
}

void alloc_with_matching_layout(
const Tensor& q,
Tensor& output,
const std::vector<int64_t>& shape) {
TORCH_INTERNAL_ASSERT(
shape.size() == q.sizes().size(),
"cuDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim");

if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) {
output = at::empty_like(q);
return;
}

// get the "fill order," which is just an argsort on the strides
std::vector<int> fill_order(shape.size());
std::iota(fill_order.begin(), fill_order.end(), 0);
const auto q_strides = q.strides();
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) {
return q_strides[idx1] < q_strides[idx2];
});
std::vector<int64_t> ordered_strides(shape.size());
int64_t current_stride = 1;
for (const int dim_idx : fill_order) {
ordered_strides[dim_idx] = current_stride;
current_stride *= shape[dim_idx];
}
output = at::empty(at::IntArrayRef(shape), q.options())
.as_strided(
at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0);
}

void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) {
const int dims = output.sizes().size();
std::vector<int64_t> outer_to_inner(dims);
std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0);
const auto o_strides = output.strides();
std::stable_sort(
outer_to_inner.begin(),
outer_to_inner.end(),
[&o_strides](int idx1, int idx2) {
return o_strides[idx1] > o_strides[idx2];
});
std::vector<int64_t> inverse(dims);
for (int d = 0; d < dims; d++) {
inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) -
outer_to_inner.begin();
}
grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner))
.contiguous()
.permute(at::IntArrayRef(inverse));
}

bool same_strides(const Tensor& t1, const Tensor& t2) {
std::vector<int> t1_strides_no_ones;
std::vector<int> t2_strides_no_ones;
const auto t1strides = t1.strides();
const auto t2strides = t2.strides();
const int dim = t1strides.size();
if (dim != (int)t2strides.size()) {
return false;
}
const auto t1sizes = t1.sizes();
const auto t2sizes = t2.sizes();

// we are going through strides backward here, but if both are backward it's
// comparable
for (int i = 0; i < dim; i++) {
if (t1sizes[i] > 1) {
t1_strides_no_ones.push_back(t1strides[i]);
}
if (t2sizes[i] > 1) {
t2_strides_no_ones.push_back(t2strides[i]);
}
}
return std::equal(
t1_strides_no_ones.begin(),
t1_strides_no_ones.end(),
t2_strides_no_ones.begin(),
t2_strides_no_ones.end());
}
} // namespace

auto build_graph_and_tensors(
Expand Down Expand Up @@ -553,7 +635,8 @@ void run_cudnn_SDP_fprop(
Tensor& dropoutoffset) {
cudnnHandle_t handle = getCudnnHandle();
if (!o.defined()) {
o = at::empty({b, h, s_q, d_v}, q.options());
// q is passed to us in BHSD dim order
alloc_with_matching_layout(q, o, {b, h, s_q, d_v});
}

if (return_softmaxstats && !softmaxstats.defined()) {
Expand Down Expand Up @@ -660,30 +743,14 @@ void run_cudnn_SDP_bprop(
}

Tensor dO_ = dO;
if (!dO.strides()[dO.strides().size() - 1]) {
TORCH_WARN(
"cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported."
" Materializing a contiguous tensor which will increase memory usage...");
dO_ = dO.contiguous();
}
if ( // handle trivial transposed case with a transposed dim of size 1
// see also: https://github.com/pytorch/pytorch/issues/134001
!(dO_.is_contiguous() && o.is_contiguous()) &&
!std::equal(
o.strides().begin(), o.strides().end(), dO.strides().begin())) {
TORCH_WARN(
if (!same_strides(o, dO)) {
TORCH_WARN_ONCE(
"cuDNN SDPA backward got grad_output.strides() != output.strides(), "
"attempting to materialize a grad_output with matching strides...");
if (o.is_contiguous()) {
dO_ = dO.contiguous();
} else {
dO_ = dO.transpose(1, 2).contiguous().transpose(1, 2);
}
permute_to_matching_layout(o, dO_);
}
TORCH_INTERNAL_ASSERT(
(dO_.is_contiguous() && o.is_contiguous()) ||
std::equal(
dO_.strides().begin(), dO_.strides().end(), o.strides().begin()),
same_strides(o, dO_),
"cuDNN SDPA expected grad_output.strides() == output.strides(), "
"the previous step probably failed to materialize a grad_output "
"with matching strides...");
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,21 @@ namespace {
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
bool check_prefer_cudnn_attention() {
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000
// TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0
// see context: https://github.com/pytorch/pytorch/issues/138340
// return false;
#if defined(CUDNN_VERSION)

#if CUDNN_VERSION > 90000
auto dprops = at::cuda::getCurrentDeviceProperties();
return dprops->major >= 9;
#else
return false;
#endif

#else
return false;
#endif
}

// flash_attention V2 is universally faster than efficient_attention and Math
Expand Down
36 changes: 33 additions & 3 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2529,9 +2529,9 @@ def test_cudnn_attention_trivial_output_transpose(self, device):
def test_cudnn_attention_nonmodulo64seqlen(self, device):
# see also: https://github.com/pytorch/pytorch/issues/137347
mask = torch.randint(0, 2, (2, 1, 157, 6404)).to(device="cuda", dtype=torch.bool)
q = torch.randn(2, 32, 157, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True)
q = torch.randn(2, 32, 157, 128, device='cuda', dtype=torch.float16, requires_grad=True)
k = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.float16, requires_grad=True)
v = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.float16, requires_grad=True)
q_cpu = q.detach().clone().cpu()
k_cpu = k.detach().clone().cpu()
v_cpu = v.detach().clone().cpu()
Expand Down Expand Up @@ -2564,6 +2564,36 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device):
torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)

@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_preserves_query_layout(self, device):

def test_attention(backend: SDPBackend, permute_order: List[List[int]]):
BHSqD = [4, 16, 256, 64]
BHSkvD = [4, 16, 512, 64]

shape_q = [BHSqD[idx] for idx in permute_order]
shape_kv = [BHSkvD[idx] for idx in permute_order]
reverse = [permute_order.index(idx) for idx in range(4)]
q = torch.randn(*shape_q, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse)
k = torch.randn(*shape_kv, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse)
v = torch.randn(*shape_kv, dtype=torch.bfloat16, device='cuda', requires_grad=True).permute(reverse)
self.assertEqual(q.shape, BHSqD)
self.assertEqual(k.shape, BHSkvD)
self.assertEqual(v.shape, BHSkvD)

with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(q, k, v)
self.assertTrue(out.permute(permute_order).is_contiguous())
out.sum().backward()

permute_orders = list()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit I think itertols has a permuatations func:

from itertools import permutations
permutable = (0, 1, 2)
fixed = [3]
permute_orders = [list(perm) + fixed for perm in permutations(permutable)]

permutable = [0, 1, 2]
permute_orders = itertools.permutations(permutable)

for permute_order in permute_orders:
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])

@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("mask_dim", [1, 2, 3, 4])
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
Expand Down
Loading
0