8000 [ROCm] Bump AOTriton to 0.11b (#161754) · pytorch/pytorch@98efc9e · GitHub
[go: up one dir, main page]

Skip to content

Commit 98efc9e

Browse files
xinyazhangpytorchmergebot
authored andcommitted
[ROCm] Bump AOTriton to 0.11b (#161754)
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b: * Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements - AITER ASM kernels deliver over 500TFLOPS training performance. See [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more details. * Now returns natural based `logsumexp` tensor, matching CUDA's behavior - PR #156903 is reverted in this PR as well since it is not needed anymore. * Enables `CausalVariant.LOWER_RIGHT` The build system changes drastically along with new packaging scheme of AOTriton 0.11 * AOTriton 0.11 packs GPU images separately from AOTriton runtime * `aotriton.cmake` now selectively downloads image packs according to `PYTORCH_ROCM_ARCH` * `aotriton.cmake` now only use pre-compiled runtime library that exactly matches the ROCM in the build environment. For PyTorch builds with ROCm versions not listed in the file, the build process will build AOTriton runtime without GPU images from source - This avoids any further ABI breaks like ROCM 6.4 -> 7.0 - recursive git clone is disabled since building AOTriton runtime does not require submodules. Bug fixes: * Fix a kernel bug introduced when implementing SWA Known Problems: * gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status due to accuracy issues. Triton compiler fixes are needed to restore the support status. * Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0. This issue is under investigation. Pull Request resolved: #161754 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
1 parent 994f2a5 commit 98efc9e

File tree

12 files changed

+484
-169
lines changed

12 files changed

+484
-169
lines changed

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,12 +1396,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
13961396
at::Tensor v_t = value.transpose(1, 2);
13971397
at::Tensor output_t = res.transpose(1, 2);
13981398
bool is_causal;
1399-
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
1400-
is_causal = true;
1401-
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
1399+
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
14021400
is_causal = false;
14031401
} else {
1404-
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
1402+
is_causal = true;
1403+
#if AOTRITON_V3_API == 0
1404+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
1405+
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
1406+
}
1407+
#endif
14051408
}
14061409

14071410
at::Tensor atomic_counter;
@@ -1426,7 +1429,51 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
14261429
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
14271430
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
14281431
hipError_t err; // TODO: Error handling
1429-
if (seqstart_q.has_value()) {
1432+
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
1433+
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
1434+
using aotriton::v3::flash::CausalType;
1435+
using aotriton::v3::flash::VarlenType;
1436+
using aotriton::v3::flash::WindowValue;
1437+
aotriton::v3::flash::attn_fwd_params params;
1438+
params.Q = mk_aotensor(q_t, "q");
1439+
params.K = mk_aotensor(k_t, "k");
1440+
params.V = mk_aotensor(v_t, "v");
1441+
params.Sm_scale = softmax_scale;
1442+
params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2;
1443+
params.Out = mk_aotensor(output_t, "Out");
1444+
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
1445+
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
1446+
params.dropout_p = dropout_p;
1447+
params.philox_seed_ptr = seed;
1448+
params.philox_offset1 = offset1;
1449+
params.philox_offset2 = offset2;
1450+
params.philox_seed_output = seed_output;
1451+
params.philox_offset_output = offset_output;
1452+
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
1453+
params.persistent_atomic_counter = persistent_counter;
1454+
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
1455+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
1456+
params.window_left = WindowValue::TopLeftAligned;
1457+
params.window_right = WindowValue::TopLeftAligned;
1458+
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
1459+
params.window_left = WindowValue::BottomRightAligned;
1460+
params.window_right = WindowValue::BottomRightAligned;
1461+
}
1462+
if (bias.has_value()) {
1463+
params.B = mk_aotensor(bias.value(), "bias");
1464+
}
1465+
if (seqstart_q.has_value()) {
1466+
params.varlen_type = VarlenType::CompactVarlen;
1467+
params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q");
1468+
params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k");
1469+
} else {
1470+
params.varlen_type = VarlenType::None;
1471+
}
1472+
err = aotriton::v3::flash::attn_fwd(params,
1473+
aotriton::v3::flash::attn_fwd_params::kVersion,
1474+
stream);
1475+
#endif // AOTRITON_V3_API
1476+
} else if (seqstart_q.has_value()) {
14301477
// varlen aka nested tensor
14311478
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
14321479
mk_aotensor(k_t, "k"),

aten/src/ATen/native/transformers/cuda/attention_backward.cu

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <ATen/Functions.h>
2525
#include <ATen/NativeFunctions.h>
2626
#else
27+
#include <ATen/ops/zeros.h>
2728
#include <ATen/ops/zeros_like.h>
2829
#include <ATen/ops/empty_strided.h>
2930
#include <ATen/ops/_cudnn_attention_backward.h>
@@ -47,6 +48,7 @@
4748
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
4849
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
4950
#else
51+
#include <ATen/native/transformers/hip/gemm_kernel_utils.h>
5052
// MemoryEfficient Attention Specific Imports for ROCM
5153
#ifndef DISABLE_AOTRITON
5254
#include <ATen/native/transformers/hip/aotriton_adapter.h>
@@ -544,12 +546,15 @@ _efficient_attention_backward(
544546
}
545547
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
546548
bool is_causal;
547-
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
548-
is_causal = true;
549-
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
549+
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
550550
is_causal = false;
551551
} else {
552-
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
552+
is_causal = true;
553+
#if AOTRITON_V3_API == 0
554+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
555+
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
556+
}
557+
#endif
553558
}
554559
at::Tensor q_t = query.permute({0,2,1,3});
555560
at::Tensor k_t = key.permute({0,2,1,3});
@@ -568,7 +573,62 @@ _efficient_attention_backward(
568573
using sdp::aotriton_adapter::mk_aoscalartensor;
569574
using sdp::aotriton_adapter::cast_dtype;
570575
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
571-
if (cu_seqlens_q.has_value()) {
576+
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
577+
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
578+
using aotriton::v3::flash::CausalType;
579+
using aotriton::v3::flash::VarlenType;
580+
using aotriton::v3::flash::WindowValue;
581+
aotriton::v3::flash::attn_bwd_params params;
582+
params.Q = mk_aotensor(q_t, "q");
583+
params.K = mk_aotensor(k_t, "k");
584+
params.V = mk_aotensor(v_t, "v");
585+
params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4;
586+
params.Sm_scale = softmax_scale;
587+
params.Out = mk_aotensor(out_t, "out");
588+
params.DO = mk_aotensor(dout_t, "dout");
589+
params.DK = mk_aotensor(dk_t, "dk");
590+
params.DV = mk_aotensor(dv_t, "dv");
591+
params.DQ = mk_aotensor(dq_t, "dq");
592+
params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4;
593+
params.L = mk_aotensor<2>(softmax_lse, "L");
594+
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
595+
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
596+
params.dropout_p = float(dropout_p);
597+
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
598+
params.philox_offset1 = mk_aoscalartensor(philox_offset);
599+
params.philox_offset2 = 0;
600+
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
601+
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
602+
params.window_left = WindowValue::TopLeftAligned;
603+
params.window_right = WindowValue::TopLeftAligned;
604+
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
605+
params.window_left = WindowValue::BottomRightAligned;
606+
params.window_right = WindowValue::BottomRightAligned;
607+
}
608+
#if AOTRITON_ALWAYS_V3_API
609+
using sdp::aotriton_adapter::mklazy_empty_like;
610+
using sdp::aotriton_adapter::mklazy_fp32zeros;
611+
using sdp::aotriton_adapter::LazyTensorContext;
612+
LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" };
613+
LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" };
614+
params.D = mklazy_empty_like<2>(&lazy_delta);
615+
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
616+
#else
617+
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
618+
params.D = mk_aotensor<2>(delta, "delta");
619+
#endif
620+
if (cu_seqlens_q.has_value()) {
621+
params.varlen_type = VarlenType::CompactVarlen;
622+
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q");
623+
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k");
624+
} else {
625+
params.varlen_type = VarlenType::None;
626+
}
627+
err = aotriton::v3::flash::attn_bwd(params,
628+
aotriton::v3::flash::attn_bwd_params::kVersion,
629+
stream);
630+
#endif // AOTRITON_V3_API
631+
} else if (cu_seqlens_q.has_value()) {
572632
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
573633
// varlen aka Nested tensor
574634
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <c10/util/irange.h>
1717
#include <c10/util/Array.h>
1818
#include <c10/util/Exception.h>
19+
#include <c10/util/string_view.h>
1920

2021
#if AT_CUDNN_ENABLED()
2122
#include <ATen/cudnn/cudnn-wrapper.h>
@@ -25,9 +26,12 @@
2526

2627
#if USE_ROCM
2728
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
29+
#include <ATen/native/transformers/hip/aotriton_versions.h>
2830
#include <aotriton/flash.h>
2931
#define USE_ROCM_ATTENTION 1
3032
#endif
33+
#else
34+
#define USE_ROCM_ATTENTION 0
3135
#endif
3236

3337
// Avoid potential compiler -Wall -Werror complains undefined macro
@@ -129,9 +133,24 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
129133
// caller_is_meff is added to make the TORCH_WARN message showing the correct result
130134
template<bool caller_is_meff = false>
131135
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
132-
#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9
136+
#if USE_ROCM_ATTENTION
133137
// AOTriton 0.9+ supports head_dim up to 512
134-
const auto max_size = c10::SymInt(512);
138+
const static auto max_hdim = []() {
139+
#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11)
140+
// gfx11xx only support hdim <= 256 on AOTriton 0.11
141+
auto dprops = at::cuda::getCurrentDeviceProperties();
142+
const c10::basic_string_view<char> arch(dprops->gcnArchName);
143+
if (arch.starts_with("gfx11")) {
144+
return 256;
145+
}
146+
#endif // AOTriton 0.11
147+
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9)
148+
return 512;
149+
#else
150+
return 256;
151+
#endif
152+
}();
153+
const auto max_size = c10::SymInt(max_hdim);
135154
#else
136155
// All head_dim sizes must be equal and less than 256
137156
const auto max_size = c10::SymInt(256);

