8000 [ROCm] Ck backend UX refactor (#152951) · pytorch/pytorch@5f5f508 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5f5f508

Browse files
alugoreyjeffdailyjithunnair-amdjaneyx99
authored andcommitted
[ROCm] Ck backend UX refactor (#152951)
Refactors how the enablement/disablement of CK Gemms and SDPA works. - Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms. - USE_ROCM_CK_GEMM is set to True by default on Linux - Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA. - USE_ROCM_CK_SDPA is set to False by default - (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release) - Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it. - the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default Pull Request resolved: #152951 Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus Co-authored-by: Jeff Daily <jeff.daily@amd.com> Co-authored-by: Jithun Nair <jithun.nair@amd.com> Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
1 parent da1f608 commit 5f5f508

File tree

23 files changed

+232
-105
lines changed

23 files changed

+232
-105
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ cmake_dependent_option(
240240
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
241241
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
242242
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
243+
cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF)
244+
option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF)
243245
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
244246
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
245247
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF

aten/src/ATen/CMakeLists.txt

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,27 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a
180180
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
181181
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
182182
if(USE_FLASH_ATTENTION)
183-
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
184-
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
185-
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
186-
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
187-
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
188-
if(NUM_ARCHS GREATER 1)
189-
message(WARNING "Building CK for multiple archs can increase build time considerably!
190-
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
191-
endif()
192-
endif()
193-
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
194-
message(STATUS "Generating CK kernel instances...")
195-
add_subdirectory(native/transformers/hip/flash_attn/ck)
196-
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
197-
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
198-
# FAv3 Generation
199-
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
200-
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
201-
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
183+
if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1")
184+
message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead")
185+
caffe2_update_option(USE_ROCM_CK_SDPA ON)
186+
endif()
187+
if(USE_ROCM_CK_SDPA)
188+
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
189+
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
190+
if(NUM_ARCHS GREATER 1)
191+
message(WARNING "Building CK for multiple archs can increase build time considerably!
192+
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
202193
endif()
194+
endif()
195+
message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
196+
message(STATUS "Generating CK kernel instances...")
197+
add_subdirectory(native/transformers/hip/flash_attn/ck)
198+
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
199+
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
200+
# FAv3 Generation
201+
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
202+
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
203+
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
203204
endif()
204205
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
205206
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
@@ -418,40 +419,42 @@ if(USE_CUDA)
418419
endif()
419420

420421
if(USE_ROCM)
421-
# NOTE: The PyTorch build does not actually add_subdirectory
422-
# third_party/composable_kernel or use it as a CMake library. What is used
423-
# is header only, so this should be ok, except that the CMake build generates
424-
# a ck/config.h. We just do that part here. Without this, the ck.h from the
425-
# ROCM SDK may get accidentally used instead.
426-
function(_pytorch_rocm_generate_ck_conf)
427-
set(CK_ENABLE_INT8 "ON")
428-
set(CK_ENABLE_FP16 "ON")
429-
set(CK_ENABLE_FP32 "ON")
430-
set(CK_ENABLE_FP64 "ON")
431-
set(CK_ENABLE_BF16 "ON")
432-
set(CK_ENABLE_FP8 "ON")
433-
set(CK_ENABLE_BF8 "ON")
434-
set(CK_USE_XDL "ON")
435-
set(CK_USE_WMMA "ON")
436-
configure_file(
437-
"${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
438-
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
439-
)
440-
endfunction()
441-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
442-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
443-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
444-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
445-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
446-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
447-
_pytorch_rocm_generate_ck_conf()
422+
if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM)
423+
# NOTE: The PyTorch build does not actually add_subdirectory
424+
# third_party/composable_kernel or use it as a CMake library. What is used
425+
# is header only, so this should be ok, except that the CMake build generates
426+
# a ck/config.h. We just do that part here. Without this, the ck.h from the
427+
# ROCM SDK may get accidentally used instead.
428+
function(_pytorch_rocm_generate_ck_conf)
429+
set(CK_ENABLE_INT8 "ON")
430+
set(CK_ENABLE_FP16 "ON")
431+
set(CK_ENABLE_FP32 "ON")
432+
set(CK_ENABLE_FP64 "ON")
433+
set(CK_ENABLE_BF16 "ON")
434+
set(CK_ENABLE_FP8 "ON")
435+
set(CK_ENABLE_BF8 "ON")
436+
set(CK_USE_XDL "ON")
437+
set(CK_USE_WMMA "ON")
438+
configure_file(
439+
"${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
440+
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
441+
)
442+
endfunction()
443+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
444+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
445+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
446+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
447+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
448+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
449+
_pytorch_rocm_generate_ck_conf()
450+
endif()
448451

449452
# Next two lines are needed because TunableOp uses third-party/fmt
450453
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
451454
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
452-
if(USE_FLASH_ATTENTION)
453-
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
454-
endif()
455+
if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA)
456+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
457+
endif()
455458
list(APPEND ATen_HIP_SRCS
456459
${ATen_HIP_SRCS}
457460
${hip_hip}
@@ -461,12 +464,17 @@ endif()
461464
${native_quantized_hip_hip}
462465
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
463466
)
464-
if(WIN32) # Windows doesn't support Composable Kernels
467+
if(NOT USE_ROCM_CK_GEMM)
465468
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
466469
file(GLOB native_hip_ck "native/hip/ck*.hip")
467470
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
468471
${native_hip_bgemm} ${native_hip_ck})
469472
endif()
473+
if(WIN32) # Windows doesn't support Composable Kernels and Triton
474+
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
475+
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
476+
endif()
477+
470478
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
471479
list(APPEND all_hip_cpp
472480
${native_nested_hip_cpp}

aten/src/ATen/Context.cpp

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,9 @@ at::BlasBackend Context::blasPreferredBackend() {
480480
// call site for blasPreferredBackend(), we set it to an actual value.
481481
if (blas_preferred_backend == at::BlasBackend::Default) {
482482
blas_preferred_backend = at::BlasBackend::Cublas;
483+
// This logic sits in the getter because it needs to validate
484+
// values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT
485+
// which initialize the backend without calling the setter
483486
#ifdef USE_ROCM
484487
// AMD Instinct targets prefer hipblaslt
485488
static const bool hipblaslt_preferred = []() {
@@ -509,6 +512,10 @@ at::BlasBackend Context::blasPreferredBackend() {
509512
// hipblaslt support for all archs is not as complete as hipblas
510513
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
511514
static const bool hipblaslt_unsupported = []() {
515+
if(!hasCuBLASLt())
516+
{
517+
return true;
518+
}
512519
static const std::vector<std::string> archs = {
513520
"gfx90a", "gfx942",
514521
#if ROCM_VERSION >= 60300
@@ -534,6 +541,24 @@ at::BlasBackend Context::blasPreferredBackend() {
534541
return blas_preferred_backend;
535542
}
536543

544+
bool Context::ckSupported() {
545+
#ifdef USE_ROCM
546+
static const std::vector<std::string> supported_archs = {
547+
"gfx90a", "gfx942", "gfx950"
548+
};
549+
for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) {
550+
if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) {
551+
TORCH_WARN_ONCE(
552+
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
553+
return false;
554+
}
555+
}
556+
return true;
557+
#else
558+
return false;
559+
#endif
560+
}
561+
537562
void Context::setBlasPreferredBackend(at::BlasBackend b) {
538563
#ifdef _MSC_VER
539564
TORCH_WARN_ONCE(
@@ -543,8 +568,14 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
543568
#else
544569
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
545570
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
546-
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
547-
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
571+
#ifdef USE_ROCM
572+
static const bool ckSupportedFlag = ckSupported();
573+
static const bool hasCKGEMMFlag = hasCKGEMM();
574+
TORCH_CHECK((b != at::BlasBackend::Ck) || (ckSupportedFlag && hasCKGEMMFlag),
575+
"Cannot set preferred blas backend to CK since following conditions are not true: ",
576+
"architecture supported for CK: ", ckSupportedFlag,
577+
", PyTorch built with CK GEMM support: ", hasCKGEMMFlag);
578+
#endif
548579
if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) {
549580
TORCH_WARN_ONCE(
550581
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
@@ -556,35 +587,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
556587
#endif
557588
}
558589

559-
at::ROCmFABackend Context::getROCmFAPreferredBackend() const {
590+
at::ROCmFABackend Context::getROCmFAPreferredBackend() {
591+
#ifdef USE_ROCM
592+
// Set potential "Default" value so we don't have to interpret at call sites.
593+
// We use aotriton backend as the default, for now.
594+
if(rocm_fa_preferred_backend == at::ROCmFABackend::Default) {
595+
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
596+
} else if (rocm_fa_preferred_backend == at::ROCmFABackend::Ck) {
597+
// This logic sits in the getter because it needs to validate
598+
// values set via env vars such as TORCH_ROCM_FA_PREFER_CK
599+
// which initialize the backend without calling the setter
600+
// Perform validity checking
601+
static const bool hasCKSDPAFlag = hasCKSDPA();
602+
static const bool ckSupportedFlag = ckSupported();
603+
if(!(hasCKSDPAFlag && ckSupportedFlag)){
604+
TORCH_WARN_ONCE(
605+
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
606+
"architecture supported for CK: ", ckSupportedFlag,
607+
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
608+
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
609+
}
610+
}
611+
#endif
612+
560613
return rocm_fa_preferred_backend;
561614
}
562615

563616
void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
564-
565-
// TODO: add plumbing for hasCK for validity checking
566-
TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(),
567-
"Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm.");
568617
#ifdef USE_ROCM
569-
if(b == at::ROCmFABackend::Ck) {
570-
static const bool ck_unsupported = []() {
571-
static const std::vector<std::string> archs = {
572-
"gfx90a", "gfx942"
573-
};
574-
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
575-
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
576-
TORCH_WARN_ONCE(
577-
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
578-
return true;
579-
}
580-
}
581-
return false;
582-
}();
583- D306
if(!ck_unsupported) rocm_fa_preferred_backend = b;
584-
}
585-
else {
586-
rocm_fa_preferred_backend = b;
587-
}
618+
static const bool hasCKSDPAFlag = hasCKSDPA();
619+
static const bool ckSupportedFlag = ckSupported();
620+
TORCH_CHECK((b != at::ROCmFABackend::Ck) || (hasCKSDPAFlag && ckSupportedFlag),
621+
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
622+
"architecture supported for CK: ", ckSupportedFlag,
623+
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
588624
#endif
589625
rocm_fa_preferred_backend = b;
590626
}

aten/src/ATen/Context.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class TORCH_API Context {
132132
static bool hasKleidiAI();
133133
static bool hasLAPACK();
134134
static bool hasMKLDNN();
135+
static bool ckSupported();
135136
static bool hasMAGMA() {
136137
return detail::getCUDAHooks().hasMAGMA();
137138
}
@@ -162,6 +163,12 @@ class TORCH_API Context {
162163
static bool hasROCM() {
163164
return detail::getCUDAHooks().hasROCM();
164165
}
166+
static bool hasCKSDPA() {
167+
return detail::getCUDAHooks().hasCKSDPA();
168+
}
169+
static bool hasCKGEMM() {
170+
return detail::getCUDAHooks().hasCKGEMM();
171+
}
165172
static bool hasHIP() {
166173
return detail::getHIPHooks().hasHIP();
167174
}
@@ -252,7 +259,7 @@ class TORCH_API Context {
252259
at::BlasBackend blasPreferredBackend();
253260
void setBlasPreferredBackend(at::BlasBackend);
254261

255-
at::ROCmFABackend getROCmFAPreferredBackend() const;
262+
at::ROCmFABackend getROCmFAPreferredBackend();
256263
void setROCmFAPreferredBackend(at::ROCmFABackend);
257264

258265
// Note [Enabling Deterministic Operations]

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
832832
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
833833
}
834834
}
835-
#if defined(USE_ROCM) && !defined(_MSC_VER)
835+
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
836836
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
837837
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
838838
}
@@ -1273,7 +1273,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
12731273
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
12741274
#endif
12751275
}
1276-
#if defined(USE_ROCM) && !defined(_MSC_VER)
1276+
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
12771277
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
12781278
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
12791279
}
@@ -1289,7 +1289,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
12891289
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
12901290
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
12911291
}
1292-
#if defined(USE_ROCM) && !defined(_MSC_VER)
1292+
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
12931293
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
12941294
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
12951295
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
@@ -1341,7 +1341,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
13411341
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
13421342
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
13431343
}
1344-
#if defined(USE_ROCM) && !defined(_MSC_VER)
1344+
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
13451345
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
13461346
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
13471347
}
@@ -1357,7 +1357,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
13571357
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
13581358
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
13591359
}
1360-
#if defined(USE_ROCM) && !defined(_MSC_VER)
1360+
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
13611361
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
13621362
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
13631363
}

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,27 @@ bool CUDAHooks::hasCuBLASLt() const {
207207
#endif
208208
}
209209

210+
211+
bool CUDAHooks::hasCKSDPA() const {
212+
#if !defined(USE_ROCM)
213+
return false;
214+
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_SDPA)
215+
return true;
216+
#else
217+
return false;
218+
#endif
219+
}
220+
221+
bool CUDAHooks::hasCKGEMM() const {
222+
#if !defined(USE_ROCM)
223+
return false;
224+
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
225+
return true;
226+
#else
227+
return false;
228+
#endif
229+
}
230+
210231
bool CUDAHooks::hasROCM() const {
211232
// Currently, this is same as `compiledWithMIOpen`.
212233
// But in future if there are ROCm builds without MIOpen,

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
3131
bool hasCuSOLVER() const override;
3232
bool hasCuBLASLt() const override;
3333
bool hasROCM() const override;
34+
bool hasCKSDPA() const override;
35+
bool hasCKGEMM() const override;
3436
const at::cuda::NVRTC& nvrtc() const override;
3537
DeviceIndex current_device() const override;
3638
bool isBuilt() const override {return true;}

0 commit comments

Comments
 (0)
0