8000 [SDPA-CUDNN] Make CuDNN Attention Opt in · pytorch/pytorch@df8058f · GitHub
[go: up one dir, main page]

Skip to content

Commit df8058f

Browse files
committed
[SDPA-CUDNN] Make CuDNN Attention Opt in
ghstack-source-id: 58f70e7 Pull Request resolved: #138522
1 parent ebd60f4 commit df8058f

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line num 8000 berDiff line change
@@ -64,20 +64,25 @@ bool check_prefer_cudnn_attention() {
6464
#endif
6565
}
6666

67+
// static const bool prefer_cudnn = check_prefer_cudnn_attention();
68+
// return prefer_cudnn ? cudnn_order : default_order;
69+
// return default_order
70+
// constexpr std::array<SDPBackend, num_backends> cudnn_order{
71+
// SDPBackend::cudnn_attention,
72+
// SDPBackend::flash_attention,
73+
// SDPBackend::efficient_attention,
74+
// SDPBackend::math,
75+
// };
76+
6777
// flash_attention V2 is universally faster than efficient_attention and Math
6878
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
6979
constexpr std::array<SDPBackend, num_backends> default_order{
7080
SDPBackend::flash_attention,
71-
SDPBackend::cudnn_attention,
7281
SDPBackend::efficient_attention,
73-
SDPBackend::math};
74-
constexpr std::array<SDPBackend, num_backends> cudnn_order{
82+
SDPBackend::math,
7583
SDPBackend::cudnn_attention,
76-
SDPBackend::flash_attention,
77-
SDPBackend::efficient_attention,
78-
SDPBackend::math};
79-
static const bool prefer_cudnn = check_prefer_cudnn_attention();
80-
return prefer_cudnn ? cudnn_order : default_order;
84+
};
85+
return default_order;
8186
}
8287

8388
bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) {

0 commit comments

Comments
 (0)
0