aten/src/ATen/native/transformers/hip/aotriton_adapter.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
#ifdef USE_ROCM
44

5+
// Expect to be included after headers of at::zeros_like and at::empty_like
6+
57
#include <aotriton/dtypes.h>
68
#include <aotriton/util.h>
9+
#include <aotriton/config.h>
10+
#include <ATen/native/transformers/hip/aotriton_versions.h>
711

812
////////////////////////////////////////////////////////////////////////////////
913
// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h
@@ -111,6 +115,61 @@ inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr)
111115
aotriton::DType::kInt32);
112116
}
113117

118+
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11)
119+
120+
struct LazyTensorContext {
121+
at::Tensor like_tensor;
122+
std::string_view tensor_name;
123+
at::Tensor tensor;
124+
};
125+
126+
template<int kRank, bool kRequireZeros>
127+
struct LazyTensorFunctions : public LazyTensorContext {
128+
static aotriton::TensorView<kRank> acquire(void* cookie) {
129+
auto ctx = (LazyTensorContext*)cookie;
130+
if (!ctx->tensor.defined()) {
131+
auto q = ctx->like_tensor;
132+
if constexpr (kRequireZeros) {
133+
ctx->tensor = at::zeros(q.sizes(),
134+
q.options().dtype(at::kFloat));
135+
} else {
136+
ctx->tensor = at::empty_like(q);
137+
}
138+
}
139+
return mk_aotensor<kRank>(ctx->tensor, ctx->tensor_name);
140+
}
141+
142+
static void dispose(void* cookie) {
143+
}
144+
};
145+
146+
template<int kRank, bool kRequireZeros>
147+
aotriton::LazyTensor<kRank> mklazy_common(LazyTensorContext* cookie)
148+
{
149+
using LTF = LazyTensorFunctions<kRank, kRequireZeros>;
150+
return aotriton::LazyTensor<kRank> {
151+
.cookie = cookie,
152+
.acquire = &LTF::acquire,
153+
.dispose = &LTF::dispose
154+
};
155+
}
156+
157+
template<int kRank>
158+
auto mklazy_empty_like(LazyTensorContext* cookie)
159+
{
160+
return mklazy_common<kRank, false>(cookie);
161+
}
162+
163+
164+
// Note: this will not keep the original strides
165+
template<int kRank>
166+
auto mklazy_fp32zeros(LazyTensorContext* cookie)
167+
{
168+
return mklazy_common<kRank, true>(cookie);
169+
}
170+
171+
#endif // >= 0.11
172+
114173
} // namespace aotriton_adapter
115174

116175
} // namespace sdp
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#ifdef USE_ROCM
4+
5+
#define AOTRITON_VERSION_INT(x, y) (x * 100 + y)
6+
#define AOTRITON_VERSION_CURRENT (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR)
7+
8+
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11)
9+
#define AOTRITON_ALWAYS_V3_API 1
10+
#else
11+
#define AOTRITON_ALWAYS_V3_API 0
12+
#endif
13+
14+
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 10)
15+
#define AOTRITON_V3_API 1
16+
#else
17+
#define AOTRITON_V3_API 0
18+
#endif
19+
20+
#endif

0 commit comments

Comments
 (0)
0