8000 try relanding cublaslt autotuning support for TunableOp # by bilal2vec · Pull Request #153316 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

try relanding cublaslt autotuning support for TunableOp # #153316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
" but got ", \
X)

namespace {
namespace at::cuda::blas {

static cublasOperation_t _cublasOpFromChar(char op) {
// NOLINTNEXTLINE(bugprone-switch-missing-default-case)
Expand All @@ -124,7 +124,7 @@ static cublasOperation_t _cublasOpFromChar(char op) {
"_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}

static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
Expand All @@ -138,7 +138,7 @@ static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
*lda = std::max<int64_t>(m, 1);
}

static void _cublasAdjustLdLevel3(
void _cublasAdjustLdLevel3(
char transa,
char transb,
int64_t m,
Expand Down Expand Up @@ -217,7 +217,7 @@ static size_t _parseChosenWorkspaceSize() {
return workspace_size * 1024;
}

static size_t _getWorkspaceSize() {
size_t _getWorkspaceSize() {
static size_t workspace_size = _parseChosenWorkspaceSize();
return workspace_size;
}
Expand Down Expand Up @@ -1415,19 +1415,19 @@ inline void gemm_tunable(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(DType, C_Dtype)) {
bool transb_ = ((transb != 'n') && (transb != 'N'));

if (transa_ && transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{&params};
gemm(&params);
}
else if (transa_ && !transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{&params};
gemm(&params);
}
else if (!transa_ && transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{&params};
gemm(&params);
}
else if (!transa_ && !transb_) {
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{};
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{&params};
gemm(&params);
}
else {
Expand Down
99 changes: 99 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,105 @@

namespace at::cuda::blas {

cublasOperation_t _cublasOpFromChar(char op);
void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda);
void _cublasAdjustLdLevel3(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
int64_t* lda,
int64_t* ldb,
int64_t* ldc);
uint32_t _getAlignment(uintptr_t address);
size_t _parseChosenWorkspaceSize();
size_t _getWorkspaceSize();

namespace {
// Following the pattern of CuSparseDescriptor
// Defined here for now because this is the only place cublas_lt interface is
// used but can be moved to a header once cublas_lt interface is used in
// multiple places.
template <typename T, cublasStatus_t (*destructor)(T*)>
struct CuBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};

template <typename T, cublasStatus_t (*destructor)(T*)>
class CuBlasLtDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}

protected:
std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
};

class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
cublasLtMatmulDescOpaque_t,
&cublasLtMatmulDescDestroy> {
public:
CuBlasLtMatmulDescriptor(
cublasComputeType_t compute_type,
cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
cublasLtMatrixLayoutOpaque_t,
&cublasLtMatrixLayoutDestroy> {
public:
CuBlasLtMatrixLayout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
int64_t ld,
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
cublasLtMatmulPreferenceOpaque_t,
&cublasLtMatmulPreferenceDestroy> {
public:
CuBlasLtMatmulPreference() {
cublasLtMatmulPreference_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};

} // namespace

// RAII guard that sets the CuBLAS pointer mode and restores it to
// its previous value when the guard is destroyed
class PointerModeGuard {
Expand Down
Loading
Loading
0