8000 Add more GPU architectures support by RayWang96 · Pull Request #112 · deepseek-ai/DeepGEMM · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git
url = git@github.com:NVIDIA/cutlass.git
[submodule "third-party/fmt"]
path = third-party/fmt
url = git@github.com:fmtlib/fmt.git
55 changes: 22 additions & 33 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,44 +1,33 @@
# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT
# TODO: add CUDA utils' library via CMake
cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
set(CMAKE_VERBOSE_MAKEFILE ON)

find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)

file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
execute_process(
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
RESULT_VARIABLE NVCC_RESULT
OUTPUT_VARIABLE NVCC_OUTPUT
ERROR_VARIABLE NVCC_ERROR_OUTPUT
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")

if (NVCC_RESULT EQUAL "0")
set(NVCC_SUPPORTS_SM90 TRUE)
message(STATUS "NVCC supports SM90")
else()
message(STATUS "NVCC does not support SM90")
endif()
set(USE_SYSTEM_NVTX on)
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")

if (NVCC_SUPPORTS_SM90)
set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE)
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10")
# The main Python API entrance
pybind11_add_module(deep_gemm_cpp csrc/python_api.cpp)
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} torch_python cuda)

cuda_add_library(example_gemm STATIC indexing/main.cu)
# Enable kernel code indexing with CMake-based IDEs
cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu)
168 changes: 53 additions & 115 deletions README.md

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions csrc/indexing/main.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/smxx_layout.cuh>

using namespace deep_gemm;

int main() {
return 0;
}
31 changes: 31 additions & 0 deletions csrc/jit/cache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include <filesystem>
#include <memory>
#include <unordered_map>

#include "kernel_runtime.hpp"

namespace deep_gemm {

class KernelRuntimeCache {
std::unordered_map<std::filesystem::path, std::shared_ptr<KernelRuntime>> cache;

public:
// TODO: consider cache capacity
KernelRuntimeCache() = default;

std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
// Hit the runtime cache
if (const auto& iterator = cache.find(dir_path); iterator != cache.end())
return iterator->second;

if (KernelRuntime::check_validity(dir_path))
return cache[dir_path] = std::make_shared<KernelRuntime>(dir_path);
return nullptr;
}
};

static auto kernel_runtime_cache = std::make_shared<KernelRuntimeCache>();

} // namespace deep_gemm
172 changes: 172 additions & 0 deletions csrc/jit/compiler.hpp
4B12
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <filesystem>
#include <fstream>
#include <regex>
#include <string>

#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/hash.hpp"
#include "../utils/system.hpp"
#include "cache.hpp"
#include "device_runtime.hpp"

namespace deep_gemm {

class Compiler {
std::string library_version;
std::filesystem::path library_root_path;

std::string get_library_version() const {
// Recursively walk through all subdirectories and update hash
std::stringstream ss;
for (const auto& entry: std::filesystem::recursive_directory_iterator(library_include_path / "deep_gemm")) {
if (entry.is_regular_file() and entry.path().extension() == ".cuh") {
std::ifstream file(entry.path(), std::ios::binary);
std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator<char>());
ss << content;
}
}
return get_hex_digest(ss.str());
}

public:
std::string signature, flags;
std::filesystem::path library_include_path;
std::filesystem::path cache_dir_path;

explicit Compiler(const std::filesystem::path& library_root_path) {
// Static library paths
this->library_root_path = library_root_path;
this->library_include_path = library_root_path / "include";
this->library_version = get_library_version();

// Cache settings
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
if (const auto& env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty())
cache_dir_path = env_cache_dir_path;

// The compiler flags applied to all derived compilers
signature = "unknown-compiler";
std::string ptxas_flags = "--ptxas-options=--register-usage-level=10";
if (get_env<int>("DG_JIT_PTXAS_VERBOSE", 0))
ptxas_flags += ",--verbose";
flags = fmt::format("-std=c++20 --diag-suppress=39,161,174,177,186,940 {}", ptxas_flags);
}

virtual ~Compiler() = default;

std::filesystem::path make_tmp_dir() const {
return make_dirs(cache_dir_path / "tmp");
}

std::filesystem::path get_tmp_file_path() const {
return make_tmp_dir() / get_uuid();
}

void put(const std::filesystem::path& path, const std::string& data) const {
const auto tmp_file_path = get_tmp_file_path();

// Write into the temporary file
std::ofstream out(tmp_file_path, std::ios::binary);
DG_HOST_ASSERT(out.write(data.data(), data.size()));
out.close();

// Atomically replace
std::filesystem::rename(tmp_file_path, path);
}

std::shared_ptr<KernelRuntime> build(const std::string& name, const std::string& code) const {
const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code);
const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature));

