diff --git a/.gitmodules b/.gitmodules index d16e9335..dd976a23 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "third-party/cutlass"] path = third-party/cutlass - url = https://github.com/NVIDIA/cutlass.git + url = git@github.com:NVIDIA/cutlass.git +[submodule "third-party/fmt"] + path = third-party/fmt + url = git@github.com:fmtlib/fmt.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 658aa7bd..ab20d622 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,44 +1,33 @@ # NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT -# TODO: add CUDA utils' library via CMake cmake_minimum_required(VERSION 3.10) project(deep_gemm LANGUAGES CXX CUDA) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CUDA_STANDARD 20) set(CMAKE_VERBOSE_MAKEFILE ON) -find_package(CUDAToolkit REQUIRED) -find_package(pybind11 REQUIRED) - -file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }") -execute_process( - COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu - RESULT_VARIABLE NVCC_RESULT - OUTPUT_VARIABLE NVCC_OUTPUT - ERROR_VARIABLE NVCC_ERROR_OUTPUT - WORKING_DIRECTORY ${CMAKE_BINARY_DIR} -) +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi") +set(CUDA_SEPARABLE_COMPILATION ON) +list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") +list(APPEND CUDA_NVCC_FLAGS "-O3") +list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") -if (NVCC_RESULT EQUAL "0") - set(NVCC_SUPPORTS_SM90 TRUE) - message(STATUS "NVCC supports SM90") -else() - message(STATUS "NVCC does not support SM90") -endif() +set(USE_SYSTEM_NVTX on) +set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") +set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") -if (NVCC_SUPPORTS_SM90) - set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE) - list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a") -endif() +find_package(CUDAToolkit REQUIRED) +find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) -include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include) -include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) -link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CUDA_STANDARD 20) + +include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) +include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) +link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG") -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10") +# The main Python API entrance +pybind11_add_module(deep_gemm_cpp csrc/python_api.cpp) +target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} torch_python cuda) -cuda_add_library(example_gemm STATIC indexing/main.cu) +# Enable kernel code indexing with CMake-based IDEs +cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu) diff --git a/README.md b/README.md index 8df722aa..01e1f540 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,18 @@ # DeepGEMM -DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3). It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. +DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. -Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques. +DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques. Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. ## News +- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module. + - NVRTC and post-compilation SASS optimization are all disabled + - NVRTC will be supported later + - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported + - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details - 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. @@ -16,57 +21,59 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] More correctness tests for grouped-contiguous layout - [x] Shared memory swizzling for output -- [ ] Larger block size on N (up to 256) - [x] MoE scheduler with TMA multicast compatibility - [x] Fix TMA multicast compatibility for indivisible shapes - [x] Skip useless computation on M -- [x] NVRTC as a faster compiler -- [ ] Stolen JIT cache +- [ ] NVRTC as a faster compiler - [ ] Sanitizer for testing - [x] Weight gradient kernels for dense models - [x] Weight gradient kernels for MoE models - [ ] Better `get_best_configs` modeling -- [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang)) - [ ] CUDA PDL support -- [ ] More scaling granularity support via templates - [ ] Larger TMA multicast size for some shapes - [x] MMA template refactor with CUTLASS -- [ ] Optimizations for power efficiency - [x] Remove shape limitations on N and K - [ ] BF16 kernels - [ ] Split/stream-k optimizations +- [ ] Ampere kernels +- [ ] Polish docs ## Quick start ### Requirements -- Hopper architecture GPUs, `sm_90a` must be supported -- Python 3.8 or above -- CUDA 12.3 or above - - **But we highly recommend 12.8 or above for the best performance** -- PyTorch 2.1 or above -- CUTLASS 3.6 or above (could be cloned by Git submodule) +- NVIDIA SM90 or SM100 architecture GPU +- Python 3.8 or higher +- Compilers with C++20 support +- CUDA Toolkit: + - CUDA 12.3 or higher for SM90 + - **We highly recommend 12.9 or higher for the best performance** + - CUDA 12.9 or higher for SM100 +- PyTorch 2.1 or higher +- CUTLASS 4.0 or higher (could be cloned by Git submodule) +- `{fmt}` library (could be cloned by Git submodule) ### Development ```bash # Submodule must be cloned git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git +cd DeepGEMM -# Make symbolic links for third-party (CUTLASS and CuTe) include directories -python setup.py develop +# Link some essential includes and build the CPP JIT module +cat develop.sh +./develop.sh -# Test JIT compilation -python tests/test_jit.py - -# Test all GEMM implements (normal, contiguous-grouped and masked-grouped) +# Test all GEMM implements +python tests/test_layout.py python tests/test_core.py ``` ### Installation ```bash -python setup.py install +cat install.sh +./install.sh ``` Then, import `deep_gemm` in your Python project, and enjoy! @@ -75,118 +82,61 @@ Then, import `deep_gemm` in your Python project, and enjoy! #### Notices -This library exclusively contains GEMM kernels. It requires the LHS scaling factor to be TMA-aligned and transposed, and it only supports the NT format (non-transposed LHS and transposed RHS). For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. +This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: `D = C + A @ B`. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, `fp8_gemm_nt` will do a `D = C + A @ B.T` + +For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different: + +- SM90 requires scaling factors in FP32 format. +- SM100 requires scaling factors in packed [UE8M0](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) format, which packs 4 UE8M0 into a single `torch.int`. + +Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. #### Normal dense GEMMs (non-grouped) -To perform a basic non-grouped FP8 GEMM, call the `deep_gemm.gemm_fp8_fp8_bf16_nt` function. For more details, please refer to the function documentation. +To perform a basic non-grouped FP8 GEMM, call the `fp8_gemm_{nt, nn, tn, tt}` function. For more details, please refer to the function documentation. #### Grouped GEMMs (contiguous layout) -Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. +Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_mk_alignment_for_contiguous_layout()`). For more information, please refer to the `m_grouped_fp8_gemm_{nt, nn}_contiguous` function documentation. -For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_m_alignment_for_contiguous_layout()`). - -For more information, please refer to the `m_grouped_gemm_fp8_fp8_bf16_nt_contiguous` function documentation. +We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to `k_grouped_fp8_gemm_tn_contiguous` for more information. #### Grouped GEMMs (masked layout) During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions. -Use `m_grouped_gemm_fp8_fp8_bf16_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. +Use `fp8_m_grouped_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. #### Utilities The library provides some utility functions besides the above kernels: - `deep_gemm.set_num_sms`: set the maximum SM count to use -- `deep_gemm.get_num_sms`: get the current SM maximum count -- `deep_gemm.get_m_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout +- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set) +- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout - `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size -- `deep_gemm.get_col_major_tma_aligned_tensor`: get a column-major TMA-aligned tensor +- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout +- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor +- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0) +- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel The library also provides some environment variables, which may be useful: - General - - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default + - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default - JIT cache related - - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - - `DG_JIT_DISABLE_CACHE`: `0` or `1`, disable the use of cache directory, `0` by default + - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - NVCC/NVRTC selections - - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default - - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default + - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default - Compiler options - - `DG_JIT_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler, `20` by default - - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default - - `DG_JIT_PRINT_REG_REUSE`: `0` or `1`, print FFMA-interleaving details, `0` by default - - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default -- Post optimization - - `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default - Heuristic selection - - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default -- Testing - - `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. -## Optimizations - -We indicate the techniques excluded from CUTLASS with 🐳. - -#### Persistent warp-specialization - -Following the CUTLASS design, the kernels in DeepGEMM are warp-specialized, enabling overlapping data movement, tensor-core MMA instructions, and CUDA-core promotion. A simplified figure illustrating this process is shown below: - -![design](figures/design.png) - -#### Hopper TMA features - -The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#tensor-memory-accelerator) (TMA) is a new hardware feature introduced by the Hopper architecture, designed for faster and asynchronous data movement. Specifically, we utilize TMA for: - -- TMA load for LHS, LHS scaling factors, and RHS matrices -- TMA store for the output matrix -- TMA multicast (automatically decide LHS or RHS to broadcast) -- TMA descriptor prefetching - -#### Common detail optimizations - -- Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction -- [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups -- Less bank conflicts via 3D TMA or swizzling -- Larger block sizes (up to 256x128 🐳) -- Overlapping as much as possible, e.g., overlapping TMA store and non-TMA RHS scaling factor load 🐳 - -#### A unified and optimized block scheduler - -- [One scheduler](deep_gemm/include/deep_gemm/scheduler.cuh) for all non-grouped and grouped kernels -- [Rasterization](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/media/docs/efficient_gemm.md#threadblock-rasterization) to enhance L2 cache reuse - -#### Fully JIT design 🐳 - -DeepGEMM employs a fully [Just-In-Time](deep_gemm/jit) (JIT) design, with no compilation required at installation. All kernels are compiled at runtime using a lightweight JIT implementation. This approach offers several advantages: - -- GEMM shapes, block sizes, and the number of pipeline stages are treated as compile-time constants - - Saving registers - - Compilers may do more optimizations -- Automatic selection of block sizes, number of warpgroups, optimal pipeline stages, and TMA cluster size - - But without auto-tuning, the optimal one is deterministically selected -- Full unrolling of the MMA pipelines, providing compilers with more optimization opportunities - - Very important for small shapes - - Refer to `launch_k_iterations` in [the kernel file](deep_gemm/include/deep_gemm/fp8_gemm.cuh) for details - -Overall, JIT significantly improves performance for small shapes, similar to the approach of the [Triton](https://github.com/triton-lang/triton/) compiler. - -#### Unaligned block sizes 🐳 - -For certain shapes, block sizes aligned to powers of 2 can lead to underutilized SMs. For instance, with `M=256, N=7168`, a typical block size assignment of `BLOCK_M=128, BLOCK_N=128` results in only `(256 / 128) * (7168 / 128) = 112` out of 132 SMs being utilized. To address this, we support unaligned block sizes like 112, enabling `(256 / 128) * (7168 / 112) = 128` SMs to work in such scenarios. Implementing this technique alongside fine-grained scaling requires careful optimization but ultimately delivers performance gains. - -#### FFMA SASS interleaving 🐳 - -We observe a performance improvement in [the CUTLASS FP8 kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm) between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in [a series of `FADD` instructions](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/include/cutlass/gemm/collective/fp8_accumulation.hpp#L73) is flipped in an interleaving pattern. -After referencing some open-source [CUDA assembler](https://github.com/cloudcores/CuAssembler/blob/96a9f72baf00f40b9b299653fcef8d3e2b4a3d49/CuAsm/CuControlCode.py#L46) implementations, we identified that this bit controls `yield`, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work). - -To leverage this, we develop [a similar script](deep_gemm/jit/interleave_ffma.py) to modify the `FFMA` instructions in the compiled binary. Besides simply modifying the `yield` bit, we also flip the `reuse` bit (registers cannot be reused if the warp is yielded). This adjustment improves performance (10%+ in some cases) for fine-grained scaling FP8 GEMMs by creating more opportunities to overlap MMA instructions with promotion `FFMA` instructions. - ## Acknowledgement DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers! @@ -194,15 +144,3 @@ DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project ## License This code repository is released under [the MIT License](LICENSE). - -## Citation - -```bibtex -@misc{deepgemm2025, - title={DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling}, - author={Chenggang Zhao and Liang Zhao and Jiashi Li and Zhean Xu}, - year={2025}, - publisher = {GitHub}, - howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}}, -} -``` diff --git a/csrc/indexing/main.cu b/csrc/indexing/main.cu new file mode 100644 index 00000000..a05b59c8 --- /dev/null +++ b/csrc/indexing/main.cu @@ -0,0 +1,13 @@ +#include +#include +#include +#include +#include +#include +#include + +using namespace deep_gemm; + +int main() { + return 0; +} diff --git a/csrc/jit/cache.hpp b/csrc/jit/cache.hpp new file mode 100644 index 00000000..fde9aab9 --- /dev/null +++ b/csrc/jit/cache.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +#include "kernel_runtime.hpp" + +namespace deep_gemm { + +class KernelRuntimeCache { + std::unordered_map> cache; + +public: + // TODO: consider cache capacity + KernelRuntimeCache() = default; + + std::shared_ptr get(const std::filesystem::path& dir_path) { + // Hit the runtime cache + if (const auto& iterator = cache.find(dir_path); iterator != cache.end()) + return iterator->second; + + if (KernelRuntime::check_validity(dir_path)) + return cache[dir_path] = std::make_shared(dir_path); + return nullptr; + } +}; + +static auto kernel_runtime_cache = std::make_shared(); + +} // namespace deep_gemm diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp new file mode 100644 index 00000000..4296b358 --- /dev/null +++ b/csrc/jit/compiler.hpp @@ -0,0 +1,172 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/hash.hpp" +#include "../utils/system.hpp" +#include "cache.hpp" +#include "device_runtime.hpp" + +namespace deep_gemm { + +class Compiler { + std::string library_version; + std::filesystem::path library_root_path; + + std::string get_library_version() const { + // Recursively walk through all subdirectories and update hash + std::stringstream ss; + for (const auto& entry: std::filesystem::recursive_directory_iterator(library_include_path / "deep_gemm")) { + if (entry.is_regular_file() and entry.path().extension() == ".cuh") { + std::ifstream file(entry.path(), std::ios::binary); + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + ss << content; + } + } + return get_hex_digest(ss.str()); + } + +public: + std::string signature, flags; + std::filesystem::path library_include_path; + std::filesystem::path cache_dir_path; + + explicit Compiler(const std::filesystem::path& library_root_path) { + // Static library paths + this->library_root_path = library_root_path; + this->library_include_path = library_root_path / "include"; + this->library_version = get_library_version(); + + // Cache settings + cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; + if (const auto& env_cache_dir_path = get_env("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) + cache_dir_path = env_cache_dir_path; + + // The compiler flags applied to all derived compilers + signature = "unknown-compiler"; + std::string ptxas_flags = "--ptxas-options=--register-usage-level=10"; + if (get_env("DG_JIT_PTXAS_VERBOSE", 0)) + ptxas_flags += ",--verbose"; + flags = fmt::format("-std=c++20 --diag-suppress=39,161,174,177,186,940 {}", ptxas_flags); + } + + virtual ~Compiler() = default; + + std::filesystem::path make_tmp_dir() const { + return make_dirs(cache_dir_path / "tmp"); + } + + std::filesystem::path get_tmp_file_path() const { + return make_tmp_dir() / get_uuid(); + } + + void put(const std::filesystem::path& path, const std::string& data) const { + const auto tmp_file_path = get_tmp_file_path(); + + // Write into the temporary file + std::ofstream out(tmp_file_path, std::ios::binary); + DG_HOST_ASSERT(out.write(data.data(), data.size())); + out.close(); + + // Atomically replace + std::filesystem::rename(tmp_file_path, path); + } + + std::shared_ptr build(const std::string& name, const std::string& code) const { + const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code); + const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature)); + + // Hit the runtime cache + if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr) + return runtime; + + // Create the kernel directory + make_dirs(dir_path); + + // Compile into a temporary CUBIN + const auto tmp_cubin_path = get_tmp_file_path(); + compile(code, dir_path, tmp_cubin_path); + + // Replace into the cache directory + make_dirs(dir_path); + std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin"); + + // Put into the runtime cache + const auto& runtime = kernel_runtime_cache->get(dir_path); + DG_HOST_ASSERT(runtime != nullptr); + return runtime; + } + + virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0; +}; + +class NVCCCompiler final: public Compiler { + std::filesystem::path nvcc_path; + + std::pair get_nvcc_version() const { + DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); + + // Call the version command + const auto& command = std::string(nvcc_path) + " --version"; + const auto& [return_code, output] = call_external_command(command); + DG_HOST_ASSERT(return_code == 0); + + // The version should be at least 12.3, for the best performance with 12.9 + int major, minor; + std::smatch match; + DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))"))); + std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor); + DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); + if (major < 12 or (major == 12 and minor < 9)) + printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance"); + return {major, minor}; + } + +public: + NVCCCompiler(const std::filesystem::path& library_root_path, + const std::filesystem::path& cuda_home_path_by_torch): + Compiler(library_root_path) { + // Override the compiler signature + nvcc_path = cuda_home_path_by_torch / "bin" / "nvcc"; + if (const auto& env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) + nvcc_path = env_nvcc_path; + const auto& [nvcc_major, nvcc_minor] = get_nvcc_version(); + signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); + + // The override the compiler flags + flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a " + "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " + "-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda", + flags, library_include_path.c_str(), device_runtime->get_arch()); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + // Write the code into the cache directory + const auto& code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Compile + const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running NVCC command: %s", command.c_str()); + const auto& [return_code, output] = call_external_command(command); + if (return_code != 0) { + printf("NVCC compilation failed: %s", output.c_str()); + DG_HOST_ASSERT(false and "NVCC compilation failed"); + } + + // Print PTXAS log + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) + printf("%s", output.c_str()); + } +}; + +static std::shared_ptr compiler = nullptr; + +} // namespace deep_gemm diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp new file mode 100644 index 00000000..c3237da8 --- /dev/null +++ b/csrc/jit/device_runtime.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include "../utils/exception.hpp" + +namespace deep_gemm { + +class DeviceRuntime { + int num_sms = 0; + std::shared_ptr cached_prop; + +public: + explicit DeviceRuntime() = default; + + std::shared_ptr get_prop() { + if (cached_prop == nullptr) + cached_prop = std::make_shared(*at::cuda::getCurrentDeviceProperties()); + return cached_prop; + } + + std::pair get_arch_pair() { + const auto prop = get_prop(); + return {prop->major, prop->minor}; + } + + int get_arch() { + const auto& [major, minor] = get_arch_pair(); + return major * 10 + minor; + } + + int get_arch_major() { + return get_arch_pair().first; + } + + void set_num_sms(const int& new_num_sms) { + DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount); + num_sms = new_num_sms; + } + + int get_num_sms() { + if (num_sms == 0) + num_sms = get_prop()->multiProcessorCount; + return num_sms; + } +}; + +static auto device_runtime = std::make_shared(); + +} // namespace deep_gemm diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp new file mode 100644 index 00000000..ac95f99a --- /dev/null +++ b/csrc/jit/kernel_runtime.hpp @@ -0,0 +1,139 @@ +#pragma once + +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/system.hpp" +#include "device_runtime.hpp" + +namespace deep_gemm { + +struct LaunchArgs { + std::pair grid_dim; + int num_threads; + int smem_size; + int cluster_dim; + + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} +}; + +template +concept HasLaunchArgs = requires (const T& t) { + { t.launch_args } -> std::convertible_to; +}; + +class KernelRuntime final { +public: + static std::filesystem::path cuda_home; + + cudaLibrary_t library; + cudaKernel_t kernel; + + explicit KernelRuntime(const std::filesystem::path& dir_path) { + // NOLINT(*-pro-type-member-init) + const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump"; + const auto& cubin_path = dir_path / "kernel.cubin"; + if (get_env("DG_JIT_DEBUG")) + printf("Loading CUBIN: %s\n", cubin_path.c_str()); + + // Find the only symbol + // TODO: use kernel enumeration for newer drivers + const std::vector illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; + const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); + DG_HOST_ASSERT(exit_code == 0); + std::istringstream iss(symbols); + std::vector symbol_names; + for (std::string line; std::getline(iss, line); ) { + if (line.find("STT_FUNC") == 0 and std::ranges::none_of(illegal_names, [&](const auto& name) { return line.find(name) != std::string::npos; })) { + const auto& last_space = line.rfind(' '); + symbol_names.push_back(line.substr(last_space + 1)); + } + } + if (get_env("DG_JIT_DEBUG")) { + printf("Symbol names: "); + for (const auto& symbol: symbol_names) + printf("%s, ", symbol.c_str()); + printf("\n"); + } + + // Load from the library + DG_HOST_ASSERT(symbol_names.size() == 1); + DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, symbol_names[0].c_str())); + } + + static void set_cuda_home(const std::string& cuda_home_path_by_torch) { + cuda_home = cuda_home_path_by_torch; + } + + static bool check_validity(const std::filesystem::path& dir_path) { + return std::filesystem::exists(dir_path / "kernel.cu") and + std::filesystem::exists(dir_path / "kernel.cubin"); + } + + ~KernelRuntime() noexcept(false) { + const auto& error = cudaLibraryUnload(library); + DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading); + } +}; + +// Declare after defining +decltype(KernelRuntime::cuda_home) KernelRuntime::cuda_home; + +template +class LaunchRuntime { +public: + template requires HasLaunchArgs + static std::string generate(const Args& args) { + const auto& code = Derived::generate_impl(args); + if (get_env("DG_JIT_DEBUG", 0)) + printf("Generated kernel code: %s\n", code.c_str()); + return code; + } + + template requires HasLaunchArgs + static void launch(const std::shared_ptr& kernel_runtime, const Args& args) { + const auto& kernel = kernel_runtime->kernel; + const auto& stream = at::cuda::getCurrentCUDAStream(); + const LaunchArgs& launch_args = args.launch_args; + + // Set dynamic shared memory size + if (launch_args.smem_size > 0) + DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, launch_args.smem_size)); + + // Launch config + cudaLaunchConfig_t config; + config.gridDim = {static_cast(launch_args.grid_dim.first), + static_cast(launch_args.grid_dim.second), + 1}; + config.blockDim = {static_cast(launch_args.num_threads), 1, 1}; + config.dynamicSmemBytes = launch_args.smem_size; + config.stream = stream; + config.numAttrs = 0; + + // Clusters + cudaLaunchAttribute attr; + if (launch_args.cluster_dim > 1) { + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {static_cast(launch_args.cluster_dim), 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + } + + // Launch in the derived class + if (get_env("DG_JIT_DEBUG")) { + printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n", + launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads, + launch_args.smem_size, launch_args.cluster_dim, stream.id()); + } + Derived::launch_impl(kernel, config, args); + } +}; + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp new file mode 100644 index 00000000..b5a8b61c --- /dev/null +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -0,0 +1,298 @@ +#pragma once + +#include "../../utils/math.hpp" + +namespace deep_gemm { + +struct MulticastConfig { + int num_multicast; + bool is_multicast_on_a; + + MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a): + num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) { + DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2); + } +}; + +struct SharedMemoryConfig { + int smem_size; + int swizzle_a_mode; + int swizzle_b_mode; + int swizzle_cd_mode; +}; + +struct ThreadConfig { + int num_threads; + + // SM90 + int num_tma_threads; + int num_math_threads; + + // SM100 + int num_non_epilogue_threads; + int num_epilogue_threads; + + static ThreadConfig sm90(const int& num_tma_threads, + const int& num_math_threads) { + auto config = ThreadConfig(); + config.num_threads = num_tma_threads + num_math_threads; + config.num_tma_threads = num_tma_threads; + config.num_math_threads = num_math_threads; + return config; + } + + static ThreadConfig sm100(const int& num_non_epilogue_threads, + const int& num_epilogue_threads) { + auto config = ThreadConfig(); + config.num_threads = num_non_epilogue_threads + num_epilogue_threads; + config.num_non_epilogue_threads = num_non_epilogue_threads; + config.num_epilogue_threads = num_epilogue_threads; + return config; + } +}; + +struct GemmConfig { + // Templated configs + GemmType gemm_type; + KernelType kernel_type; + at::ScalarType ab_dtype, cd_dtype; + cute::UMMA::Major major_a; + cute::UMMA::Major major_b; + bool with_accumulation; + int block_m, block_n, block_k; + int num_stages, num_last_stages; + + // Runtime configs + int num_sms; + + // Structured configs + MulticastConfig multicast_config; + SharedMemoryConfig smem_config; + ThreadConfig thread_config; +}; + +static bool is_multicast_legal(const int& shape_dim, const int& block_dim, + const int& num_multicast, const int& num_sms, + const bool& require_divisible) { + const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible; + return divisible and num_sms % num_multicast == 0; +} + +static int get_swizzle_mode(const int& block_size, const int& elem_size) { + // `> 0` means interleaving + // 16B actually means non-swizzling (but interleaving) + for (const int& mode: {128, 64, 32, 16}) { + if ((block_size * elem_size) % mode == 0) + return mode; + } + DG_HOST_UNREACHABLE("Unreachable"); +} + +template +static SharedMemoryConfig get_smem_config(const KernelType& kernel_type, + const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& num_stages, const MulticastConfig& multicast_config) { + const int& ab_elem_size = static_cast(c10::elementSize(ab_dtype)); + const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); + + const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m); + const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n); + const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size); + const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size); + const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size); + + // Different archs have different epilogue pipelines + const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); + + // A/B shared memory + const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size; + const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size; + + // SF shared memory + const auto& [smem_sfa_per_stage, smem_sfb_per_stage] = + ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype); + const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k); + + // M-barriers and tensor memory pointers + const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages); + const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size(); + + // Sum them up + int smem_size = 0; + smem_size += smem_cd; + smem_size += num_stages * smem_a_per_stage; + smem_size += num_stages * smem_b_per_stage; + smem_size += num_stages * smem_sfa_per_stage; + smem_size += num_stages * smem_sfb_per_stage; + smem_size += smem_extra_sfb; + smem_size += smem_barrier; + smem_size += smem_tmem_ptr; + + return SharedMemoryConfig { + .smem_size = smem_size, + .swizzle_a_mode = swizzle_a_mode, + .swizzle_b_mode = swizzle_b_mode, + .swizzle_cd_mode = swizzle_cd_mode, + }; +} + +template +static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type, + const int& m, const int& n, const int& k, const int& num_groups, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const bool& with_accumulation, const int& num_sms) { + DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); + + // Select M/N block sizes + // TODO: support `% 16 == 8` block size on SM90 + const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ? + std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256}; + std::vector block_ns; + for (int i = 16; i <= 256; i += 16) + block_ns.push_back(i); + + // K block size is selected in a fixed manner + const auto& block_k = 128 / static_cast(c10::elementSize(ab_dtype)); + + // Some util functions + const auto& get_num_blocks = [=](const int& block_m, const int& block_n) { + return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups; + }; + const auto& get_num_waves = [=](const int& block_m, const int& block_n) { + return ceil_div(get_num_blocks(block_m, block_n), num_sms); + }; + const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) { + const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms; + return num_last_blocks == 0 ? num_sms : num_last_blocks; + }; + + // Decide block sizes by waves + int best_block_m = 0, best_block_n = 0; + int best_num_waves = 0, best_last_util = 0; + for (const auto& block_m: block_ms) { + for (const auto& block_n: block_ns) { + const int& num_waves = get_num_waves(block_m, block_n); + const auto& last_util = get_last_wave_util(block_m, block_n); + if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n)) + continue; + + bool success = false; + if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) { + success = true; + } else if (num_waves == best_num_waves) { + // Check last wave utilization + success = last_util > best_last_util; + if (last_util == best_last_util) { + // Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n; + // Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m; + // Case 3: different for both `block_m` and `block_n`, larger `block_n` is better + success |= block_m != best_block_m and block_n > best_block_n; + } + } + + // Replace with the new config if successful + if (success) { + best_block_m = block_m, best_block_n = block_n; + best_num_waves = num_waves, best_last_util = last_util; + } + } + } + DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); + + // Decide the number of TMA multicasts and whether broadcast on A + MulticastConfig best_multicast_config = {1, true}; + const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( + gemm_type, m, n, best_block_m, best_block_n, num_sms); + const bool is_legal[2] = {is_legal_on_a, is_legal_on_b}; + bool order[2] = {false, true}; + if (best_block_m > best_block_n) + std::swap(order[0], order[1]); + for (const bool& is_multicast_on_a: order) { + if (m >= 512 and is_legal[static_cast(is_multicast_on_a)]) { + best_multicast_config = {2, is_multicast_on_a}; + break; + } + } + + // Always pick the largest number of stage + constexpr int smem_capacity = ArchSpec::smem_capacity; + int best_num_stages = 0; + SharedMemoryConfig best_smem_config; + for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) { + if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) + continue; + + best_smem_config = get_smem_config(kernel_type, + m, n, k, + best_block_m, best_block_n, block_k, + major_a, major_b, + ab_dtype, cd_dtype, + num_stages, best_multicast_config); + if (best_smem_config.smem_size <= smem_capacity) { + best_num_stages = num_stages; + break; + } + } + DG_HOST_ASSERT(best_num_stages != 0); + + // Recompute the minimal number of SMs required + // NOTES: less L2 cache usage and less GPU frequency drop + int num_min_sms = num_sms; + if (ArchSpec::should_minimize_num_sms()) { + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); + num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); + DG_HOST_ASSERT(num_min_sms <= num_sms); + } + + const auto& config = GemmConfig { + .gemm_type = gemm_type, + .kernel_type = kernel_type, + .ab_dtype = ab_dtype, + .cd_dtype = cd_dtype, + .major_a = major_a, + .major_b = major_b, + .with_accumulation = with_accumulation, + .block_m = best_block_m, + .block_n = best_block_n, + .block_k = block_k, + .num_stages = best_num_stages, + .num_last_stages = ceil_div(k, block_k) % best_num_stages, + .num_sms = num_min_sms, + .multicast_config = best_multicast_config, + // ReSharper disable once CppLocalVariableMightNotBeInitialized + .smem_config = best_smem_config, + .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) + }; + + // Print configs for the first time + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b, + ab_dtype, cd_dtype, with_accumulation, num_sms); + static std::set printed; + if (not printed.contains(key)) { + printf("Gemm type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " + "A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, " + "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " + "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " + "swizzle B: %d, swizzle CD: %d, threads: %d\n", + static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, + static_cast(major_a), static_cast(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype), + static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, + best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, + static_cast(best_multicast_config.is_multicast_on_a), + best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode, + best_smem_config.swizzle_cd_mode, config.thread_config.num_threads); + printed.insert(key); + } + } + return config; +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp new file mode 100644 index 00000000..722c3d1e --- /dev/null +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -0,0 +1,144 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +struct SM100ArchSpec { + static constexpr int smem_capacity = 232448; + + static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) { + return block_m / (config.is_multicast_on_a ? config.num_multicast : 1); + } + + static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) { + return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast); + } + + static int get_cd_store_block_m(const int& block_m) { + constexpr int layout_ad_m = 128; + return std::min(block_m, layout_ad_m); + } + + static int get_cd_store_block_n(const int& block_n) { + return block_n; + } + + static std::pair get_sf_uttcp_aligned_block_sizes( + const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) { + constexpr int num_utccp_aligned_elems = 128; + DG_HOST_ASSERT(block_m % num_utccp_aligned_elems == 0); + switch (ab_dtype) { + case torch::kBFloat16: return {0, 0}; + case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; + default: DG_HOST_UNREACHABLE("Unknown dtype"); + } + } + + static bool is_block_size_legal(const KernelType& kernel_type, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& block_m, const int& block_n) { + // Layout A/D does not support `block_m == 64` and `block_n % 16 != 0` + if (block_m == 64 or block_n % 16 != 0) + return false; + + // Performance is lower with 1D1D and `block_m == 256` + if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m != 128) + return false; + + // 1D2D kernels' maximum block N is 128 + // 1D2D kernels require more friendly block Ns + if (kernel_type == KernelType::Kernel1D2D and (block_n > 128 or 128 % block_n != 0)) + return false; + + // Check tensor memory validity + int sf_block_m = 0, sf_block_n = 0; + if (kernel_type == KernelType::Kernel1D1D) { + const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + sf_block_m = sf_block_m_, sf_block_n = sf_block_n_; + } + if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512) + return false; + + // NOTES: when B is MN-major, we restrict `block_n` to multiples of 64, + // since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA + return major_b == cute::UMMA::Major::K or block_n % 64 == 0; + } + + static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& num_stages, + const int& block_m, const int& block_n, const int& block_k) { + return true; + } + + static bool should_minimize_num_sms() { + return false; + } + + static std::pair get_multicast_legality(const GemmType& gemm_type, + const int& m, const int& n, const int& block_m, const int& block_n, + const int& num_sms) { + // TODO: support other layouts + return { + is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous), + false, + }; + } + + static ThreadConfig get_thread_config(const KernelType& kernel_type, + const int& block_m, const int& block_n) { + return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D1D ? 128 : block_m); + } + + static int get_smem_cd_size(const KernelType& kernel_type, + const int& block_m, const int& block_n, + const int& swizzle_cd_mode, + const at::ScalarType& cd_dtype) { + constexpr static int layout_ad_m = 128; + return (kernel_type == KernelType::Kernel1D1D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2; + } + + static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, + const int& block_m, const int& block_n, const int& block_k, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) { + if (ab_dtype == torch::kBFloat16) + return {0, 0}; + + int smem_sfa_per_stage = 0; + int smem_sfb_per_stage = 0; + if (kernel_type == KernelType::Kernel1D1D) { + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + smem_sfa_per_stage = sf_block_m * 4; + smem_sfb_per_stage = sf_block_n * 4; + } else { + smem_sfa_per_stage = block_m * 4; + smem_sfb_per_stage = 0; + } + return {smem_sfa_per_stage, smem_sfb_per_stage}; + } + + static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k) { + return 0; + } + + static int get_barrier_smem_size(const int& num_stages) { + // TODO: remove SF barriers for BF16 GEMMs + // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers + // NOTES: 1D2D kernel will not use the with-SF full barriers + // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages + return num_stages * 8 * 3 + 2 * 8 * 2; + } + + static int get_tmem_ptr_smem_size() { + return 4; + } +}; + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp new file mode 100644 index 00000000..a1cb5b4b --- /dev/null +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" + +namespace deep_gemm { + +struct SM90ArchSpec { + static constexpr int smem_capacity = 232448; + + static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) { + return block_m; + } + + static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) { + return block_n; + } + + static int get_cd_store_block_m(const int& block_m) { + return block_m; + } + + static int get_cd_store_block_n(const int& block_n) { + return block_n; + } + + static bool is_block_size_legal(const KernelType& kernel_type, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& block_m, const int& block_n) { + // FP32 output does not support `block_m == 256` + if (cd_dtype == at::kFloat and block_m == 256) + return false; + + // Must be some fixed block N selections + if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 or block_n != 152)) + return false; + if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 or block_n != 160)) + return false; + + // Avoid bank conflicts for FP32 output + if (cd_dtype == torch::kFloat and block_n % 16 == 0) + return false; + + // The block sizes cannot be too large (for enough registers), so at least one dim less than 128 + return block_m <= 128 or block_n <= 128; + } + + static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const int& num_stages, + const int& block_m, const int& block_n, const int& block_k) { + // Unrolling both stages and `num_former_iters` will cause large code size + if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) + return num_stages <= 4; + return true; + } + + static bool should_minimize_num_sms() { + return true; + } + + static std::pair get_multicast_legality(const GemmType& gemm_type, + const int& m, const int& n, const int& block_m, const int& block_n, + const int& num_sms) { + return { + is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), + is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked, + }; + } + + static ThreadConfig get_thread_config(const KernelType& kernel_type, + const int& block_m, const int& block_n) { + return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128); + } + + static int get_smem_cd_size(const KernelType& kernel_type, + const int& block_m, const int& block_n, + const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) { + return block_m * block_n * static_cast(c10::elementSize(cd_dtype)); + } + + static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, + const int& block_m, const int& block_n, const int& block_k, + const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) { + if (ab_dtype == torch::kBFloat16) + return {0, 0}; + + int smem_sfa_per_stage = block_m * static_cast(sizeof(float)); + int smem_sfb_per_stage = 0; + // TODO: figure out here + if (kernel_type == KernelType::Kernel1D1D) + smem_sfb_per_stage = align(block_n * 4, block_k); + return {smem_sfa_per_stage, smem_sfb_per_stage}; + } + + static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k) { + const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2; + return align(ceil_div(k, block_k) * static_cast(sizeof(float)) * use_uniform_sfb, 8); + } + + static int get_barrier_smem_size(const int& num_stages) { + // For 1D1D kernels, there is an extra barrier for accumulation + return (num_stages + 1) * 8 * 2; + } + + static int get_tmem_ptr_smem_size() { + return 0; + } +}; + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp new file mode 100644 index 00000000..ed9c5305 --- /dev/null +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -0,0 +1,173 @@ +#pragma once + +#include +#include + +#include "../../utils/math.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +static std::pair get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) { + return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k); +} + +static int get_non_contiguous_dim(const cute::UMMA::Major& major) { + return major == cute::UMMA::Major::K ? -2 : -1; +} + +static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) { + for (const char& c: compiled_dims) { + if (name == c) + return dim; + } + return 0; +} + +static std::string to_string(const cute::UMMA::Major& major) { + switch (major) { + case cute::UMMA::Major::K: return "cute::UMMA::Major::K"; + case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN"; + } + DG_HOST_UNREACHABLE("Unknown major"); +} + +static std::string to_string(const GemmType& type) { + switch (type) { + case GemmType::Normal: return "GemmType::Normal"; + case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; + case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; + case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; + } + DG_HOST_UNREACHABLE("Unknown GEMM type"); +} + +static std::string to_string(const at::ScalarType& dtype) { + switch (dtype) { + case torch::kInt: return "int"; + case torch::kFloat: return "float"; + case torch::kBFloat16: return "cutlass::bfloat16_t"; + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) { + switch (dtype) { + case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32; + case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) { + switch (mode) { + case 0: return CU_TENSOR_MAP_SWIZZLE_NONE; + case 16: return CU_TENSOR_MAP_SWIZZLE_NONE; + case 32: return CU_TENSOR_MAP_SWIZZLE_32B; + case 64: return CU_TENSOR_MAP_SWIZZLE_64B; + case 128: return CU_TENSOR_MAP_SWIZZLE_128B; + default: DG_HOST_UNREACHABLE("Unsupported swizzling mode"); + } +} + +static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, + int gmem_inner_dim, int gmem_outer_dim, + int smem_inner_dim, int smem_outer_dim, + const int& gmem_outer_stride, + const int& swizzle_mode) { + const auto& elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + smem_inner_dim = swizzle_mode / elem_size; + + CUtensorMap tensor_map; + const cuuint64_t gmem_dims[2] = {static_cast(gmem_inner_dim), static_cast(gmem_outer_dim)}; + const cuuint32_t smem_dims[2] = {static_cast(smem_inner_dim), static_cast(smem_outer_dim)}; + const cuuint64_t gmem_strides[1] = {static_cast(gmem_outer_stride * elem_size), }; + const cuuint32_t elem_strides[2] = {1, 1}; + if (get_env("DG_JIT_DEBUG")) { + printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n", + gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, + gmem_outer_stride, swizzle_mode, elem_size); + } + DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()), + 2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode), + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensor_map; +} + +static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_m, const int& shape_k, + const int& block_m, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode) { + if (num_groups > 1) + DG_HOST_ASSERT(major == cute::UMMA::Major::K); + const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups); + const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m); + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode); +} + +static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_n, const int& shape_k, + const int& block_n, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode) { + const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n); + const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n); + + // `num_groups` is always applied into the outer dimensions + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim * num_groups, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode); +} + +static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, + const int& shape_m, const int& shape_n, + const int& block_m, const int& block_n, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode) { + + // Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode` + // bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required + return make_tma_2d_desc(t, + shape_n, shape_m * num_groups, + block_n, block_m, + outer_stride, + swizzle_mode); +} + +static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + int shape_mn, int shape_k, + const int& block_mn, const int& block_k, + const int& num_groups, + const int& swizzle_mode) { + DG_HOST_ASSERT(major == cute::UMMA::Major::MN); + + // TODO: maybe swizzle SF as well + DG_HOST_ASSERT(swizzle_mode == 0); + + shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); + return make_tma_2d_desc(t, + shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, + block_mn, 1, + shape_mn, + swizzle_mode); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp new file mode 100644 index 00000000..fe8887e4 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -0,0 +1,351 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_c; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d1d_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + to_string(args.gemm_config.gemm_type), + args.gemm_config.with_accumulation, + to_string(args.gemm_config.cd_dtype)); + } + + static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + args.grouped_layout, args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_c, args.tensor_map_d)); + } +}; + +static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(cd.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, 1, 0); + + // Duplicate the accumulator if necessary + if (c.has_value()) { + if (c->data_ptr() == d.data_ptr()) { + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + } else { + // ReSharper disable once CppExpressionWithoutSideEffects + d.copy_(c.value()); + } + } + + // Launch + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_c, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D1D, + m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, num_groups, 0); + + // Launch kernel + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_d, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D1D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, num_groups, 0); + + // Launch kernel + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_d, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0, sum_sf_k = 0; + for (const auto& k: ks) { + sum_k += k, sum_sf_k += ceil_div(k, 512); + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::ranges::max_element(ks); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::Kernel1D1D, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(cd.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512, + config.block_m, config.block_k, num_groups, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512, + config.block_n, config.block_k, num_groups, 0); + + // Duplicate the accumulator if necessary + if (c.has_value()) { + DG_HOST_ASSERT(c->data_ptr() == d.data_ptr()); + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + } + + // Launch kernel + const SM100FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_c = tensor_map_c, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code); + SM100FP8Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp new file mode 100644 index 00000000..02478a09 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp @@ -0,0 +1,242 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8Gemm1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *sfb, *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + CUtensorMap tensor_map_sfa; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d2d_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + to_string(args.gemm_config.gemm_type), + to_string(args.gemm_config.cd_dtype)); + } + + static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + args.sfb, args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d, args.tensor_map_sfa)); + } +}; + +static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(not c.has_value()); + + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM100FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code); + SM100FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D2D, + m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM100FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code); + SM100FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D2D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + + // Launch + const SM100FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM100FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code); + SM100FP8Gemm1D2DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp new file mode 100644 index 00000000..2909ef3b --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -0,0 +1,255 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *sfb, *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + CUtensorMap tensor_map_sfa; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d2d_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, {}, + {} + >); +}}; +)", + // TODO: add CD dtype + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + to_string(args.gemm_config.gemm_type)); + } + + static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + args.sfb, args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d, args.tensor_map_sfa)); + } +}; + +static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& aligned_k = align(k, 128); + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D2D, + m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D2D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/smxx_layout.hpp b/csrc/jit_kernels/impls/smxx_layout.hpp new file mode 100644 index 00000000..9d6e3021 --- /dev/null +++ b/csrc/jit_kernels/impls/smxx_layout.hpp @@ -0,0 +1,199 @@ +#pragma once + +#include + +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../../utils/layout.hpp" + +namespace deep_gemm { + +class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_and_pack_fp32_into_ue8m0< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast(args.mn))); + } +}; + +class PackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int num_groups, mn, sf_k, packed_sf_k; + int block_mn, block_packed_sf_k; + void *sf, *out, *ks; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&pack_fp32_into_ue8m0< + {}, {}, {}, {} + >); +}}; +)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k); + } + + static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k)); + } +}; + +static std::tuple preprocess_sf(const torch::Tensor& sf) { + // NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + const auto& dim = sf.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat); + const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf; + + const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf); + const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast(sf.element_size())); + return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf}; +} + +static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { + const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + + // The last kernel already gives a column-major TMA aligned layout + if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn) + return (dim == 2) ? batched_sf.squeeze(0) : batched_sf; + + // Normal layout requires transposing + auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options()); + aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf); + return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf; +} + +static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) { + const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + const auto& packed_sf_k = ceil_div(sf_k, 4); + const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k}, + {packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn}, + at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); + DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0); + + // Launch the kernel + if (batched_sf.is_contiguous()) { + constexpr int block_mn = 48; + constexpr int num_threads = 512; + const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4) + }; + + const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); + TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); + } else { + DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1); + DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = 1, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .ks = nullptr, + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto& code = PackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf, + const torch::Tensor& ks_tensor, + const std::vector& ks) { + const auto& [sf_k, mn] = get_shape<2>(sf); + const auto& num_groups = static_cast(ks.size()); + + int ref_sf_k = 0, packed_sf_k = 0; + for (const auto& k: ks) + ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512); + DG_HOST_ASSERT(sf.is_contiguous()); + DG_HOST_ASSERT(ref_sf_k == sf_k); + DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0); + + const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt)); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = num_groups, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = sf.data_ptr(), + .out = out.data_ptr(), + .ks = ks_tensor.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto& code = PackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + return out; +} + +} // namespace deep_gemm diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp new file mode 100644 index 00000000..e1e916f2 --- /dev/null +++ b/csrc/python_api.cpp @@ -0,0 +1,402 @@ +#include +#include + +#include "jit/compiler.hpp" +#include "jit/device_runtime.hpp" +#include "utils/layout.hpp" + +#include "jit_kernels/impls/smxx_layout.hpp" +#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" +#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" + +#ifndef TORCH_EXTENSION_NAME +#define TORCH_EXTENSION_NAME deep_gemm_cpp +#endif + +namespace deep_gemm { +torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const std::optional& num_groups, + const std::tuple& recipe, + const bool& is_sfa, + const bool& disable_ue8m0_cast) { + const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe); + const auto& gran_k = std::get<2>(recipe); + const auto& arch_major = device_runtime->get_arch_major(); + + // Pre-transform checks + check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups); + + // (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return get_mn_major_tma_aligned_tensor(sf); + + // (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); + } + + // (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); + + // (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128)); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); + } + + // (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); + + DG_HOST_UNREACHABLE("Unknown SF transformation"); +} + +torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::tuple& recipe) { + DG_HOST_ASSERT(sf.dim() == 2); + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + const auto& arch_major = device_runtime->get_arch_major(); + + // FP32 on SM90 + if (sf.scalar_type() == torch::kFloat and arch_major == 9) + DG_HOST_UNREACHABLE("Unimplemented"); + + // FP32 on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); + + // INT on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + DG_HOST_UNREACHABLE("Unimplemented"); + + DG_HOST_UNREACHABLE("Unknown cases"); +} + +void fp8_gemm_nt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a.first); + const auto& [n , k_] = get_shape<2>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check C as well + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast); + const auto& sfb = transform_sf_into_required_layout(b.second, n, k, std::nullopt, recipe.value(), false, disable_ue8m0_cast); + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types"); + } +} + +void fp8_gemm_nn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +void fp8_gemm_tn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, + {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +void fp8_gemm_tt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + if (fp8_requires_k_major()) + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(m_indices.is_contiguous()); + + // Type and shape checks + const auto& [m, k] = get_shape<2>(a.first); + const auto& [num_groups, n, k_] = get_shape<3>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + const auto& m__ = static_cast(m_indices.numel()); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast); + const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types"); + } +} + +void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, + d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); +} + +void fp8_m_grouped_gemm_nt_masked(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& expected_m, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a.first); + const auto& [num_groups_, n, k_] = get_shape<3>(b.first); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Transform scaling factors + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = transform_sf_into_required_layout(a.second, m, k, num_groups, recipe.value(), true, disable_ue8m0_cast); + const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_m_grouped_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported kernel or scaling factor types"); + } +} + +void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + if (c.has_value()) { + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().is_contiguous()); + } + + // Do nothing if empty + if (std::accumulate(ks.begin(), ks.end(), 0) == 0) + return; + + // Transform SF with padding + const auto& [_, m] = get_shape<2>(a.first); + const auto& [__, n] = get_shape<2>(b.first); + const auto& sfa = transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto& sfb = transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +} // namespace deep_gemm + +// ReSharper disable once CppParameterMayBeConstPtrOrRef +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + using namespace deep_gemm; + + m.doc() = "DeepGEMM C++ library"; + + // Runtime + m.def("get_num_sms", [&]() { + return device_runtime->get_num_sms(); + }); + m.def("set_num_sms", [&](const int& new_num_sms) { + device_runtime->set_num_sms(new_num_sms); + }); + + // JIT + m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) { + DG_HOST_ASSERT(get_env("DG_JIT_USE_NVRTC", 0) == 0 and "Currently only support NVCC"); + compiler = std::make_shared(library_root_path, cuda_home_path_by_torch); + KernelRuntime::set_cuda_home(cuda_home_path_by_torch); + }); + + // Stable kernel APIs with automatic arch/layout dispatch + m.def("fp8_gemm_nt", &fp8_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_nn", &fp8_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_tn", &fp8_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_tt", &fp8_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_m_grouped_gemm_nt_masked", &fp8_m_grouped_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); + m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout); + + // Raw kernels or functions + m.def("get_tma_aligned_size", &get_tma_aligned_size); + m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); + m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor); + m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor); + m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); +} diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp new file mode 100644 index 00000000..493e4807 --- /dev/null +++ b/csrc/utils/exception.hpp @@ -0,0 +1,58 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +class DGException final : public std::exception { + std::string message = {}; + +public: + explicit DGException(const char *name, const char* file, const int line, const std::string& error) { + message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; + } + + const char *what() const noexcept override { + return message.c_str(); + } +}; + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_HOST_ASSERT +#define DG_HOST_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + throw DGException("Assertion", __FILE__, __LINE__, #cond); \ + } \ +} while (0) +#endif + +#ifndef DG_HOST_UNREACHABLE +#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason)) +#endif + +#ifndef DG_CUDA_DRIVER_CHECK +#define DG_CUDA_DRIVER_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + throw DGException("CUDA driver", __FILE__, __LINE__, ""); \ + } \ +} while (0) +#endif + +#ifndef DG_CUDA_RUNTIME_CHECK +#define DG_CUDA_RUNTIME_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != cudaSuccess) { \ + throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast(e))); \ + } \ +} while (0) +#endif + +} // namespace deep_gemm diff --git a/csrc/utils/format.hpp b/csrc/utils/format.hpp new file mode 100644 index 00000000..bf617372 --- /dev/null +++ b/csrc/utils/format.hpp @@ -0,0 +1,6 @@ +#pragma once + +// Just a wrapper for the `fmt` headers +#define FMT_HEADER_ONLY +#include +#include diff --git a/csrc/utils/hash.hpp b/csrc/utils/hash.hpp new file mode 100644 index 00000000..fad1231f --- /dev/null +++ b/csrc/utils/hash.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include + +namespace deep_gemm { + +static uint64_t fnv1a(const std::string& data, const uint64_t& seed) { + uint64_t h = seed; + const uint64_t& prime = 0x100000001b3ull; + for (const char& c: data) { + h ^= static_cast(c); + h *= prime; + } + return h; +} + +static std::string get_hex_digest(const std::string& data) { + const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull); + const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull); + + // Split-mix 64 + const auto& split_mix = [](uint64_t z) { + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull; + z = (z ^ (z >> 27)) * 0x94d049bb133111ebull; + return z ^ (z >> 31); + }; + + std::ostringstream oss; + oss << std::hex << std::setfill('0') + << std::setw(16) << split_mix(state_0) + << std::setw(16) << split_mix(state_1); + return oss.str(); +} + +} // namespace deep_gemm diff --git a/csrc/utils/layout.hpp b/csrc/utils/layout.hpp new file mode 100644 index 00000000..47d46c47 --- /dev/null +++ b/csrc/utils/layout.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include +#include + +#include "math.hpp" +#include "exception.hpp" +#include "../jit/device_runtime.hpp" + +namespace deep_gemm { + +// Major-ness stuffs +static void major_check(const torch::Tensor& t) { + const auto dim = t.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + if (dim == 3) + DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1)); + DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1); +} + +static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) { + major_check(t); + return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; +} + +static void check_major_type_cd(const torch::Tensor& t) { + // NOTES: the library only supports row-major output layouts + major_check(t); + DG_HOST_ASSERT(t.stride(-1) == 1); +} + +static bool fp8_requires_k_major() { + return device_runtime->get_arch_major() == 9; +} + +// Tensor utils +template +static auto get_shape(const torch::Tensor& t) { + return [&t] (std::index_sequence) { + return std::make_tuple(static_cast(t.sizes()[Is])...); + }(std::make_index_sequence()); +} + +// Recipe +static std::tuple +get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) { + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat); + return {1, 128, 128}; + } else if (arch_major == 10) { + DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt); + return sfb_dtype == torch::kFloat ? + std::make_tuple(1, 128, 128): // Legacy format or 1D2D kernels + std::make_tuple(1, 1, 128); // 1D1D kernels + } + DG_HOST_UNREACHABLE("Unknown recipe"); +} + +// SF layouts +static torch::Tensor check_sf_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const int& gran_mn, const int& gran_k, + const std::optional& num_groups, + const bool& tma_stride_check = false, + const bool& contiguous_check = false, + const std::optional& type_check = std::nullopt) { + // Type check + if (type_check.has_value()) + DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); + + // Always do shape checks + const auto& sf_dtype = sf.scalar_type(); + DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); + DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.size(-3) == num_groups.value()); + DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn)); + DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4))); + + // TMA stride checks: TMA aligned and MN-major + if (tma_stride_check) { + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1)); + DG_HOST_ASSERT(sf.stride(-2) == 1); + DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())); + } + + // Hopper SFB must be contiguous + if (contiguous_check) + DG_HOST_ASSERT(sf.is_contiguous()); + return sf; +} + +// Value matrix layout +static int get_mk_alignment_for_contiguous_layout() { + return 128; +} + +} // namespace deep_gemm diff --git a/csrc/utils/math.hpp b/csrc/utils/math.hpp new file mode 100644 index 00000000..264d2d10 --- /dev/null +++ b/csrc/utils/math.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "exception.hpp" + +namespace deep_gemm { + +template +static T ceil_div(const T& a, const T& b) { + return (a + b - 1) / b; +} + +template +static constexpr T align(const T& a, const T& b) { + return ceil_div(a, b) * b; +} + +static int get_tma_aligned_size(const int& x, const int& element_size) { + constexpr int kNumTMAAlignmentBytes = 16; + DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0); + return align(x, kNumTMAAlignmentBytes / element_size); +} + +} // namespace deep_gemm diff --git a/csrc/utils/system.hpp b/csrc/utils/system.hpp new file mode 100644 index 00000000..7189b7f1 --- /dev/null +++ b/csrc/utils/system.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include + +#include "exception.hpp" + +namespace deep_gemm { + +// ReSharper disable once CppNotAllPathsReturnValue +template +static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) { + const auto& c_str = std::getenv(name.c_str()); + if (c_str == nullptr) + return default_value; + + // Read the env and convert to the desired type + if constexpr (std::is_same_v) { + return std::string(c_str); + } else if constexpr (std::is_same_v) { + int value; + std::sscanf(c_str, "%d", &value); + return value; + } else { + DG_HOST_ASSERT(false and "Unexpected type"); + } +} + +static std::tuple call_external_command(std::string command) { + command = command + " 2>&1"; + const auto& deleter = [](FILE* f) { if (f) pclose(f); }; + std::unique_ptr pipe(popen(command.c_str(), "r"), deleter); + DG_HOST_ASSERT(pipe != nullptr); + + std::array buffer; + std::string output; + while (fgets(buffer.data(), buffer.size(), pipe.get())) + output += buffer.data(); + const auto& exit_code = WEXITSTATUS(pclose(pipe.release())); + return {exit_code, output}; +} + +static std::filesystem::path make_dirs(const std::filesystem::path& path) { + // OK if existed + std::error_code capture; + const bool& created = std::filesystem::create_directories(path, capture); + DG_HOST_ASSERT(created or capture.value() == 0); + if (created and get_env("DG_JIT_DEBUG")) + printf("Create directory: %s\n", path.c_str()); + return path; +} + +static std::string get_uuid() { + static std::random_device rd; + static std::mt19937 gen([]() { + return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count(); + }()); + static std::uniform_int_distribution dist; + + std::stringstream ss; + ss << getpid() << "-" + << std::hex << std::setfill('0') + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen); + return ss.str(); +} + +} // deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 8e6b2996..17e7a330 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -1,15 +1,41 @@ +import os import torch +import torch.utils.cpp_extension -from . import jit -from .jit_kernels import ( - gemm_fp8_fp8_bf16_nt, - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked, - wgrad_gemm_fp8_fp8_fp32_nt, - k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, - ceil_div, - set_num_sms, get_num_sms, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout +# Set some default environment provided at setup +try: + # noinspection PyUnresolvedReferences + from .envs import persistent_envs + for key, value in persistent_envs.items(): + if key not in os.environ: + os.environ[key] = value +except ImportError: + pass + +# Import functions from the CPP module +import deep_gemm_cpp +deep_gemm_cpp.init( + os.path.dirname(os.path.abspath(__file__)), # Library root directory path + torch.utils.cpp_extension.CUDA_HOME # CUDA home +) + +# Configs +from deep_gemm_cpp import ( + set_num_sms, + get_num_sms ) -from .utils import bench, bench_kineto, calc_diff + +# Kernels +from deep_gemm_cpp import ( + fp8_gemm_nt, fp8_gemm_nn, + fp8_gemm_tn, fp8_gemm_tt, + m_grouped_fp8_gemm_nt_contiguous, + m_grouped_fp8_gemm_nn_contiguous, + fp8_m_grouped_gemm_nt_masked, + k_grouped_fp8_gemm_tn_contiguous +) + +# Some utils +from . import testing +from . import utils +from .utils import * diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh new file mode 100644 index 00000000..8ce8aa09 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -0,0 +1,213 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +enum class KGroupedIndexType { + MN, + K, + SF_K, +}; + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx; + // Only used for masked layout + uint32_t current_m_cumsum; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum; + + // ReSharper disable once CppPossiblyUninitializedMember + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + int* grouped_layout = nullptr) { + num_m_blocks = ceil_div(shape_m, BLOCK_M); + num_n_blocks = ceil_div(shape_n, BLOCK_N); + if constexpr (kGemmType == GemmType::Normal) { + num_blocks = num_m_blocks * num_n_blocks; + } else if (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::MGroupedMasked) { + current_group_idx = current_m_cumsum = 0; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::KGroupedContiguous) { + current_group_idx = current_num_valid_groups = 0; + current_k_cumsum = current_sf_k_cumsum = 0; + current_shape_k = __ldg(grouped_layout + current_group_idx); + this->grouped_layout = grouped_layout; + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto& group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? std::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == KGroupedIndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == KGroupedIndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == KGroupedIndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + current_group_idx)), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + break; + + // Move to check the next group + if (current_shape_k > 0) { + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, 512u); + current_num_valid_groups ++; + } + if ((++ current_group_idx) != kNumGroups) + current_shape_k = __ldg(grouped_layout + current_group_idx); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = kNum1DBlocksPerGroup % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh new file mode 100644 index 00000000..2016a79a --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -0,0 +1,169 @@ +#pragma once + +#include +#include +#include + +#include + +namespace deep_gemm::sm100 { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const int32_t& outer_idx) { + DG_STATIC_ASSERT(1 <= kNumMulticast and kNumMulticast <= 2, "Invalid multicast config"); + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + + // 2-CTA function will send signals to the leader CTA only + const auto copy_func = kNumMulticast == 1 ? cute::SM90_TMA_LOAD_2D::copy : cute::SM100_TMA_2SM_LOAD_2D::copy; + + // Issue multiple TMAs + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + copy_func(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + uint32_t stride_byte_offset, uint32_t leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +__device__ __forceinline__ +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +__device__ __forceinline__ +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +template +__device__ __forceinline__ +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + ((offset + k_idx * get_umma_desc_stride_k()) >> 4u); +} + +template +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(to_umma_layout_type(), + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(to_umma_layout_type(), + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +__device__ __forceinline__ +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) { + desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id; + return static_cast(static_cast(desc)) << 32; +} + +template +__device__ constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if (kNumCols <= 32) return 32; + if (kNumCols <= 64) return 64; + if (kNumCols <= 128) return 128; + if (kNumCols <= 256) return 256; + return 512; +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +} // namespace `deep_gemm::sm100` diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh similarity index 62% rename from deep_gemm/include/deep_gemm/mma_utils.cuh rename to deep_gemm/include/deep_gemm/common/sm90_utils.cuh index 85b2ccc0..e0160636 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -1,15 +1,63 @@ #pragma once -#ifndef __CUDACC_RTC__ -#include -#endif - +#include #include #include -#include "utils.cuh" +namespace deep_gemm::sm90 { -namespace deep_gemm { +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; template struct SM90_U32x2_STSM_N { @@ -21,17 +69,6 @@ struct SM90_U32x2_STSM_N { } }; -template -struct SM90_U32x4_STSM_N { - __device__ __forceinline__ static void - copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { - const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), - *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; - asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" - :: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); - } -}; - __forceinline__ __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } @@ -44,54 +81,13 @@ __forceinline__ __device__ void warpgroup_fence_operand(float& reg) { asm volatile("" : "+f"(reg) :: "memory"); } -__forceinline__ __device__ uint32_t get_lane_id() { - uint32_t lane_id; - asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) { - uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) { - int4 ret; - asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) { - float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) { - float2 ret; - asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); - return ret; -} - -__device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); -} - -__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); -} - -__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { - asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y)); -} - template -__device__ void warpgroup_wait() { +__forceinline__ __device__ void warpgroup_wait() { DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); } +// TODO: replace with CUTLASS solution union GmmaDescriptor { __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} @@ -128,11 +124,11 @@ union GmmaDescriptor { }; template -__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, - int leading_byte_offset = 0, - int stride_byte_offset = 1024) { +__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { GmmaDescriptor desc; - auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); desc.bitfield.start_address_ = uint_ptr >> 4; desc.bitfield.layout_type_ = layout_type; desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; @@ -141,72 +137,15 @@ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, return desc; } -template -struct FP8MMA { - - template - __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence) { - using namespace cute::SM90::GMMA; - MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); - } - - __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence{}); - } - - static constexpr int M = 64; - static constexpr int N = N_; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; - -template -struct FP8MMASelector { - - static constexpr auto select_mma() { - using namespace cute::SM90::GMMA; - if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); - if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); - } - - static constexpr auto select_type() { - return FP8MMA(); +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, + const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast) { + constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); + if (num_tma_multicast == 1) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); + } else if (cute::block_rank_in_cluster() == 0) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); } - - using type = decltype(select_type()); -}; - -enum class Layout { - RowMajor, - ColMajor -}; - -__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) { - return block_m == 64 ? 1 : 2; -} - -template -__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { - DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); - return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads; } -} // namespace deep_gemm +} // namespace `deep_gemm::sm90` diff --git a/deep_gemm/include/deep_gemm/common/types.hpp b/deep_gemm/include/deep_gemm/common/types.hpp new file mode 100644 index 00000000..7e879533 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/types.hpp @@ -0,0 +1,17 @@ +#pragma once + +namespace deep_gemm { + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, +}; + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, +}; + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh new file mode 100644 index 00000000..a4ab6a34 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -0,0 +1,138 @@ +#pragma once + +#include +#include + +#ifdef __CLION_IDE__ + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +namespace deep_gemm { + +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ + explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} + + __device__ __host__ + auto operator [](const uint32_t& i) { + return func(i); + } +}; + +template +__device__ __host__ T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ T align(T a, T b) { + return ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +__forceinline__ __device__ void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +__forceinline__ __device__ uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +__forceinline__ __device__ uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); +} + +__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +template +__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +} // namespace `deep_gemm` diff --git a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh deleted file mode 100644 index 7b7e3d31..00000000 --- a/deep_gemm/include/deep_gemm/fp8_wgrad_gemm.cuh +++ /dev/null @@ -1,363 +0,0 @@ -#pragma once - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunknown-attributes" - -#include -#include - -#include -#include -#include - -#include "mma_utils.cuh" -#include "scheduler.cuh" -#include "tma_utils.cuh" -#include "utils.cuh" - -namespace deep_gemm { - -template -__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) -fp8_wgrad_gemm_kernel(uint32_t shape_k, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_scales_a, - const __grid_constant__ CUtensorMap tensor_map_scales_b, - const __grid_constant__ CUtensorMap tensor_map_d) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__) - // Scaling checks - DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - - // Types - using WGMMA = typename FP8MMASelector::type; - using Barrier = cutlass::arch::ClusterTransactionBarrier; - DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); - - // Shared memory - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); - static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U; - - // Configs - constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); - constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; - - const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); - const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_id(); - - // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == kNumMathThreads) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); - } - __syncwarp(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); - - // Data on shared memory - auto smem_d = reinterpret_cast(smem_buffer); - __nv_fp8_e4m3* smem_a[kNumStages]; - __nv_fp8_e4m3* smem_b[kNumStages]; - float* smem_scales_a[kNumStages]; - float* smem_scales_b[kNumStages]; - - // TMA Barrier for both divisible and non-divisible cases - Barrier* full_barriers[kNumStages + 1]; - Barrier* empty_barriers[kNumStages + 1]; - - // Fill shared memory pointers - #pragma unroll - for (int i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) - + i * SMEM_SCALES_A_SIZE_PER_STAGE); - smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE) - + i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE); - } - - // Fill barriers - DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers"); - auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages - * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE)); - #pragma unroll - for (int i = 0; i < kNumStages + 1; ++ i) { - full_barriers[i] = barrier_start_ptr + i; - empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i; - } - - // Initialize barriers - DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast"); - if (threadIdx.x == kNumMathThreads) { - // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, - // even with TMA multicast disabled, we want to make the behavior aligned - #pragma unroll - for (int i = 0; i < kNumStages; ++ i) { - full_barriers[i]->init(1); - empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); - } - full_barriers[kNumStages]->init(1); - empty_barriers[kNumStages]->init(1); - - // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); - (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); - } - - // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); - - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - auto launch_k_iterations = [&](const auto& func) { - if constexpr (kNumLastStages == 0) { - for (int k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(num_iterations - 1, NotDivisibleK{}); - } - }; - - // Register reconfigurations - constexpr int kNumTMARegisters = 40; - constexpr int kNumMathRegisters = 232; - - // Block scheduler - uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(SHAPE_M); - - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); - - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - // Assign TMA multicast number into A and B - // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. - const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - - // Issue TMA A - auto& full_barrier = *full_barriers[s]; - int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - k_idx / BLOCK_K, num_tma_multicast_a); - - // Issue TMA B - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b); - tma_copy(&tensor_map_scales_b, reinterpret_cast(&full_barrier), - smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b); - - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }); - - // Issue TMA D - empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1); - auto& full_barrier = *full_barriers[kNumStages]; - tma_copy(&tensor_map_d, reinterpret_cast(&full_barrier), - smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1); - full_barrier.arrive_and_expect_tx(SMEM_D_SIZE); - } - - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; - const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); - } - }; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; - float2 scales_b[WGMMA::kNumAccum / 4]; - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - - // Read A scales - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - - // Read B scales at the first warpgroup wave - if (local_idx == 0) { - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) - scales_b[i] = ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + col_idx * 2)); - __syncwarp(); - } - - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); - - // Promote with scales - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - const float &scale_b_0 = scales_b[i].x; - const float &scale_b_1 = scales_b[i].y; - shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; - } - } - } - - // Wait last TMA store to be finished - if (k_iter == 0 and scheduler.current_iter > 0) { - if (threadIdx.x == 0) { - cute::tma_store_wait<0>(); - empty_barriers[kNumStages]->arrive(); - } - __syncwarp(); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); - - // Wait TMA D arrivals - full_barriers[kNumStages]->wait(scheduler.current_iter & 1); - - // Accumulate to D shared memory - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2); - auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - float2 d_0 = ld_shared(smem_d_0 + i * 4); - st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]}); - float2 d_1 = ld_shared(smem_d_1 + i * 4); - st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]}); - } - } - - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M); - cute::tma_store_arrive(); - } - __syncwarp(); - } - } -#else - if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false && "This kernel only support sm_90a"); -#endif -} - -}; // namespace deep_gemm - -#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh new file mode 100644 index 00000000..28b5399a --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -0,0 +1,3 @@ +#pragma once + +// TODO: add implement \ No newline at end of file diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..360719aa --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -0,0 +1,601 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_sfa, + const __grid_constant__ CUtensorMap tensor_map_sfb, + const __grid_constant__ CUtensorMap tensor_map_c, + const __grid_constant__ CUtensorMap tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(std::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t kNumTMAStoreStages = 2; + constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t); + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = std::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_d); + if constexpr (kWithAccumulation) + cute::prefetch_tma_descriptor(&tensor_map_c); + } + + // Data on shared memory (layout as ordered below) + cd_dtype_t* smem_cd[kNumTMAStoreStages]; + cutlass::float_e4m3_t* smem_a[kNumStages]; + cutlass::float_e4m3_t* smem_b[kNumStages]; + uint32_t* smem_sfa[kNumStages]; + uint32_t* smem_sfb[kNumStages]; + + // Fill D/A/B pointers + #pragma unroll + for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) + smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill SFA/SFB + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_sfa[i] = reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + smem_sfb[i] = reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + } + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (threadIdx.x == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + uint32_t phase = 0; + auto launch_k_iterations = [&](const auto& func) { + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K); + const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages; + + // TODO: refactor here + if (num_last_stages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, false, num_last_stages); + func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1; + } + }; + + auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) { + DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2, + "Too many epilogue stages, please modify the Python heuristic as well"); + accum_stage_idx == 0 ? func(0) : func(1); + }; + + // Dispatch warps into different roles + if (warp_idx == 0) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = std::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_block_idx = k_iter * kNumStages + s; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + if (cute::elect_one_sync()) { + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); + } + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad))); + tma_copy(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx)); + num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); + } + + // Arrive at full barriers + if (cute::elect_one_sync()) + full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait(phase ^ 1); + if (cute::elect_one_sync()) + full_barriers[s]->arrive(); + } + }); + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + // Wait tensor memory empty barrier arrival + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[s])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = std::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[s]->wait(phase); + tcgen05_after_thread_sync(); + + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { + using cute_utccp_t = std::conditional_t; + + // SFA and SFB copy + // TODO: process shared memory descriptor by addition + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); + + // Issue UMMA in the leader CTA + using cute_mma_t = std::conditional_t, + cute::SM100_MMA_MXF8F6F4_2x1SM_SS>; + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + cute_mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_iter > 0 or s > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + with_sf_full_barriers[s]->wait(phase); + empty_barrier_arrive(s, false); + } + }); + }); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = std::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrival + full_barriers[s]->wait(phase); + + // Transpose for UTCCP at certain stages + const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems); + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[s]->arrive(0u); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait(phase); + with_sf_full_barriers[s]->arrive(0u); + } + }); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + // Epilogue warp groups + const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Flush TMA stores + // NOTES: for the first store, we have to flush all previous TMA, + // as we don't share pipeline stages between two blocks + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + const uint32_t iter_idx = w * kNumStores + s; + if (iter_idx >= kNumTMAStoreStages) { + if (epilogue_thread_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = iter_idx % kNumTMAStoreStages; + const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (std::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + if (epilogue_thread_idx == 0) { + using cute_tma_t = std::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + } + }); + } + + // Flush all stages in the pipeline to make TMA stores visible to the next kernel + // TODO: do we actually need this? + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + // TODO: do we need 2 SM allocation? + if (epilogue_warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } + + // To safely deconstruct all barriers, we need a cluster sync + // TODO: optimize it by another round of barrier waits + if constexpr (kNumMulticast > 1) + cute::cluster_sync(); +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh new file mode 100644 index 00000000..dcfeed9d --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh @@ -0,0 +1,532 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_d, + const __grid_constant__ CUtensorMap tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_M == kNumEpilogueThreads, "Invalid block M"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const auto shape_k_scales = ceil_div(shape_k, BLOCK_K); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = std::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + + // Share memory sizes + // NOTES: do not use `LOAD_BLOCK_M` for SFA, as we need full SFA for promotion + constexpr bool kMustUseUniformedSFB = (BLOCK_K % BLOCK_N == 0); + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // Must have 2 epilogue stages + constexpr uint32_t kNumEpilogueStages = 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + } + + // Data on shared memory (layout as ordered below) + cd_dtype_t* smem_cd[kNumTMAStoreStages]; + cutlass::float_e4m3_t* smem_a[kNumStages]; + cutlass::float_e4m3_t* smem_b[kNumStages]; + float* smem_sfa[kNumStages]; + + // Fill D/A/B pointers + #pragma unroll + for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) + smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill SFA/SFB + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) + smem_sfa[i] = reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * SMEM_SFA_SIZE_PER_STAGE); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (threadIdx.x == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + const uint32_t num_iterations = ceil_div(shape_k, kNumStages * BLOCK_K); + auto launch_k_iterations = [=](const auto& func) { + if constexpr (kNumLastStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(num_iterations - 1, NotDivisibleK{}); + } + }; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + // Register configurations + constexpr uint32_t kNumNonEpilogueRegisters = 64; + constexpr uint32_t kNumEpilogueRegisters = 216; + DG_STATIC_ASSERT(kNumNonEpilogueRegisters * kNumNonEpilogueThreads + kNumEpilogueRegisters * kNumEpilogueThreads <= 65535, "Too many registers"); + + // Dispatch warps into different roles + if (warp_idx == 0) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::K)>( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_block_idx = k_iter * kNumStages + s; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_b_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::MN)>( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + if (cute::elect_one_sync()) { + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); + + // Issue SFA TMA + tma_copy( + &tensor_map_sfa, full_barriers[s], + smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_k_scales, 1, k_block_idx)); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE; + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(); + } + }); + } + } else if (warp_idx == 1 and is_leader_cta) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + // Wait TMA full + auto iter_idx = scheduler.current_iter * num_iterations + k_iter; + full_barriers[s]->wait(iter_idx & 1); + + // Wait tensor memory empty + auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages; + auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_stage_phase ^ 1); + + // Issue UMMA in the leader CTA + if (s < kNumInnerStages) { + using cute_mma_t = std::conditional_t; + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + auto b_desc = make_umma_desc(smem_b[s], 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + auto a_desc = make_umma_desc(smem_a[s], w * LAYOUT_AD_M, k * UMMA_K); + cute_mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k > 0, + runtime_instr_desc); + } + } + tcgen05_before_thread_sync(); + } + + // Commit to the TMA empty and tensor memory full barrier + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + } + }); + } + } else if (warp_idx < kNumNonEpilogueThreads / 32) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // Epilogue warp groups + const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; + const auto epilogue_thread_idx_in_warpgroup = epilogue_thread_idx % 128; + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + const auto epilogue_warpgroup_idx = epilogue_thread_idx / 128; + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t kNumElemsPerLDTM = 16; + DG_STATIC_ASSERT(kNumElemsPerLDTM == 16 and BLOCK_N % kNumElemsPerLDTM == 0 and BLOCK_K % kNumElemsPerLDTM == 0, "Invalid LDTM width"); + + // SFB stuffs + uint32_t num_former_iters = BLOCK_N, num_full_iters = BLOCK_N; + if constexpr (not kMustUseUniformedSFB) { + num_former_iters = min(BLOCK_N, BLOCK_K - ((n_block_idx * BLOCK_N) % BLOCK_K)); + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N); + } + num_former_iters /= kNumElemsPerLDTM, num_full_iters /= kNumElemsPerLDTM; + const auto sfb_offset = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); + const auto sfb_ptr = sfb + (sfb_offset + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; + + // Launch promotion + float accum[BLOCK_N] = {0}; + launch_k_iterations([&](uint32_t k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + // Load SFB + float sf_0 = 0, sf_1 = 0; + if (s < kNumInnerStages) { + const auto k_block_idx = k_iter * kNumStages + s; + sf_0 = __ldg(sfb_ptr + k_block_idx); + sf_1 = num_former_iters < num_full_iters ? __ldg(sfb_ptr + k_block_idx + shape_k_scales) : 0; + } + + // Wait UMMA arrival + auto iter_idx = scheduler.current_iter * num_iterations + k_iter; + auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages; + auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_stage_phase); + tcgen05_after_thread_sync(); + + // Commit to the TMA empty barrier for all CTAs after loading SFA + float sfa = s < kNumInnerStages ? ld_shared(smem_sfa[s] + epilogue_thread_idx) : 0; + sf_0 *= sfa, sf_1 *= sfa; + __syncwarp(); + if (lane_idx < kNumMulticast) + empty_barriers[s]->arrive(lane_idx); + __syncwarp(); + + // Do promotion like the SM90 kernel + if (s < kNumInnerStages) { + uint32_t values[kNumElemsPerLDTM]; + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerLDTM; ++ i) { + // Load from tensor memory + cute::SM100_TMEM_LOAD_32dp32b16x::copy( + accum_stage_idx * kNumMWaves * BLOCK_N + epilogue_warpgroup_idx * BLOCK_N + i * kNumElemsPerLDTM, + values[ 0], values[ 1], values[ 2], values[ 3], + values[ 4], values[ 5], values[ 6], values[ 7], + values[ 8], values[ 9], values[10], values[11], + values[12], values[13], values[14], values[15]); + cutlass::arch::fence_view_async_tmem_load(); + + // Promote + const auto sf = (kMustUseUniformedSFB or i < num_former_iters) ? sf_0 : sf_1; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerLDTM; ++ j) + accum[i * kNumElemsPerLDTM + j] += *reinterpret_cast(&values[j]) * sf; + } + } + + // Commit to the tensor memory empty barrier (only at the leader CTA) + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + }); + + // Flush TMA stores + // NOTES: for the first store, we have to flush all previous TMA, + // as we don't share pipeline stages between two blocks + if (epilogue_thread_idx_in_warpgroup == 0) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + + // Write shared memory + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Epilogue store and addition + // Issue every swizzled atom and pipeline: store shared, add C, and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (epilogue_thread_idx_in_warpgroup == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + // NOTES: if you want to do accumulation, please notice that you need two accumulation barriers + const auto offset = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; + if constexpr (std::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + st_shared(smem_ptr, + *reinterpret_cast(&accum[offset + 0]), + *reinterpret_cast(&accum[offset + 1]), + *reinterpret_cast(&accum[offset + 2]), + *reinterpret_cast(&accum[offset + 3])); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v, "Invalid type"); + st_shared(smem_ptr, + cast_into_bf16_and_pack(accum[offset + 0], accum[offset + 1]), + cast_into_bf16_and_pack(accum[offset + 2], accum[offset + 3]), + cast_into_bf16_and_pack(accum[offset + 4], accum[offset + 5]), + cast_into_bf16_and_pack(accum[offset + 6], accum[offset + 7])); + } + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + if (epilogue_thread_idx_in_warpgroup == 0) { + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_d, local_smem_cd, + n_idx, m_idx + epilogue_warpgroup_idx * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + } + } + + // Flush all stages in the pipeline to make TMA stores visible to the next kernel + // TODO: do we actually need this? + if (epilogue_thread_idx_in_warpgroup == 0) + cute::tma_store_wait<0>(); + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + // TODO: do we need 2 SM allocation? + if (epilogue_warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } + + // To safely deconstruct all barriers, we need a cluster sync + // TODO: optimize it by another round of barrier waits + if constexpr (kNumMulticast > 1) + cute::cluster_sync(); +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh new file mode 100644 index 00000000..0ccec3eb --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -0,0 +1,3 @@ +#pragma once + +// TODO: add implement diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..28b5399a --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -0,0 +1,3 @@ +#pragma once + +// TODO: add implement \ No newline at end of file diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh similarity index 78% rename from deep_gemm/include/deep_gemm/fp8_gemm.cuh rename to deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 5c11cd3d..6fff0252 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -10,13 +10,14 @@ #include #include -#include "mma_utils.cuh" -#include "scheduler.cuh" -#include "tma_utils.cuh" -#include "utils.cuh" +#include +#include +#include namespace deep_gemm { +using namespace deep_gemm::sm90; + template __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) { if (num_former_iters == kNumFormerIters) { @@ -28,59 +29,58 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); } -template -__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) -fp8_gemm_kernel(float* scales_b, int* grouped_layout, - uint32_t shape_m, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_scales_a, - const __grid_constant__ CUtensorMap tensor_map_d) { +__global__ void __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ CUtensorMap tensor_map_a, + const __grid_constant__ CUtensorMap tensor_map_b, + const __grid_constant__ CUtensorMap tensor_map_d, + const __grid_constant__ CUtensorMap tensor_map_sfa) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types using WGMMA = typename FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + // Shared memory static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); - static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); // Configs constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); - constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; - constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_id(); + const uint32_t lane_idx = get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (threadIdx.x == kNumMathThreads) { // NOTES: `reinterpret_cast` must be here, or NVRTC will fail cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); - - // `tensor_map_d` is only used in swizzling mode - // For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode - if constexpr (kSwizzleDMode > 0) - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_sfa)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); } __syncwarp(); @@ -92,8 +92,8 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); __nv_fp8_e4m3* smem_a[kNumStages]; __nv_fp8_e4m3* smem_b[kNumStages]; - float* smem_scales_a[kNumStages]; - float* smem_scales_b; + float* smem_sfa[kNumStages]; + float* smem_sfb; // TMA Barrier for both divisible and non-divisible cases Barrier* full_barriers[kNumStages]; @@ -104,12 +104,12 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, for (uint32_t i = 0; i < kNumStages; ++ i) { smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); + smem_sfa[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE); } - smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); + smem_sfb = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE)); // Fill barriers - auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); #pragma unroll for (uint32_t i = 0; i < kNumStages; ++ i) { full_barriers[i] = barrier_start_ptr + i; @@ -129,7 +129,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, // Make initialized barrier visible in async proxy cutlass::arch::fence_view_async_shared(); - (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + cutlass::arch::fence_barrier_init(); } // Synchronize all threads to make barrier visible in normal memory model @@ -140,7 +140,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, struct NotDivisibleK {}; struct SkipComputation {}; struct NotSkipComputation {}; - auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) { + auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) { constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; @@ -149,15 +149,15 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { if (skip_computation) { - for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter) + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); - } else if (SHAPE_K % kFullKOfAllStages == 0) { - for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter) + } else if (shape_k % kFullKOfAllStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); } else { - for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); - func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); + func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); } }, func, kShouldOptimize ? num_former_iters : 0); }; @@ -168,7 +168,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data @@ -180,7 +180,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, while (scheduler.get_next_block(m_block_idx, n_block_idx)) { launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; // Assign TMA multicast number into A and B // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. @@ -194,30 +194,31 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, #pragma unroll for (uint32_t s = 0; s < kNumInnerStages; ++ s) { // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); // Issue TMA A + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[s]; uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), num_tma_multicast_a); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K), + tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), + smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(shape_k_scales, 1, k_idx / BLOCK_K), num_tma_multicast_a); // Issue TMA B tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), + smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), num_tma_multicast_b); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); } // Wait unaligned cases #pragma unroll for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); full_barriers[s]->arrive(); } }, false, 0); @@ -227,7 +228,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, if constexpr (kNumTMAMulticast > 1) { #pragma unroll for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); } } } else { @@ -235,33 +236,33 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, cutlass::arch::warpgroup_reg_alloc(); // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; if constexpr (not kMustUseUniformedScaleB) { num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; } - uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + auto num_previous_lines = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); + auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(local_sfb + i)); } cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Accumulation for WGMMA or CUDA promotion - constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M); + constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; @@ -279,19 +280,18 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { constexpr bool kSkipComputation = std::is_same_v; constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : - (kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K); + constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages); #pragma unroll for (uint32_t s = 0; s < kNumInnerStages; ++ s) { // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1; // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales); // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); // TODO: remove some useless computation for unaligned Ms #pragma unroll @@ -300,8 +300,8 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset); + auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset); // Commit WGMMA instructions #pragma unroll @@ -347,7 +347,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, // Wait unaligned cases #pragma unroll for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); empty_barrier_arrive(s); } }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); @@ -360,8 +360,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, "Unaligned TMA store or too many TMA store instructions"); DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); - DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, - "Swizzling and padding are not compatible"); // Wait last TMA store to be finished if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) @@ -403,9 +401,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset } else { // No swizzling, just padding - // NOTES: padding must be zero for BF16 output - DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output"); - smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); } // NOTES: only 16 lanes' addresses are used @@ -421,13 +417,14 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, // Use TMA store to write back to global memory // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_block_idx * BLOCK_N + in_block_n_offset, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); cute::tma_store_arrive(); } __syncwarp(); @@ -441,4 +438,4 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, }; // namespace deep_gemm -#pragma clang diagnostic pop \ No newline at end of file +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh new file mode 100644 index 00000000..5b979a8c --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -0,0 +1,139 @@ +#pragma once + +#include + +#include + +namespace deep_gemm { + +// NOTES: the two kernels below always pack the K dimension + +template +__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { + extern __shared__ uint32_t smem_buffer[]; + + // Shapes and strides + constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the group + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + + // Load FP32 SFs + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); + const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + const auto num_values = in_block_mn * SF_K; + const auto num_uint4 = num_values / 4; + #pragma unroll + for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { + const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); + st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + } + + // Fill unaligned values as well + if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) + st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + __syncthreads(); + + // Pack into UE8M0 and store + #pragma unroll + for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) { + const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN; + + // Load shared memory + uint32_t values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + const auto sf_k_idx = sf_k_pack_idx * 4 + j; + values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + } + + // Pack and store + uint32_t packed = 0; + packed |= (values[0] >> 23u); + packed |= (values[1] >> 15u); + packed |= (values[2] >> 7u); + packed |= (values[3] << 1u); + if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) + out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; + } +} + +template +__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { + // Always packing the K dimension + // NOTES: should also assert `mn % 4 == 0` at launch + DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes"); + + // Shapes and strides + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto in_block_mn_uint4 = in_block_mn / 4; + const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K); + + // Shift into the right block along MN + sf += blockIdx.x * BLOCK_MN; + out += blockIdx.x * BLOCK_MN; + + // Each warp is responsible for a packed row + const auto warp_idx = threadIdx.x / 32; + const auto lane_idx = get_lane_idx(); + const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; + if (warp_idx >= in_block_packed_sf_k) + return; + + // Make an offset on the input + uint32_t input_offset = 0; + if constexpr (kNumGroups > 1) { + // Load each group's size + DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups"); + uint32_t group_ks[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) { + const auto group_idx = lane_idx * 4 + i; + group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + } + __syncwarp(); + + // Make the offset + sf_k = 0; + auto sum_packed_sf_k = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumGroups; ++ i) { + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + sf_k += sf_k_in_group; + sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + if (packed_sf_k_idx < sum_packed_sf_k) + break; + if (const auto remainder = sf_k_in_group % 4; remainder > 0) + input_offset += 4 - remainder; + } + } + + for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + // Load + uint4 values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + values[j] = make_uint4(0, 0, 0, 0); + if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) + values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + } + + // Pack and store + uint4 packed; + packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); + packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); + packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); + packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; + } +} + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh deleted file mode 100644 index 69ea2160..00000000 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ /dev/null @@ -1,163 +0,0 @@ -#pragma once - -#include "utils.cuh" - -namespace deep_gemm { - -enum class GemmType { - Normal, - GroupedContiguous, - GroupedMasked -}; - -#pragma clang diagnostic push -#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" -template -struct Scheduler { - int current_iter = -1; - uint32_t num_aligned_m_blocks; - - // For normal GEMM - // Maybe not used in the masked grouped GEMM - uint32_t num_blocks; - uint32_t num_blocks_in_group; - bool is_peer_cta_alive = true; - - // For grouped GEMM - int* grouped_layout; - - // Only used for masked layout - uint32_t curr_group_idx, curr_cumsum; - - __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, - int* grouped_layout = nullptr) { - num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); - if constexpr (kGemmType == GemmType::Normal) { - num_blocks = num_aligned_m_blocks * kNumNBlocks; - } else if (kGemmType == GemmType::GroupedContiguous) { - num_blocks = num_aligned_m_blocks * kNumNBlocks; - this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::GroupedMasked) { - curr_group_idx = curr_cumsum = 0; - this->grouped_layout = grouped_layout; - } - } - - // ReSharper disable once CppNotAllPathsReturnValue - __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { - if constexpr (kGemmType == GemmType::Normal) { - return true; - } else if constexpr (kGemmType == GemmType::GroupedContiguous) { - return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; - } else if constexpr (kGemmType == GemmType::GroupedMasked) { - return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx); - } - } - - __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { - if (num_blocks_in_group == 1) - return false; - if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { - return true; - } else { - DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type"); - if constexpr (kIsTMAMulticastOnA) { - return true; - } else { - auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); - auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); - return group_idx == peer_group_idx; - } - } - } - - __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx, - uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); - - // Swizzle for better L2 usages - auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks; - auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks; - auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_block_idx = group_idx * kNum1DBlocksPerGroup; - auto in_group_idx = block_idx % num_blocks_per_group; - num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); - - // Fix unaligned TMA multicast - if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { - if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { - num_blocks_in_group = num_blocks_in_group ^ 1; - } else { - in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; - first_block_idx += num_blocks_in_group ^ 1; - num_blocks_in_group = 1; - } - } - - // Convert to final M/N block indices - if constexpr (kIsTMAMulticastOnA) { - m_block_idx = in_group_idx / num_blocks_in_group; - n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; - } else { - m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; - n_block_idx = in_group_idx / num_blocks_in_group; - } - } - - template - __device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size, - const uint32_t& block_idx, const uint32_t& m_block_idx=0) { - if constexpr (kGemmType == GemmType::Normal) { - return block_idx * block_size; - } else if constexpr (kGemmType == GemmType::GroupedContiguous) { - auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)); - return offset * shape_dim + block_idx * block_size; - } else if constexpr (kGemmType == GemmType::GroupedMasked) { - return curr_group_idx * shape_dim + block_idx * block_size; - } - } - - __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { - const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; - - if constexpr (kGemmType == GemmType::GroupedMasked) { - uint32_t num_m_blocks; - while (true) { - // End of the task - if (curr_group_idx == kNumGroups) - return false; - - // Within the current group - num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); - auto current_m_block_cumsum = curr_cumsum + num_m_blocks; - if (next_block_idx < current_m_block_cumsum * kNumNBlocks) - break; - - // Move to check the next group - curr_group_idx ++, curr_cumsum = current_m_block_cumsum; - } - - get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); - } else { - if (next_block_idx >= num_blocks) - return false; - - // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned - is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass) - num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass) - (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound - get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); - } - return true; - } -}; - -#pragma clang diagnostic pop - -} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh deleted file mode 100644 index 795dca6a..00000000 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include "utils.cuh" - -namespace deep_gemm { - -// TODO: move this function to other files -__device__ __forceinline__ void -tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { - constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - if (num_tma_multicast == 1) { - cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); - } else if (cute::block_rank_in_cluster() == 0) { - cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); - } -} - -} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh deleted file mode 100644 index 598a4146..00000000 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#ifdef __CLION_IDE__ - -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { - asm volatile("trap;"); -} - -#define printf host_device_printf -#endif - -#ifndef DG_DEVICE_ASSERT -#define DG_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - asm("trap;"); \ - } \ -} while (0) -#endif - -#ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) -#endif - -template -__device__ __host__ constexpr T ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ constexpr T constexpr_gcd(T a, T b) { - return b == 0 ? a : constexpr_gcd(b, a % b); -} diff --git a/deep_gemm/jit/__init__.py b/deep_gemm/jit/__init__.py deleted file mode 100644 index 06a51940..00000000 --- a/deep_gemm/jit/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler -from .runtime import Runtime diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py deleted file mode 100644 index d3f1f762..00000000 --- a/deep_gemm/jit/compiler.py +++ /dev/null @@ -1,284 +0,0 @@ -import functools -import hashlib -import os -import re -import subprocess -import time -import uuid -from typing import Any, Dict, List, Tuple, Type - -import cuda.bindings -import cuda.bindings.nvrtc as nvrtc -from torch.utils.cpp_extension import CUDA_HOME - -from . import interleave_ffma -from .runtime import Runtime, RuntimeCache - -runtime_cache = RuntimeCache() - - -def hash_to_hex(s: str) -> str: - md5 = hashlib.md5() - md5.update(s.encode('utf-8')) - return md5.hexdigest()[0:12] - - -@functools.lru_cache(maxsize=None) -def get_jit_include_dir() -> str: - return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include') - - -@functools.lru_cache(maxsize=None) -def get_deep_gemm_version() -> str: - md5 = hashlib.md5() - - # Update include directories - include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm') - assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}' - for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): - with open(os.path.join(include_dir, filename), 'rb') as f: - md5.update(f.read()) - - # Update `interleave_ffma.py` - with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f: - md5.update(f.read()) - return md5.hexdigest()[0:12] - - -@functools.lru_cache(maxsize=None) -def get_nvcc_compiler() -> Tuple[str, str]: - paths = [] - if os.getenv('DG_JIT_NVCC_COMPILER'): - paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) - paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) - - # Try to find the first available NVCC compiler - least_version_required = '12.3' - version_pattern = re.compile(r'release (\d+\.\d+)') - for path in paths: - if os.path.exists(path): - command = [path, '--version'] - result = subprocess.run(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, text=True) - match = version_pattern.search(result.stdout) - version = match.group(1) - assert match, f'Cannot get the version of NVCC compiler {path}' - assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' - return path, version - raise RuntimeError('Cannot find any available NVCC compiler') - - -@functools.lru_cache(maxsize=None) -def get_default_user_dir(): - if 'DG_JIT_CACHE_DIR' in os.environ: - path = os.getenv('DG_JIT_CACHE_DIR') - os.makedirs(path, exist_ok=True) - return path - return os.path.join(os.path.expanduser('~'), '.deep_gemm') - - -@functools.lru_cache(maxsize=None) -def get_tmp_dir(): - return os.path.join(get_default_user_dir(), 'tmp') - - -@functools.lru_cache(maxsize=None) -def get_cache_dir(): - return os.path.join(get_default_user_dir(), 'cache') - - -def make_tmp_dir(): - tmp_dir = get_tmp_dir() - os.makedirs(tmp_dir, exist_ok=True) - return tmp_dir - - -def put(path, data): - # Write and do POSIX atomic replace - tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}') - with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f: - f.write(data) - os.replace(tmp_file_path, path) - - -class Compiler: - @classmethod - def signature(cls) -> str: - pass - - @staticmethod - def __version__() -> Tuple[int, int]: - pass - - @classmethod - def compile(cls, name: str, code: str, target_path: str) -> None: - pass - - @staticmethod - def flags() -> List[str]: - cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20)) - return [f'-std=c++{cpp_standard}', - '--ptxas-options=--register-usage-level=10' + - (',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''), - # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases - '--diag-suppress=39,161,174,177,186,940'] - - @staticmethod - def include_dirs() -> List[str]: - return [get_jit_include_dir()] - - @classmethod - def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: - # Compiler flags - flags = cls.flags() - - # Build signature - enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0)) - signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' - name = f'kernel.{name}.{hash_to_hex(signature)}' - path = os.path.join(get_cache_dir(), name) - - # Check runtime cache or file system hit - global runtime_cache - cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs) - if cached_runtime is not None: - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Using cached JIT runtime {name} during build') - return cached_runtime - - # Compile into a temporary CU file - os.makedirs(path, exist_ok=True) - cubin_path = os.path.join(path, 'kernel.cubin') - tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') - - start_time = time.time() - cls.compile(name, code, tmp_cubin_path) - end_time = time.time() - elapsed_time = end_time - start_time - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.') - - # Interleave FFMA reuse - if enable_sass_opt: - interleave_ffma.process(tmp_cubin_path) - - # Atomic replace files - os.replace(tmp_cubin_path, cubin_path) - - # Put cache and return - runtime = runtime_cache.get(path, runtime_cls, name, kwargs, force_enable_cache=True) - assert runtime is not None - return runtime - - -class NVCCCompiler(Compiler): - @staticmethod - def __version__() -> Tuple[int, int]: - _, version = get_nvcc_compiler() - major, minor = map(int, version.split('.')) - return major, minor - - @classmethod - def signature(cls) -> str: - return f'{get_nvcc_compiler()[0]}+{cls.__version__()}' - - @classmethod - def flags(cls) -> List[str]: - cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] - return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], - '-gencode=arch=compute_90a,code=sm_90a', - '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', - f'--compiler-options={",".join(cxx_flags)}'] - - @classmethod - def compile(cls, name: str, code: str, target_path: str) -> None: - # Write the code - path = os.path.join(get_cache_dir(), name) - src_path = os.path.join(path, 'kernel.cu') - put(src_path, code) - command = [get_nvcc_compiler()[0], - src_path, '-o', target_path, - *cls.flags()] - if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): - print(f'Compiling JIT runtime {name} with command {command}') - - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}') - assert False, f'Failed to compile {src_path}' - - -class NVRTCCompiler(Compiler): - @staticmethod - def __version__() -> Tuple[int, int]: - res, major, minor = nvrtc.nvrtcVersion() - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - # Failed to get the actual NVRTC version, use cuda-bindings version instead - major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) - return major, minor - - @classmethod - def signature(cls) -> str: - return f'nvrtc+{cls.__version__()}' - - @staticmethod - def include_dirs() -> List[str]: - if CUDA_HOME is None: - raise RuntimeError('CUDA_HOME is required for NVRTC compilation') - return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] - - @classmethod - def flags(cls) -> List[str]: - flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], - '--gpu-architecture=sm_90a', '-default-device'] - # NOTES: PCH is vital for compilation speed - if cls.__version__() >= (12, 8): - flags += ['--pch'] - if int(os.getenv('DG_JIT_DEBUG', 0)): - flags += ['--pch-verbose=true'] - return flags - - @classmethod - def compile(cls, name: str, code: str, target_path: str) -> None: - # Create program - code_bytes = bytes(code, 'utf-8') - result, program = nvrtc.nvrtcCreateProgram( - code_bytes, bytes(name, 'utf-8'), 0, [], []) - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}' - - # Compile - options = [bytes(flag, 'utf-8') for flag in cls.flags()] - if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)): - print(f'Compiling JIT runtime {name} with options: {options}') - compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0] - - # Print compiler log - if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS: - result, log_size = nvrtc.nvrtcGetProgramLogSize(program) - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}' - - log_bytes = bytes(log_size) - result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}' - print(f'Compiler log: {log_bytes.decode("utf-8")}') - - # Exit if failed - assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}' - - # Create CUBIN - result, cubin_size = nvrtc.nvrtcGetCUBINSize(program) - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}' - cubin_bytes = bytes(cubin_size) - result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] - assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}' - - # Write into the file system - put(target_path, cubin_bytes) - - # Destroy handler - assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}' - - -def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime: - compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler - return compiler_cls.build(name, code, runtime_cls, kwargs) diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py deleted file mode 100644 index 7899a221..00000000 --- a/deep_gemm/jit/interleave_ffma.py +++ /dev/null @@ -1,137 +0,0 @@ -import argparse -import mmap -import os -import re -import subprocess -from torch.utils.cpp_extension import CUDA_HOME - - -def run_cuobjdump(file_path): - command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path] - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - assert result.returncode == 0 - return result.stdout - - -def extract_ffma(sass): - lines = sass.splitlines() - collected = [] - current = [] - - arch_name, func_name = 'N/A', 'N/A' - skip_next_line = False - for line in lines: - if 'code for' in line: - arch_name = line.lstrip().lstrip('code for ').rstrip() - elif 'Function :' in line: - func_name = line.lstrip().lstrip('Function :').rstrip() - elif 'FFMA' in line: - current.append(line) - skip_next_line = True - elif skip_next_line: - current.append(line) - skip_next_line = False - else: - if len(current) >= 16: - assert len(current) % 2 == 0 - collected.append((f'{arch_name}::{func_name}', current)) - current = [] - - if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): - print(f'Found {len(collected)} FFMA segments') - return collected - - -def extract_hex_from_line(line): - match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line) - assert match - return int(match.group(1), 16) - - -def validate(m, offset, le_bytes, num_lines): - assert len(le_bytes) == num_lines // 2 - assert m[offset:offset + 16] == le_bytes[0] - for i in range(1, num_lines // 2): - if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]: - return False - return True - - -def parse_registers(line): - line = re.sub(r'/\*.*?\*/', '', line) - line = line.replace(';', '') - tokens = line.strip().split(',') - registers = [] - for token in tokens: - token = token.strip() - words = token.split() - for word in words: - if word.startswith('R'): - reg = word.split('.')[0] - registers.append(reg) - return registers - - -def modify_segment(m, name, ffma_lines): - num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2 - assert num_lines % 2 == 0 - - le_bytes, new_le_bytes = [], [] - reused_list = [] - dst_reg_set = set() - last_reused, last_dst_reg = False, '' - num_changed = 0 - for i in range(num_lines // 2): - dst_reg = parse_registers(ffma_lines[i * 2])[-2] - low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1] - low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line) - le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) - reused = (high_hex & 0x0800000000000000) != 0 - if reused: - is_first_occurred = dst_reg not in dst_reg_set - if is_first_occurred or (last_reused and dst_reg == last_dst_reg): - # Modify the `reuse` and `yield` bits - assert high_hex & 0x0800200000000000, f'{hex(high_hex)}' - high_hex ^= 0x0800200000000000 - reused = False - num_changed += 1 - else: - reused_list.append(i) - dst_reg_set.add(dst_reg) - new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) - last_reused, last_dst_reg = reused, dst_reg - if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): - print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') - - # Find the offset - offsets = [] - offset = m.find(le_bytes[0]) - while offset != -1: - offsets.append(offset) - offset = m.find(le_bytes[0], offset + 1) - offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets)) - - # Replace with `new_le_bytes` - for offset in offsets: - for i in range(num_lines // 2): - m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i] - - -def process(path): - if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)): - print(f'Processing {path}') - output = run_cuobjdump(path) - segments = extract_ffma(output) - with open(path, 'r+b') as f: - mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE) - for segment in segments: - modify_segment(mm, *segment) - mm.close() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') - parser.add_argument('--so', help='Path to the SO file') - args = parser.parse_args() - - process(args.so) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py deleted file mode 100644 index 7a63bf1c..00000000 --- a/deep_gemm/jit/runtime.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import subprocess -import time -import torch -import cuda.bindings.driver as cbd - -from typing import Any, Dict, Optional, Type -from torch.utils.cpp_extension import CUDA_HOME - - -class Runtime: - def __init__(self, path: str) -> None: - self.path = path - self.lib = None - self.kernel = None - assert self.is_path_valid(self.path) - - @staticmethod - def is_path_valid(path: str) -> bool: - # Exists and is a directory - if not os.path.exists(path) or not os.path.isdir(path): - return False - - # Contains all necessary files - files = ['kernel.cubin'] - return all(os.path.exists(os.path.join(path, file)) for file in files) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - raise NotImplemented - - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - raise NotImplemented - - def __call__(self, **kwargs) -> cbd.CUresult: - # Load CUBIN - if self.kernel is None: - start_time = time.time_ns() - - # Load CUBIN - path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8') - result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}' - - # Extract the kernel name - # TODO: use `cuda-bindings` API to do this (requires at least 12.8) - command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - assert result.returncode == 0 - illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail'] - check_illegal = lambda line: any([name in line for name in illegal_names]) - kernel_names = [line.split()[-1] for line in result.stdout.splitlines() - if line.startswith('STT_FUNC') and not check_illegal(line)] - assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' - - # Load kernel from the library - result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8')) - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}' - - end_time = time.time_ns() - elapsed_time = (end_time - start_time) / 1e6 - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') - - # noinspection PyArgumentList - return self.launch(self.kernel, kwargs) - - def __del__(self) -> None: - if self.lib is not None: - res = cbd.cuLibraryUnload(self.lib)[0] - if res != cbd.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to unload library {self.path}: {res}') - - -class RuntimeCache: - def __init__(self) -> None: - self.cache = {} - - def __setitem__(self, path: str, runtime: Runtime) -> None: - self.cache[path] = runtime - - def get(self, path: str, runtime_cls: Type[Runtime], - name: str = '', kwargs: Dict[str, Any] = None, - force_enable_cache: bool = False) -> Optional[Runtime]: - # In Python runtime - if path in self.cache: - return self.cache[path] - - # Already compiled - use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) - if use_cache and os.path.exists(path) and Runtime.is_path_valid(path): - # Print heuristic for the first time - if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))): - simplified_kwargs = dict() - for key, value in kwargs.items() if kwargs is not None else dict().items(): - value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value - value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value - simplified_kwargs[key] = value - print(f'Put kernel {name} with {simplified_kwargs} into runtime cache') - - runtime = runtime_cls(path) - self.cache[path] = runtime - return runtime - return None diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py deleted file mode 100644 index f1fa7bb2..00000000 --- a/deep_gemm/jit_kernels/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from .gemm import gemm_fp8_fp8_bf16_nt -from .m_grouped_gemm import ( - m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked -) -from .wgrad_gemm import ( - wgrad_gemm_fp8_fp8_fp32_nt, - k_grouped_wgrad_gemm_fp8_fp8_fp32_nt -) -from .utils import ( - ceil_div, set_num_sms, get_num_sms, - get_col_major_tma_aligned_tensor, - get_m_alignment_for_contiguous_layout -) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py deleted file mode 100644 index c376c38e..00000000 --- a/deep_gemm/jit_kernels/gemm.py +++ /dev/null @@ -1,242 +0,0 @@ -import math -import torch -from functools import lru_cache -from typing import Tuple - -from ..jit import build -from .runtime import ( - FP8GemmRuntime, GemmType, - make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout - - -def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, - require_divisible: bool = False) -> bool: - divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible - return divisible and num_sms % num_tma_multicast == 0 - - -def get_swizzle_mode(block_n: int) -> int: - elem_size = 2 - for mode_bytes in (128, 64, 32): - if (block_n * elem_size) % mode_bytes == 0: - return mode_bytes - return 0 - - -def get_block_n_padding_for_smem_d(block_n: int) -> int: - # NOTES: padding is for solving bank conflicts, but wastes shared memory space - elem_size, requirement = 2, (4, 8) - bank_stride = (block_n * elem_size) // 4 - padding = (requirement[0] - bank_stride) % requirement[1] - return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size - - -def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, - is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: - assert block_k == 128 - - # Try swizzle first, as it does not waste shared memory - swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d( - block_n) if swizzle_mode == 0 else 0 - - # NOTES: `scales_b` in a total manner or per-stage manner - smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) - smem_a_per_stage = block_m * block_k - smem_scales_a_per_stage = block_m * 4 - smem_b_per_stage = block_n * block_k - smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 - smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 - smem_barrier = (num_stages + int(is_wgrad)) * 8 * 2 - - smem_size = 0 - smem_size += smem_d - smem_size += num_stages * smem_a_per_stage - smem_size += num_stages * smem_scales_a_per_stage - smem_size += num_stages * smem_b_per_stage - smem_size += num_stages * smem_scales_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 - smem_size += smem_barrier - - # Swizzle and padding are not compatible - assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 - - return smem_size, swizzle_mode, block_n_padding - - -@lru_cache(maxsize=None) -def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False, is_grouped_masked: bool = False, - is_fp32_out: bool = False, is_wgrad: bool = False) -> \ - Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: - if not is_grouped_contiguous: - block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) - else: - block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) - - # Avoid bank conflicts for FP32 output - if is_fp32_out: - block_ns = [x for x in block_ns if x % 16 == 8] - - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) - - # Decide block sizes by waves - best_block_m, best_block_n = None, None - for block_m in block_ms: - # NOTES: the block sizes cannot be too large, so at least one dim less than 128 - for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): - success = False - num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) - if best_block_m is None or best_block_n is None: - success = True - elif num_waves < best_num_waves: - success = True - elif num_waves == best_num_waves: - # Check last wave utilization - util = get_last_wave_util(block_m, block_n) - best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util - if util == best_util: - # Case 1: same `block_m`, smaller `block_n` (wasted) - success |= block_m == best_block_m and block_n < best_block_n - # Case 2: same `block_n`, smaller `block_m` (wasted) - success |= block_n == best_block_n and block_m < best_block_m - # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better - success |= block_m != best_block_m and block_n > best_block_n - best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) - assert best_block_m is not None and best_block_n is not None - - # Always pick the longest one - # NOTES: for double B scales, the best number of stages may be reduced - best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 - stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1))) - if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: - # Unrolling both stages and `num_former_iters` will cause large code size - stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) - for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad) - if best_smem_config[0] <= sm90_capacity: - best_num_stages = num_stages - break - assert best_smem_config is not None - assert best_num_stages is not None - - # Decide the number of TMA multicasts and whether broadcast on A - best_tma_multicast_config = (1, True) - - # Try to multicast on the larger block side first - # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even - is_multicast_legal = { - 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, - } - for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): - if m >= 512 and is_multicast_legal[i]: - best_tma_multicast_config = (2, i == 'A') - break - - # Recompute the minimal number of SMs required - # NOTES: less L2 cache usage and less GPU frequency drop - num_waves = get_num_waves(best_block_m, best_block_n) - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] - assert num_min_sms <= num_sms - - return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config - - -def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor) -> None: - """ - Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - - Requirements: - LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. - The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, - the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[m, n]`, representing the result. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - n, k_ = rhs.shape - m_, n_ = out.shape - - # Type and shape checks - assert m == m_ and n == n_ and k == k_ - assert n > 0 and k > 0 - assert lhs_scales.shape == (m, ceil_div(k, 128)) - assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128)) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.bfloat16 - assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 - - # LHS scales must be transposed for TMA loads, but not for RHS scales - # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Do nothing if `m` is zero - if m == 0: - return - - # K must be aligned to 128 - aligned_k = ceil_div(k, 128) * 128 - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) - tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) - tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) - - kwargs = { - # Templated arguments - 'GEMM_TYPE': GemmType.Normal, - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': aligned_k, - 'NUM_GROUPS': 1, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - # Runtime arguments - 'SCALES_B': rhs_scales, - 'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device), - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8GemmRuntime.generate(kwargs) - runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py deleted file mode 100644 index ca2fc79a..00000000 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ /dev/null @@ -1,205 +0,0 @@ -import torch -from typing import Tuple - -from ..jit import build -from .gemm import get_best_configs -from .runtime import ( - FP8GemmRuntime, GemmType, - make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms - - -def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, m_indices: torch.Tensor) -> None: - """ - Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - - Requirements: - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - On the M axis, inputs are grouped into several batches, of which batch sizes aligned to - `get_m_alignment_for_contiguous_layout()` (128). - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`, - the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. - m_indices: a tensor of shape `[m_sum]` with type `torch.int`. - `m_indices[i]` records the group which the i-th row of the LHS belongs to, - which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. - Values of `m_indices` in every-m-alignment-block must also be the same. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - num_groups, n, k_ = rhs.shape - m_, n_ = out.shape - m__ = m_indices.numel() - - # Type and shape checks - assert m == m_ == m__ and k == k_ and n == n_ - assert lhs_scales.shape == (m, ceil_div(k, 128)) - assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.bfloat16 - assert m_indices.dtype == torch.int32 - assert lhs.is_contiguous() and rhs.is_contiguous() - assert out.is_contiguous() and m_indices.is_contiguous() - - # LHS scales must be transposed for TMA load, but not for RHS scales - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Do nothing if `m` is zero - if m == 0: - return - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - m, n, k, 1, num_sms, is_grouped_contiguous=True) - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups) - tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups) - tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups) - - kwargs = { - # Templated arguments - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': k, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': GemmType.GroupedContiguous, - # Runtime arguments - 'SCALES_B': rhs_scales, - 'GROUPED_LAYOUT': m_indices, - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8GemmRuntime.generate(kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) - - -def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: - """ - Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. - - Requirements: - LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement, - this function will do a transposing with a set of slow PyTorch operations. - Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch - should be separately transposed. - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. - The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. - out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. - masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute - in the i-th group. - expected_m: a value hint (which is a value on CPU) for the M expectation of each batch, - correctly setting this value may lead to better performance. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - num_groups, m, k = lhs.shape - num_groups_, n, k_ = rhs.shape - num_groups__, m_, n_ = out.shape - num_groups___ = masked_m.numel() - - # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ - assert m == m_ and n == n_ and k == k_ - assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128)) - assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.bfloat16 - assert masked_m.dtype == torch.int32 - assert lhs.is_contiguous() and rhs.is_contiguous() - assert out.is_contiguous() and masked_m.is_contiguous() - - # LHS scales must be transposed for TMA load, but not for RHS scales - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) - assert rhs_scales.is_contiguous() - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) - - # Extra checks for TMA store - if num_groups > 1 and m > block_m: - assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' - - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups) - tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups) - tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups) - - kwargs = { - # Templated arguments - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': k, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'SWIZZLE_D_MODE': smem_config[1], - 'BLOCK_N_PADDING': smem_config[2], - 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - 'GEMM_TYPE': GemmType.GroupedMasked, - # Runtime arguments - 'SCALES_B': rhs_scales, - 'GROUPED_LAYOUT': masked_m, - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8GemmRuntime.generate(kwargs) - runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) - runtime(**kwargs) diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py deleted file mode 100644 index e65e85aa..00000000 --- a/deep_gemm/jit_kernels/runtime.py +++ /dev/null @@ -1,318 +0,0 @@ -import ctypes -import os -import enum -import torch -import cuda.bindings.driver as cbd -from typing import Any, Dict, Tuple - -from .utils import get_tma_aligned_size -from ..jit.runtime import Runtime - - -class GemmType(enum.Enum): - Normal = 0 - GroupedContiguous = 1 - GroupedMasked = 2 - - def __str__(self) -> str: - return { - 0: 'Normal', - 1: 'GroupedContiguous', - 2: 'GroupedMasked', - }[self.value] - - -tmap_type_map: Dict[Any, str] = { - torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, - torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, - torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, - torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, - torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, - torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, - torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, - torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, -} - -swizzle_type_map = { - 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, - 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, - 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, - 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, -} - - -def get_num_math_warpgroups(block_m: int) -> int: - return 1 if block_m == 64 else 2 - - -def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: - assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' - return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads - - -def make_2d_tma_copy_desc(t: torch.Tensor, - gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t, - smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t], - swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap: - tensor_dtype = tmap_type_map[t.dtype] - res, tensor_map = cbd.cuTensorMapEncodeTiled( - tensor_dtype, - 2, - t.data_ptr(), - gmem_dims, - (gmem_outer_stride,), - smem_dims, - (cbd.cuuint32_t(1), cbd.cuuint32_t(1)), - cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle_type, - cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, - cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, - ) - - if res != cbd.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to encode tensor map: {res}') - return tensor_map - - -def make_2d_tma_desc(t: torch.Tensor, - gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int, - smem_inner_dim: int, smem_outer_dim: int, - swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap: - gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim)) - smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim)) - return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type) - - -def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor, - shape_m: int, shape_k: int, m_stride: int, - block_m: int, block_k: int, - num_groups: int) -> cbd.CUtensorMap: - return make_2d_tma_desc(t, - shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, - block_k, block_m) - - -def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor, - shape_n: int, shape_k: int, n_stride: int, - block_n: int, block_k: int, - num_groups: int) -> cbd.CUtensorMap: - return make_2d_tma_desc(t, - shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride, - block_k, block_n) - - -def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor, - shape_m: int, shape_n: int, m_stride: int, - block_m: int, block_n: int, - num_groups: int, - swizzle_mode: int) -> cbd.CUtensorMap: - # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` - # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - return make_2d_tma_desc(t, - shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, - block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m, - swizzle_type_map[swizzle_mode]) - - -def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor, - shape_mn: int, shape_k: int, - block_mn: int, block_k: int, - num_groups: int) -> cbd.CUtensorMap: - # Make TMA aligned to 16 bytes - shape_mn = get_tma_aligned_size(shape_mn, t.element_size()) - return make_2d_tma_desc(t, - shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn, - block_mn, 1, - cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) - - -class FP8GemmRuntime(Runtime): - def __init__(self, path: str) -> None: - super().__init__(path) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - code = f''' -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - -#include -#include - -#include - -using namespace deep_gemm; - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&fp8_gemm_kernel< - {kwargs['N']}, - {kwargs['K']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['BLOCK_N_PADDING']}, - {kwargs['SWIZZLE_D_MODE']}, - {kwargs['NUM_GROUPS']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, - GemmType::{kwargs['GEMM_TYPE']} - >); -}}; -''' - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Generated FP8 GEMM code:\n{code}') - return code - - # noinspection PyMethodOverriding - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - num_tma_threads = 128 - num_math_threads_per_group = 128 - - result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' - - attr_val = cbd.CUlaunchAttributeValue() - attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] - attr_val.clusterDim.y = 1 - attr_val.clusterDim.z = 1 - attr = cbd.CUlaunchAttribute() - attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attr.value = attr_val - - config = cbd.CUlaunchConfig() - config.numAttrs = 1 - config.attrs = [attr] - config.gridDimX = kwargs['NUM_SMS'] - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) - config.blockDimY = 1 - config.blockDimZ = 1 - config.sharedMemBytes = kwargs['SMEM_SIZE'] - config.hStream = kwargs['STREAM'] - - arg_values = ( - kwargs['SCALES_B'].data_ptr(), - kwargs['GROUPED_LAYOUT'].data_ptr(), - kwargs['M'], - kwargs['TENSOR_MAP_A'], - kwargs['TENSOR_MAP_B'], - kwargs['TENSOR_MAP_SCALES_A'], - kwargs['TENSOR_MAP_D'], - ) - arg_types = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_uint32, - None, - None, - None, - None, - ) - return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) - - -class FP8WGradGemmRuntime(Runtime): - def __init__(self, path: str) -> None: - super().__init__(path) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - code = f''' -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - -#include -#include - -#include - -using namespace deep_gemm; - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&fp8_wgrad_gemm_kernel< - {kwargs['M']}, - {kwargs['N']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_LAST_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'} - >); -}}; -''' - if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Generated FP8 WGrad GEMM code:\n{code}') - return code - - # noinspection PyMethodOverriding - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - num_tma_threads = 128 - num_math_threads_per_group = 128 - - result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] - assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' - - attr_val = cbd.CUlaunchAttributeValue() - attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] - attr_val.clusterDim.y = 1 - attr_val.clusterDim.z = 1 - attr = cbd.CUlaunchAttribute() - attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - attr.value = attr_val - - config = cbd.CUlaunchConfig() - config.numAttrs = 1 - config.attrs = [attr] - config.gridDimX = kwargs['NUM_SMS'] - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) - config.blockDimY = 1 - config.blockDimZ = 1 - config.sharedMemBytes = kwargs['SMEM_SIZE'] - config.hStream = kwargs['STREAM'] - - arg_values = ( - kwargs['K'], - kwargs['TENSOR_MAP_A'], - kwargs['TENSOR_MAP_B'], - kwargs['TENSOR_MAP_SCALES_A'], - kwargs['TENSOR_MAP_SCALES_B'], - kwargs['TENSOR_MAP_D'], - ) - arg_types = ( - ctypes.c_uint32, - None, - None, - None, - None, - None, - ) - return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py deleted file mode 100644 index c6da56b0..00000000 --- a/deep_gemm/jit_kernels/utils.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch - -_num_sms = None - - -def set_num_sms(num_sms: int) -> None: - """ - Set the maximum SM count for all GEMM kernels to use. - - Arguments: - num_sms: the desired maximum SM count for all GEMM kernels to use. - """ - global _num_sms - assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count - _num_sms = num_sms - - -def get_num_sms() -> int: - """ - Get the current maximum limit of SM count for all GEMM kernels to use. - If the count is never specified, the function will return the number of device SMs. - - Returns: - Current maximum limit of SM count for all GEMM kernels to use. - """ - global _num_sms - if _num_sms is None: - _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count - return _num_sms - - -def ceil_div(x: int, y: int) -> int: - """ - Perform ceiling division of two integers. - - Args: - x: the dividend. - y: the divisor. - - Returns: - The result of the ceiling division. - """ - return (x + y - 1) // y - - -def get_m_alignment_for_contiguous_layout(): - """ - When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. - Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well - with GEMM block shape. - - Returns: - Group-level alignment requirement for grouped contiguous layout, which is always 128. - """ - return 128 - - -def get_tma_aligned_size(x: int, element_size: int) -> int: - """ - Global memory address of TMA must be 16-byte aligned. - Since we use column-major layout for the LHS scaling tensor, - the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. - - Arguments: - x: original M-axis shape of the LHS scaling tensor. - element_size: element size of the LHS scaling tensor. - - Returns: - M-axis shape of the LHS scaling tensor after padding. - """ - tma_alignment_bytes = 16 - assert tma_alignment_bytes % element_size == 0 - alignment = tma_alignment_bytes // element_size - return ceil_div(x, alignment) * alignment - - -def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: - """ - Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. - If the input tensor is already column-major layout and 16-byte aligned along the M axis - (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. - - Arguments: - x: usually the LHS scaling tensor in GEMM. - - Returns: - The LHS scaling tensor of TMA-aligned transposed format. - """ - # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA - assert x.dim() in (2, 3) - remove_dim = False - m, n = x.shape[-2], x.shape[-1] - aligned_m = get_tma_aligned_size(m, x.element_size()) - if x.dim() == 2: - if x.stride(0) == 1 and x.stride(1) == aligned_m: - return x - x, remove_dim = x.unsqueeze(0), True - - b = x.shape[0] - - # The last kernel gives a column-major TMA aligned layout - if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: - return x.squeeze(0) if remove_dim else x - - # Normal layout requires transposing - aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) - aligned_x[:, :m, :] = x - aligned_x = aligned_x[:, :m, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x diff --git a/deep_gemm/jit_kernels/wgrad_gemm.py b/deep_gemm/jit_kernels/wgrad_gemm.py deleted file mode 100644 index 00b8cd10..00000000 --- a/deep_gemm/jit_kernels/wgrad_gemm.py +++ /dev/null @@ -1,158 +0,0 @@ -import torch -from typing import List, Tuple - -from ..jit import build -from .runtime import ( - FP8WGradGemmRuntime, GemmType, - make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .gemm import get_best_configs -from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size - - -def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor): - """ - Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. - Results will be accumulated into the output tensor. - - Requirements: - LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1. - The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4. - RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format. - If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations. - - Arguments: - lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, - the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. - rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`, - the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`. - out: the FP32 output tensor of shape `[m, n]`, which will be accumulated. - """ - lhs, lhs_scales = lhs - rhs, rhs_scales = rhs - m, k = lhs.shape - n, k_ = rhs.shape - m_, n_ = out.shape - - # Type and shape checks - assert m == m_ and n == n_ and k == k_ - assert n > 0 and m > 0 - assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m) - assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n) - assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 - assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 - assert out.dtype == torch.float - assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1 - - # LHS and RHS scales must be transposed for TMA load - # NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels - def get_valid_scales(scales: torch.Tensor, mn: int): - if scales.shape == (ceil_div(k, 128), mn): - # For k-grouped GEMMs - scales = scales.permute(1, 0) - assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn - else: - scales = get_col_major_tma_aligned_tensor(scales) - return scales - - lhs_scales = get_valid_scales(lhs_scales, m) - rhs_scales = get_valid_scales(rhs_scales, n) - - # Do nothing if `k` is zero - if k == 0: - return - - # K must be aligned to 128 - aligned_k = ceil_div(k, 128) * 128 - - # Auto-tuning with compilation - num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( - m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True) - num_last_stages = ceil_div(k, 128) % num_stages - block_k = 128 - num_tma_threads = 128 - num_math_threads_per_group = 128 - - tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1) - tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1) - tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1]) - tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1) - tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1) - - kwargs = { - # Templated arguments - 'GEMM_TYPE': GemmType.Normal, - 'NUM_TMA_THREADS': num_tma_threads, - 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': aligned_k, - 'NUM_GROUPS': 1, - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'NUM_STAGES': num_stages, - 'NUM_LAST_STAGES': num_last_stages, - 'NUM_TMA_MULTICAST': tma_multicast_config[0], - 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], - # Runtime arguments - 'NUM_SMS': num_sms, - 'SMEM_SIZE': smem_config[0], - 'TENSOR_MAP_A': tensor_map_a, - 'TENSOR_MAP_B': tensor_map_b, - 'TENSOR_MAP_SCALES_A': tensor_map_scales_a, - 'TENSOR_MAP_SCALES_B': tensor_map_scales_b, - 'TENSOR_MAP_D': tensor_map_d, - 'STREAM': torch.cuda.current_stream().cuda_stream, - 'DEVICE_INDEX': out.device.index - } - - # Generate, build and run the kernel - code = FP8WGradGemmRuntime.generate(kwargs) - runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs) - runtime(**kwargs) - - -def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, - batch_sizes: List[int]): - """ - Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling. - Results will be accumulated into the output tensor. - - Requirements: - This function handles multiple batches with varying k-dimensions, processing each batch sequentially. - Each batch's LHS, RHS, and output tensors must be contiguous. - The RHS and RHS scaling factors are required to be transposed. - The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format. - - Arguments: - lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data, - and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows. - The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`, - representing the per-128-channel scaling factors. - rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data, - and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows. - The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`, - representing the per-128-channel scaling factors. - out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated. - batch_sizes: A list of integers specifying the k-dimension for each batch. - """ - lhs, lhs_scales = lhs[0].view(-1), lhs[1] - rhs, rhs_scales = rhs[0].view(-1), rhs[1] - num_batches, m, n = out.shape - - lhs_offset, rhs_offset, scales_offset = 0, 0, 0 - - for i in range(num_batches): - k = batch_sizes[i] - lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k) - rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k) - lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] - rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)] - wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i]) - - lhs_offset += m * k - rhs_offset += n * k - scales_offset += ceil_div(k, 128) diff --git a/deep_gemm/testing/__init__.py b/deep_gemm/testing/__init__.py new file mode 100644 index 00000000..2537dbf1 --- /dev/null +++ b/deep_gemm/testing/__init__.py @@ -0,0 +1,3 @@ +from . import bench, numeric +from .bench import * +from .numeric import * diff --git a/deep_gemm/utils.py b/deep_gemm/testing/bench.py similarity index 78% rename from deep_gemm/utils.py rename to deep_gemm/testing/bench.py index 55a9affa..7e77866d 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/testing/bench.py @@ -1,8 +1,6 @@ import os import sys -import time import torch -import torch.distributed as dist def bench(fn, num_warmups: int = 5, num_tests: int = 10, @@ -31,7 +29,7 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10, end_event.record() torch.cuda.synchronize() - return start_event.elapsed_time(end_event) / num_tests + return start_event.elapsed_time(end_event) / num_tests / 1e3 class empty_suppress: @@ -77,8 +75,9 @@ def __exit__(self, *_): self.errnull_file.close() -def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, - trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, with_multiple_kernels: bool = False): # Conflict with Nsight Systems using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) @@ -96,12 +95,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() with profiler: for i in range(2): - # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead - if barrier_comm_profiling: - lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') - rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') - lhs @ rhs - dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() @@ -116,7 +109,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: # Parse the profiling table assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) - is_tupled = isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) @@ -145,21 +138,4 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: break kernel_times.append(total_time / total_num) - return tuple(kernel_times) if is_tupled else kernel_times[0] - - -def calc_diff(x, y): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def count_bytes(tensors): - total = 0 - for t in tensors: - if isinstance(t, tuple): - total += count_bytes(t) - else: - total += t.numel() * t.element_size() - return total + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py new file mode 100644 index 00000000..d06a03b9 --- /dev/null +++ b/deep_gemm/testing/numeric.py @@ -0,0 +1,19 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/deep_gemm/utils/__init__.py b/deep_gemm/utils/__init__.py new file mode 100644 index 00000000..e8f859a2 --- /dev/null +++ b/deep_gemm/utils/__init__.py @@ -0,0 +1,3 @@ +from . import math, layout +from .layout import * +from .math import * diff --git a/deep_gemm/utils/layout.py b/deep_gemm/utils/layout.py new file mode 100644 index 00000000..ac8c070b --- /dev/null +++ b/deep_gemm/utils/layout.py @@ -0,0 +1,11 @@ +from deep_gemm_cpp import ( + get_tma_aligned_size, + get_mk_alignment_for_contiguous_layout, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor +) + +# Some alias +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py new file mode 100644 index 00000000..884a7112 --- /dev/null +++ b/deep_gemm/utils/math.py @@ -0,0 +1,48 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % 128 == 0 + m, n = x.shape + x_view = x.view(-1, 128, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) diff --git a/develop.sh b/develop.sh new file mode 100755 index 00000000..58798613 --- /dev/null +++ b/develop.sh @@ -0,0 +1,25 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Link CUTLASS includes +ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include +ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include + +# Remove old dist file, build, and build +rm -rf build dist +rm -rf *.egg-info +python setup.py build + +# Find the .so file in build directory and create symlink in current directory +so_file=$(find build -name "*.so" -type f | head -n 1) +if [ -n "$so_file" ]; then + ln -sf "$so_file" . +else + echo "Error: No SO file found in build directory" >&2 + exit 1 +fi + +# Open users' original directory +cd "$original_dir" diff --git a/figures/design.png b/figures/design.png deleted file mode 100644 index b3761d60..00000000 Binary files a/figures/design.png and /dev/null differ diff --git a/indexing/main.cu b/indexing/main.cu deleted file mode 100644 index 5b15256a..00000000 --- a/indexing/main.cu +++ /dev/null @@ -1,8 +0,0 @@ -#include "deep_gemm/fp8_gemm.cuh" -#include "deep_gemm/fp8_wgrad_gemm.cuh" - -using namespace deep_gemm; - -int main() { - return 0; -} diff --git a/install.sh b/install.sh new file mode 100755 index 00000000..6b675d6c --- /dev/null +++ b/install.sh @@ -0,0 +1,13 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel +pip install dist/*.whl + +# Open users' original directory +cd "$original_dir" diff --git a/setup.py b/setup.py index b39efd03..72db0500 100644 --- a/setup.py +++ b/setup.py @@ -2,34 +2,28 @@ import setuptools import shutil import subprocess +from setuptools import find_packages from setuptools.command.build_py import build_py -from setuptools.command.develop import develop +from torch.utils.cpp_extension import CppExtension, CUDA_HOME current_dir = os.path.dirname(os.path.realpath(__file__)) -jit_include_dirs = ('deep_gemm/include/deep_gemm', ) -third_party_include_dirs = ( +cxx_flags = ['-std=c++20', '-O3', '-fPIC', '-Wno-psabi'] +sources = ['csrc/python_api.cpp'] +build_include_dirs = [ + f'{CUDA_HOME}/include', + 'deep_gemm/include', + 'third-party/cutlass/include', + 'third-party/fmt/include', +] +build_libraries = ['cuda', 'cudart'] +build_library_dirs = [ + f'{CUDA_HOME}/lib64', + f'{CUDA_HOME}/lib64/stub' +] +third_party_include_dirs = [ 'third-party/cutlass/include/cute', 'third-party/cutlass/include/cutlass', -) - - -class PostDevelopCommand(develop): - def run(self): - develop.run(self) - self.make_jit_include_symlinks() - - @staticmethod - def make_jit_include_symlinks(): - # Make symbolic links of third-party include directories - for d in third_party_include_dirs: - dirname = d.split('/')[-1] - src_dir = f'{current_dir}/{d}' - dst_dir = f'{current_dir}/deep_gemm/include/{dirname}' - assert os.path.exists(src_dir) - if os.path.exists(dst_dir): - assert os.path.islink(dst_dir) - os.unlink(dst_dir) - os.symlink(src_dir, dst_dir, target_is_directory=True) +] class CustomBuildPy(build_py): @@ -37,9 +31,21 @@ def run(self): # First, prepare the include directories self.prepare_includes() - # Then run the regular build + # Second, make clusters' cache setting default into `envs.py` + self.generate_default_envs() + + # Finally, run the regular build build_py.run(self) + def generate_default_envs(self): + code = '# Pre-installed environment variables\n' + code += 'persistent_envs = dict()\n' + for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_DISABLE_SHORTCUT_CACHE'): + code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else '' + + with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f: + f.write(code) + def prepare_includes(self): # Create temporary build directory instead of modifying package directory build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') @@ -67,19 +73,28 @@ def prepare_includes(self): except: revision = '' + # noinspection PyTypeChecker setuptools.setup( name='deep_gemm', - version='1.0.0' + revision, - packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'], + version='2.0.0' + revision, + packages=find_packages('.'), package_data={ 'deep_gemm': [ - 'include/deep_gemm/*', + 'include/deep_gemm/**/*', 'include/cute/**/*', 'include/cutlass/**/*', ] }, + ext_modules=[ + CppExtension(name='deep_gemm_cpp', + sources=sources, + include_dirs=build_include_dirs, + libraries=build_libraries, + library_dirs=build_library_dirs, + extra_compile_args=cxx_flags) + ], + zip_safe=False, cmdclass={ - 'develop': PostDevelopCommand, 'build_py': CustomBuildPy, }, ) diff --git a/tests/generators.py b/tests/generators.py new file mode 100644 index 00000000..a0597ad0 --- /dev/null +++ b/tests/generators.py @@ -0,0 +1,212 @@ +import enum +import random +import torch +from typing import Generator, Tuple, List + +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, + get_mk_alignment_for_contiguous_layout +) + + +class KernelType(enum.Enum): + # For SM100 GEMMs + Kernel1D1D = 0 + Kernel1D2D = 1 + + def is_1d1d(self): + return self.value == 0 + + def is_1d2d(self): + return self.value == 1 + + +class MajorTypeAB(enum.Enum): + KMajor = 0 + MNMajor = 1 + + def is_k_major(self): + return self.value == 0 + + def is_mn_major(self): + return self.value == 1 + + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def get_ue8m0_usage(kernel_type: KernelType) -> bool: + if get_arch_major() == 9: + return False + return kernel_type.is_1d1d() + + +def get_kernel_types() -> tuple: + return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D) + + +def get_out_dtype() -> tuple: + return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float) + + +def get_major_ab(freeze_a: bool) -> tuple: + if get_arch_major() == 9: + return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), ) + if freeze_a: + return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor) + return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \ + (MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + + +def enumerate_normal() -> Generator: + for kernel_type in get_kernel_types(): + for m in (128, 4096): + for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]: + for major_a, major_b in get_major_ab(False): + for out_dtype in get_out_dtype(): + for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True): + yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype + + +def enumerate_m_grouped_contiguous() -> Generator: + for kernel_type in get_kernel_types(): + for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): + for major_a, major_b in get_major_ab(True): + yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b + + +def enumerate_m_grouped_masked() -> Generator: + max_m = 4096 + for kernel_type in get_kernel_types(): + for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for n, k in ((4096, 7168), (7168, 2048), ): + yield kernel_type, num_groups, max_m, m, n, k + + +def enumerate_k_grouped_contiguous(): + # TODO: support SM90 kernels + if get_arch_major() == 9: + return [] + + # Must with FP32 accumulation and 1D1D kernels + for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 + ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 + (16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16 + ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] + yield num_groups, m, n, ks, expected_k_per_group + + +def enumerate_sf_layout(): + for with_transpose in (True, False): + for mn in (4096, 4097, 8192): + for k in (128, 7168, 7296): + for num_groups in (1, 2, 4) if with_transpose else (1, ): + if num_groups > 1 and (mn * ceil_div(k, 128)) % 4 != 0: + continue + if not with_transpose and mn % 4 != 0: + continue + yield mn, k, with_transpose, num_groups + + +def enumerate_k_grouped_sf_layout(): + alignment = get_mk_alignment_for_contiguous_layout() + assert alignment % 128 == 0 + for mn in (4096, 7168): + for num_groups, avg_k in ((16, 2048), (8, 4096), (72, 384), (128, 256)): + ks = [align(int(random.uniform(0.7, 1.3) * avg_k), alignment) for _ in range(num_groups)] + yield mn, ks, num_groups + + +def generate_normal(m: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + accumulate: bool, out_dtype: torch.dtype, + use_ue8m0: bool): + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ + torch.empty((m, n), device='cuda', dtype=out_dtype) + c = d if accumulate else None + ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) + + a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) + a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1]) + b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1]) + return a_fp8, b_fp8, c, d, ref_d + + +def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, use_ue8m0: bool) -> \ + Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] + m = sum(aligned_ms) + + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): + actual_end = start + actual_m + aligned_end = start + aligned_m + m_indices[start:actual_end] = i + m_indices[actual_end:aligned_end] = -1 + ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t() + start = aligned_end + ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d) + + assert major_a.is_k_major() + a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) + b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1]) + return m, a_fp8, b_fp8, m_indices, d, ref_d + + +def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, use_ue8m0: bool) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.einsum('gmk,gnk->gmn', a, b) + + a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float)) + b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) + b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + assert masked_m.amax().item() <= max_m + + return a_fp8, b_fp8, masked_m, d, ref_d + + +def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool): + assert get_mk_alignment_for_contiguous_layout() % 128 == 0 + k = sum(ks) + + a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16) + b = torch.randn((k, n), device='cuda', dtype=torch.bfloat16) + c = torch.randn((num_groups, m, n), device='cuda', dtype=torch.float) * 32 + d = c + ref_d = torch.empty_like(c) + + start = 0 + for i, group_k in enumerate(ks): + end = start + group_k + ref_d[i] = c[i] + (a[start:end].T @ b[start:end]) + start = end + + a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) + b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) + return k, a_fp8, b_fp8, c, d, ref_d diff --git a/tests/test_core.py b/tests/test_core.py index 3b88539c..d9ddc75d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,297 +1,161 @@ -# PyTorch has its own NVRTC, which may have a lower version than the system -# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch -import cuda.bindings.nvrtc as nvrtc -print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}') - +import copy import random +import time import torch -from typing import List, Tuple import deep_gemm -from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor -from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout - - -def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - pad_size = (128 - (n % 128)) % 128 - x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) - return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) - - -def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - - -def construct(m: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - ref_out = x @ y.t() - - x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ - Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - alignment = get_m_alignment_for_contiguous_layout() - group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] - m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) - - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - m_indices = torch.empty(m, device='cuda', dtype=torch.int32) - out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) - - start = 0 - for i, group_m in enumerate(group_ms): - actual_end = start + group_m - aligned_end = start + ceil_div(group_m, alignment) * alignment - m_indices[start:actual_end] = i - m_indices[actual_end:aligned_end] = -1 - ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() - start = aligned_end - ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) - - assert m % 4 == 0, f'TMA alignment error: {m}' - x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) - for i in range(num_groups): - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - return m, x_fp8, y_fp8, m_indices, out, ref_out - - -def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - x = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - out = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) - ref_out = torch.einsum('gmk,gnk->gmn', x, y) - - assert max_m % 4 == 0, f'TMA alignment error: {max_m}' - x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, k // 128), device='cuda', dtype=torch.float)) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) - for i in range(num_groups): - x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - - # Construct mask - masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) - for j in range(num_groups): - masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) - assert masked_m.amax().item() <= max_m - return x_fp8, y_fp8, masked_m, out, ref_out - - -def construct_wgrad(m: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10 - out = residual.clone() - ref_out = residual + (x.float() @ y.float().t()) - - x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = per_token_cast_to_fp8(y) - - # NOTES: please do inplace add on the `out` later - return x_fp8, y_fp8, residual, out, ref_out - - -def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]: - num_groups, total_k = len(k_sizes), sum(k_sizes) - - x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16) - y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16) - out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) - ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) - - # Fill tensors with data and compute reference output - x_offset, y_offset = 0, 0 - for idx, k in enumerate(k_sizes): - x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - - x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten()) - y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten()) - ref_out[idx] = x_chunk.float() @ y_chunk.float().t() - - x_offset += m * k - y_offset += n * k - - x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) - y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) - - total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes) - x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) - y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) - - # Cast to FP8 and prepare scale factors - x_offset, y_offset, scale_offset = 0, 0, 0 - for k in k_sizes: - x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k)) - y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k)) - - x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) - y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) - - num_scales = ceil_div(k, 128) - x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) - y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) - - x_offset += m * k - y_offset += n * k - scale_offset += num_scales - - return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes +from deep_gemm.testing import ( + bench, bench_kineto, + calc_diff, count_bytes +) + +from generators import ( + KernelType, get_ue8m0_usage, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous +) def test_gemm() -> None: print('Testing GEMM:') - for m in (64, 128, 4096): - for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0) + func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else (a[0].T, a[1].T) + b = b if major_b.is_k_major() else (b[0].T, b[1].T) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) + diff = calc_diff(d, ref_d) + assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0) + + # Test launch overhead + launch_start_t = time.time_ns() + deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) + launch_end_t = time.time_ns() + torch.cuda.synchronize() + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}):' + f' launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') print() def test_m_grouped_gemm_contiguous() -> None: - print('Testing grouped contiguous GEMM:') - - for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), (8, 4096, 2048, 7168), - (32, 256, 7168, 4096), (32, 256, 2048, 7168)): - # NOTES: we should mask the unfilled part before calculating difference - m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + print('Testing m-grouped contiguous GEMM:') + + for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) + func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - valid_m = (m_indices != -1).sum().item() - print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') print() def test_m_grouped_gemm_masked() -> None: - print('Testing grouped masked GEMM:') - - for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)): - for k, n in ((7168, 4096), (2048, 7168), ): - # Test correctness - for i in range(10): - x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) - for j in range(num_groups): - diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) - assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) - - # Test performance with fixed shapes - # noinspection PyUnboundLocalVariable - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - + print('Testing m-grouped masked GEMM:') -def test_wgrad_gemm(): - print('Testing weight gradient GEMM:') + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 - for k in (4096, 8192): - for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)): - # Test correctness - x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) - deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' - # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) - x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - # noinspection PyShadowingNames - def test_func(): - deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + # noinspection PyShadowingNames + def test_func(): + deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True) - print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') print() -def test_k_grouped_wgrad_gemm(): - print('Testing grouped weight gradient GEMM:') - - for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)): - for m, n in ((7168, 4096), (2048, 7168)): - # Vary k sizes around base_k - k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)] - k_sizes.append(base_k * num_groups - sum(k_sizes)) - - # Test correctness - x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) - deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) - - for idx in range(num_groups): - diff = calc_diff(out[idx], ref_out[idx]) - assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}' - - # Construct new tensors to avoid L2 cache acceleration - x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) - total_k = sum(k_sizes) - - def test_func(): - deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) - - t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups - print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' - f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s') +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + for num_groups, m, n, ks, expected_k_per_group in enumerate_k_grouped_contiguous(): + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, new_ks, use_ue8m0=use_ue8m0) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c=c) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {i=}, {diff:.5f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, ks, use_ue8m0=use_ue8m0) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=c) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') print() @@ -307,6 +171,4 @@ def test_func(): test_gemm() test_m_grouped_gemm_contiguous() test_m_grouped_gemm_masked() - - test_wgrad_gemm() - test_k_grouped_wgrad_gemm() + test_k_grouped_gemm_contiguous() diff --git a/tests/test_jit.py b/tests/test_jit.py deleted file mode 100644 index 26b7b36c..00000000 --- a/tests/test_jit.py +++ /dev/null @@ -1,98 +0,0 @@ -import ctypes -import os -import torch -import cuda.bindings.driver as cbd -from typing import Any, Dict - -from deep_gemm import jit - -# Essential debugging staffs -os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1') -os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') - - -class VectorAddRuntime(jit.Runtime): - def __init__(self, path: str) -> None: - super().__init__(path) - - @staticmethod - def generate(kwargs: Dict[str, Any]) -> str: - return f""" -#ifdef __CUDACC_RTC__ -#include -#else -#include -#endif - -#include -#include - -template -__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ - uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < n) {{ - c[i] = a[i] + b[i]; - }} -}} - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); -}} -""" - - # noinspection PyShadowingNames,PyMethodOverriding - @staticmethod - def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: - assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape - assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device - assert kwargs['A'].dim() == 1 - - config = cbd.CUlaunchConfig() - config.gridDimX = (kwargs['A'].numel() + 127) // 128 - config.gridDimY = 1 - config.gridDimZ = 1 - config.blockDimX = 128 - config.blockDimY = 1 - config.blockDimZ = 1 - config.hStream = kwargs['STREAM'] - - arg_values = ( - kwargs['A'].data_ptr(), - kwargs['B'].data_ptr(), - kwargs['C'].data_ptr(), - kwargs['A'].numel(), - ) - arg_types = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_uint32, - ) - - return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0] - - -if __name__ == '__main__': - print('Generated code:') - kwargs = {'T': 'float'} - code = VectorAddRuntime.generate(kwargs) - print(code) - print() - - for compiler_name in ('NVCC', 'NVRTC'): - # Get compiler - compiler_cls = getattr(jit, f'{compiler_name}Compiler') - print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}') - - # Build - print('Building ...') - func = compiler_cls.build('test_func', code, VectorAddRuntime, kwargs) - - # Run and check - a = torch.randn((1024, ), dtype=torch.float32, device='cuda') - b = torch.randn((1024, ), dtype=torch.float32, device='cuda') - c = torch.empty_like(a) - ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) - assert ret == cbd.CUresult.CUDA_SUCCESS, ret - torch.testing.assert_close(c, a + b) - print(f'JIT test for {compiler_name} passed\n') diff --git a/tests/test_layout.py b/tests/test_layout.py new file mode 100644 index 00000000..6cad6426 --- /dev/null +++ b/tests/test_layout.py @@ -0,0 +1,104 @@ +import time +import torch +import random +from deep_gemm.testing import bench_kineto, count_bytes +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, + get_tma_aligned_size, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor +) + +from generators import ( + enumerate_sf_layout, + enumerate_k_grouped_sf_layout +) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float and x.dim() in (2, 3) + + # First, convert into UE8M0 `uint8_t` + ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) + + # Second, make padded packed tensors + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8) + padded[:, :mn, :k] = ue8m0_tensor + padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4) + + # Finally, transpose + transposed = torch.zeros((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int).mT + transposed[:, :, :] = padded + aligned_x = transposed[:, :mn, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def test_sf_layout_kernels() -> None: + print('Testing SF layout kernels:') + for mn, k, with_transpose, num_groups in enumerate_sf_layout(): + x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=True) + fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1) + fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2) + + # Correctness + packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf) + assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}' + assert packed_sf.shape == ref_packed_sf.shape + assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())]) + + # Test launch overhead + launch_start_t = time.time_ns() + get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) + launch_end_t = time.time_ns() + + # Performance + t = bench_kineto(lambda: get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf), 'pack_fp32_into_ue8m0') + print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}): ' + f'launch {(launch_end_t - launch_start_t) / 1e3:3.0f} us | {t * 1e6:4.0f} us | ' + f'{count_bytes(fp32_sf, packed_sf) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_sf_layout_kernels() -> None: + print('Testing k-grouped SF layout kernels:') + for mn, ks, num_groups in enumerate_k_grouped_sf_layout(): + sf_ks = [k // 128 for k in ks] + packed_sf_ks = [ceil_div(k, 512) for k in ks] + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + x = torch.randn((sum(ks), mn), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_channel_cast_to_fp8(x, use_ue8m0=True) + + # Correctness + packed_sf = get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks) + split_packed_sf = packed_sf.split(packed_sf_ks) + split_fp32_sf = fp32_sf.split(sf_ks) + for i in range(num_groups): + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(split_fp32_sf[i].T).T + assert torch.equal(split_packed_sf[i], ref_packed_sf), f'{i=}' + + # Performance + t = bench_kineto(lambda: get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks), 'pack_fp32_into_ue8m0') + print(f' > Perf ({num_groups=:3}, {mn=:5}, sum_k={sum(ks):5}):' + f'{t * 1e6:4.0f} us | ' + f'{count_bytes(fp32_sf, packed_sf, ks_tensor) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(1) + random.seed(1) + + test_sf_layout_kernels() + test_k_grouped_sf_layout_kernels() diff --git a/third-party/cutlass b/third-party/cutlass index eefa1713..b244379d 160000 --- a/third-party/cutlass +++ b/third-party/cutlass @@ -1 +1 @@ -Subproject commit eefa171318b79cbe2e78514d4cce5cd0fe919d0c +Subproject commit b244379d9b15574e07b73b814b88bd2233f0b3ce diff --git a/third-party/fmt b/third-party/fmt new file mode 160000 index 00000000..553ec11e --- /dev/null +++ b/third-party/fmt @@ -0,0 +1 @@ +Subproject commit 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28