8000 [release/2.3] [ROCM] Properly disable Flash Attention/Efficient Atten… · ROCm/pytorch@1b935e2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1b935e2

Browse files
[release/2.3] [ROCM] Properly disable Flash Attention/Efficient Attention with environment variables (#1571)
Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can compile correctly. This is cherry-picked version of pytorch#133866 --------- Co-authored-by: Pruthvi Madugundu <pruthvigithub@gmail.com>
1 parent 772df6b commit 1b935e2

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,16 @@ cmake_dependent_option(
773773
Will be disabled if not supported by the platform" ON
774774
"USE_CUDA" OFF)
775775

776+
#
777+
# Cannot be put into Dependencies.cmake due circular dependency:
778+
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
779+
#
780+
if(USE_ROCM)
781+
if(USE_FLASH_ATTENTION)
782+
include(cmake/External/aotriton.cmake)
783+
endif()
784+
endif()
785+
776786
if(DEBUG_CUDA)
777787
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
778788
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
#include <functional>
2323

2424
#if USE_ROCM
25+
#if defined(USE_FLASH_ATTENTION)
2526
#include <aotriton/flash.h>
27+
#define USE_AOTRITON 1
28+
#endif
2629
#endif
2730

2831
/**
@@ -187,6 +190,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
187190
using sm80 = SMVersion<8, 0>;
188191
using sm90 = SMVersion<9, 0>;
189192
#if USE_ROCM
193+
#if USE_AOTRITON
190194
auto stream = at::cuda::getCurrentCUDAStream().stream();
191195
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
192196
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -196,6 +200,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
196200
}
197201
return false;
198202
}
203+
#else
204+
return false;
205+
#endif
199206
#else
200207
auto dprops = at::cuda::getCurrentDeviceProperties();
201208
if (!check_sm_version<sm80, sm90>(dprops)) {
@@ -217,6 +224,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
217224
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
218225
using sm50 = SMVersion<5, 0>;
219226
using sm90 = SMVersion<9, 0>;
227+
#if USE_ROCM
228+
return false;
229+
#else
220230
auto dprops = at::cuda::getCurrentDeviceProperties();
221231
if (!check_sm_version<sm50, sm90>(dprops)) {
222232
if (debug) {
@@ -230,6 +240,8 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
230240
return false;
231241
}
232242
return true;
243+
#endif
244+
return false;
233245
}
234246

235247
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(

cmake/Dependencies.cmake

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1348,7 +1348,6 @@ if(USE_ROCM)
13481348
message(STATUS "Disabling Kernel Assert for ROCm")
13491349
endif()
13501350

1351-
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
13521351
if(USE_CUDA)
13531352
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
13541353
endif()

0 commit comments

Comments
 (0)
0