// Hit the runtime cache
if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr)
return runtime;

// Create the kernel directory
make_dirs(dir_path);

// Compile into a temporary CUBIN
const auto tmp_cubin_path = get_tmp_file_path();
compile(code, dir_path, tmp_cubin_path);

// Replace into the cache directory
make_dirs(dir_path);
std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin");

// Put into the runtime cache
const auto& runtime = kernel_runtime_cache->get(dir_path);
DG_HOST_ASSERT(runtime != nullptr);
return runtime;
}

virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0;
};

class NVCCCompiler final: public Compiler {
std::filesystem::path nvcc_path;

std::pair<int, int> get_nvcc_version() const {
DG_HOST_ASSERT(std::filesystem::exists(nvcc_path));

// Call the version command
const auto& command = std::string(nvcc_path) + " --version";
const auto& [return_code, output] = call_external_command(command);
DG_HOST_ASSERT(return_code == 0);

// The version should be at least 12.3, for the best performance with 12.9
int major, minor;
std::smatch match;
DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))")));
std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor);
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3");
if (major < 12 or (major == 12 and minor < 9))
printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance");
return {major, minor};
}

public:
NVCCCompiler(const std::filesystem::path& library_root_path,
const std::filesystem::path& cuda_home_path_by_torch):
Compiler(library_root_path) {
// Override the compiler signature
nvcc_path = cuda_home_path_by_torch / "bin" / "nvcc";
if (const auto& env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
nvcc_path = env_nvcc_path;
const auto& [nvcc_major, nvcc_minor] = get_nvcc_version();
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);

// The override the compiler flags
flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a "
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
flags, library_include_path.c_str(), device_runtime->get_arch());
}

void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
// Write the code into the cache directory
const auto& code_path = dir_path / "kernel.cu";
put(code_path, code);

// Compile
const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC command: %s", command.c_str());
const auto& [return_code, output] = call_external_command(command);
if (return_code != 0) {
printf("NVCC compilation failed: %s", output.c_str());
DG_HOST_ASSERT(false and "NVCC compilation failed");
}

// Print PTXAS log
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
printf("%s", output.c_str());
}
};

static std::shared_ptr<Compiler> compiler = nullptr;

} // namespace deep_gemm
50 changes: 50 additions & 0 deletions csrc/jit/device_runtime.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

#include <ATen/cuda/CUDAContext.h>

#include "../utils/exception.hpp"

namespace deep_gemm {

class DeviceRuntime {
int num_sms = 0;
std::shared_ptr<cudaDeviceProp> cached_prop;

public:
explicit DeviceRuntime() = default;

std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr)
cached_prop = std::make_shared<cudaDeviceProp>(*at::cuda::getCurrentDeviceProperties());
return cached_prop;
}

std::pair<int, int> get_arch_pair() {
const auto prop = get_prop();
return {prop->major, prop->minor};
}

int get_arch() {
const auto& [major, minor] = get_arch_pair();
return major * 10 + minor;
}

int get_arch_major() {
return get_arch_pair().first;
}

void set_num_sms(const int& new_num_sms) {
DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount);
num_sms = new_num_sms;
}

int get_num_sms() {
if (num_sms == 0)
num_sms = get_prop()->multiProcessorCount;
return num_sms;
}
};

static auto device_runtime = std::make_shared<DeviceRuntime>();

} // namespace deep_gemm
Loading
0