10000 [release/2.4] [ROCM] Properly disable Flash Attention/Efficient Atten… · ROCm/pytorch@98d727f · GitHub
[go: up one dir, main page]

Skip to content

Commit 98d727f

Browse files
xinyazhangjithunnair-amd
authored andcommitted
[release/2.4] [ROCM] Properly disable Flash Attention/Efficient Attention with environment variables (#1570)
Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can compile correctly. This is cherry-picked version of pytorch#133866
1 parent ebec4e9 commit 98d727f

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,16 @@ cmake_dependent_option(
874874
Will be disabled if not supported by the platform" ON
875875
"USE_CUDA OR USE_ROCM" OFF)
876876

877+
#
878+
# Cannot be put into Dependencies.cmake due circular dependency:
879+
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
880+
#
881+
if(USE_ROCM)
882+
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
883+
include(cmake/External/aotriton.cmake)
884+
endif()
885+
endif()
886+
877887
if(DEBUG_CUDA)
878888
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
879889
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
#include <c10/util/string_view.h>
2121

2222
#if USE_ROCM
23+
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
2324
#include <aotriton/flash.h>
25+
#define USE_AOTRITON 1
26+
#endif
2427
#endif
2528

2629
/**
@@ -185,6 +188,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
185188
using sm80 = SMVersion<8, 0>;
186189
using sm90 = SMVersion<9, 0>;
187190
#if USE_ROCM
191+
#if USE_AOTRITON
188192
auto stream = at::cuda::getCurrentCUDAStream().stream();
189193
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
190194
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -194,6 +198,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
194198
}
195199
return false;
196200
}
201+
#else
202+
return false;
203+
#endif
197204
#else
198205
auto dprops = at::cuda::getCurrentDeviceProperties();
199206
if (!check_sm_version<sm80, sm90>(dprops)) {
@@ -216,6 +223,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
216223
using sm50 = SMVersion<5, 0>;
217224
using sm90 = SMVersion<9, 0>;
218225
#if USE_ROCM
226+
#if USE_AOTRITON
219227
auto stream = at::cuda::getCurrentCUDAStream().stream();
220228
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
221229
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -225,6 +233,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
225233
}
226234
return false;
227235
}
236+
#else
237+
return false;
238+
#endif
228239
#else
229240
auto dprops = at::cuda::getCurrentDeviceProperties();
230241
if (!check_sm_version<sm50, sm90>(dprops)) {
@@ -238,8 +249,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
238249
}
239250
return false;
240251
}
241-
#endif
242252
return true;
253+
#endif
254+
return false;
243255
}
244256

245257
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(

cmake/Dependencies.cmake

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,6 @@ if(USE_ROCM)
10971097
message(STATUS "Disabling Kernel Assert for ROCm")
10981098
endif()
10991099

1100-
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
11011100
if(USE_CUDA)
11021101
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
11031102
endif()

0 commit comments

Comments
 (0)
0