@@ -180,26 +180,27 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a
180180file (GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip" )
181181# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
182182if (USE_FLASH_ATTENTION)
183- if (DEFINED ENV{USE_CK_FLASH_ATTENTION})
184- set (USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION} )
185- if (USE_CK_FLASH_ATTENTION STREQUAL "1" )
186- if (DEFINED ENV{PYTORCH_ROCM_ARCH})
187- list (LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
188- if (NUM_ARCHS GREATER 1)
189- message (WARNING "Building CK for multiple archs can increase build time considerably!
190- Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for" )
191- endif ()
192- endif ()
193- message (STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled" )
194- message (STATUS "Generating CK kernel instances..." )
195- add_subdirectory (native/transformers/hip/flash_attn/ck)
196- file (GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip" )
197- list (APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip} )
198- # FAv3 Generation
199- add_subdirectory (native/transformers/hip/flash_attn/ck/fav_v3)
200- file (GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip" )
201- list (APPEND native_transformers_hip_hip ${flash_attention_v3_hip} )
183+ if ("$ENV{USE_CK_FLASH_ATTENTION} " STREQUAL "1" )
184+ message (STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead" )
185+ caffe2_update_option(USE_ROCM_CK_SDPA ON )
186+ endif ()
187+ if (USE_ROCM_CK_SDPA)
188+ if (DEFINED ENV{PYTORCH_ROCM_ARCH})
189+ list (LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
190+ if (NUM_ARCHS GREATER 1)
191+ message (WARNING "Building CK for multiple archs can increase build time considerably!
192+ Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for" )
202193 endif ()
194+ endif ()
195+ message (STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled" )
196+ message (STATUS "Generating CK kernel instances..." )
197+ add_subdirectory (native/transformers/hip/flash_attn/ck)
198+ file (GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip" )
199+ list (APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip} )
200+ # FAv3 Generation
201+ add_subdirectory (native/transformers/hip/flash_attn/ck/fav_v3)
202+ file (GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip" )
203+ list (APPEND native_transformers_hip_hip ${flash_attention_v3_hip} )
203204 endif ()
204205 file (GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip" )
205206 file (GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip" )
@@ -418,40 +419,42 @@ if(USE_CUDA)
418419endif ()
419420
420421if (USE_ROCM)
421- # NOTE: The PyTorch build does not actually add_subdirectory
422- # third_party/composable_kernel or use it as a CMake library. What is used
423- # is header only, so this should be ok, except that the CMake build generates
424- # a ck/config.h. We just do that part here. Without this, the ck.h from the
425- # ROCM SDK may get accidentally used instead.
426- function (_pytorch_rocm_generate_ck_conf)
427- set (CK_ENABLE_INT8 "ON" )
428- set (CK_ENABLE_FP16 "ON" )
429- set (CK_ENABLE_FP32 "ON" )
430- set (CK_ENABLE_FP64 "ON" )
431- set (CK_ENABLE_BF16 "ON" )
432- set (CK_ENABLE_FP8 "ON" )
433- set (CK_ENABLE_BF8 "ON" )
434- set (CK_USE_XDL "ON" )
435- set (CK_USE_WMMA "ON" )
436- configure_file (
437- "${Torch_SOURCE_DIR} /third_party/composable_kernel/include/ck/config.h.in"
438- "${CMAKE_CURRENT_BINARY_DIR} /composable_kernel/ck/config.h"
439- )
440- endfunction ()
441- list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /hip)
442- list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/composable_kernel/include )
443- list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/composable_kernel/library/include )
444- list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/composable_kernel/example/ck_tile/01_fmha)
445- list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR} /composable_kernel)
446- list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/aiter/csrc/include )
447- _pytorch_rocm_generate_ck_conf()
422+ if ((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM)
423+ # NOTE: The PyTorch build does not actually add_subdirectory
424+ # third_party/composable_kernel or use it as a CMake library. What is used
425+ # is header only, so this should be ok, except that the CMake build generates
426+ # a ck/config.h. We just do that part here. Without this, the ck.h from the
427+ # ROCM SDK may get accidentally used instead.
428+ function (_pytorch_rocm_generate_ck_conf)
429+ set (CK_ENABLE_INT8 "ON" )
430+ set (CK_ENABLE_FP16 "ON" )
431+ set (CK_ENABLE_FP32 "ON" )
432+ set (CK_ENABLE_FP64 "ON" )
433+ set (CK_ENABLE_BF16 "ON" )
434+ set (CK_ENABLE_FP8 "ON" )
435+ set (CK_ENABLE_BF8 "ON" )
436+ set (CK_USE_XDL "ON" )
437+ set (CK_USE_WMMA "ON" )
438+ configure_file (
439+ "${Torch_SOURCE_DIR} /third_party/composable_kernel/include/ck/config.h.in"
440+ "${CMAKE_CURRENT_BINARY_DIR} /composable_kernel/ck/config.h"
441+ )
442+ endfunction ()
443+ list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /hip)
444+ list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/composable_kernel/include )
445+ list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/composable_kernel/library/include )
446+ list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/composable_kernel/example/ck_tile/01_fmha)
447+ list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR} /composable_kernel)
448+ list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /../../../third_party/aiter/csrc/include )
449+ _pytorch_rocm_generate_ck_conf()
450+ endif ()
448451
449452 # Next two lines are needed because TunableOp uses third-party/fmt
450453 list (APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES >)
451454 list (APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
452- if (USE_FLASH_ATTENTION)
453- list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /native/transformers/hip/flash_attn/ck)
454- endif ()
455+ if (USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA )
456+ list (APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR} /native/transformers/hip/flash_attn/ck)
457+ endif ()
455458 list (APPEND ATen_HIP_SRCS
456459 ${ATen_HIP_SRCS}
457460 ${hip_hip}
@@ -461,12 +464,17 @@ endif()
461464 ${native_quantized_hip_hip}
462465 ${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
463466 )
464- if (WIN32 ) # Windows doesn't support Composable Kernels
467+ if (NOT USE_ROCM_CK_GEMM)
465468 file (GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip" )
466469 file (GLOB native_hip_ck "native/hip/ck*.hip" )
467470 exclude (ATen_HIP_SRCS "${ATen_HIP_SRCS} "
468471 ${native_hip_bgemm} ${native_hip_ck} )
469472 endif ()
473+ if (WIN32 ) # Windows doesn't support Composable Kernels and Triton
474+ exclude (ATen_HIP_SRCS "${ATen_HIP_SRCS} "
475+ ${native_transformers_hip_hip} ${native_transformers_hip_cpp} )
476+ endif ()
477+
470478 # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
471479 list (APPEND all_hip_cpp
472480 ${native_nested_hip_cpp}
0 commit comments