8000 Enhance support for Float8 and Float4 data types in scaled_gemm and r… · pytorch/pytorch@87d6e71 · GitHub
[go: up one dir, main page]

Skip to content

Commit 87d6e71

Browse files
author
Peter Y. Yeh
committed
Enhance support for Float8 and Float4 data types in scaled_gemm and related functions. Update error messages to reflect ROCm 6.5 compatibility. Add HIP data type mapping for Float4_e2m1fn_x2. Ensure proper version checks for ROCm in CUDA operations.
1 parent d99e956 commit 87d6e71

File tree

4 files changed

+29
-10
lines changed

4 files changed

+29
-10
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,18 +1910,18 @@ void scaled_gemm(
19101910
}
19111911

19121912
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1913-
#if CUDA_VERSION >= 12080
1913+
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 60500)
19141914
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
19151915
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
19161916
#else
1917-
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
1917+
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 or ROCm 6.5 and above");
19181918
#endif // if CUDA_VERSION >= 12080
19191919
} else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) {
1920-
#if CUDA_VERSION >= 12080
1920+
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 60500)
19211921
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
19221922
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
19231923
#else
1924-
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above");
1924+
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 or ROCm 6.5 and above");
19251925
#endif // if CUDA_VERSION >= 12080
19261926
}
19271927

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
8585
return static_cast<hipDataType>(500);
8686
}
8787

88+
template <>
89+
constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
90+
#if ROCM_VERSION >= 60500
91+
return HIP_R_4F_E2M1;
92+
#else
93+
return static_cast<hipDataType>(30);
94+
#endif
95+
}
96+
8897
template <typename T>
8998
int GetBatchFromParams(const GemmParams<T>* params) {
9099
return 1;

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,15 +1271,21 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
12711271
if (use_fast_accum) {
12721272
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
12731273
}
1274+
#ifdef USE_ROCM
1275+
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
1276+
TORCH_CHECK(ROCM_VERSION >= 60500, "Float4_e2m1fn_x2 is only supported for ROCm 6.5 and above");
1277+
}
1278+
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
1279+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e5m2 is only supported for ROCm 6.5 and above");
1280+
}
1281+
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
1282+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e4m3fn is only supported for ROCm 6.5 and above");
1283+
}
1284+
#endif
12741285
if (bias) {
12751286
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
12761287
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
1277-
"Bias must be either Half or BFloat16, but got ", bias->scalar_type());
1278-
TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
1279-
bias->scalar_type() == ScalarType::BFloat16,
1280-
"Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
1281-
TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
1282-
"Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
1288+
"Bias must be BFloat16 or Half, but got ", bias->scalar_type());
12831289
}
12841290
{
12851291
auto bias_ = bias.value_or(Tensor());

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7345,6 +7345,10 @@
73457345
("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
73467346
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
73477347
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
7348+
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7349+
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
7350+
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)),
7351+
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", CONV_MATH_FUNC, API_BLAS)),
73487352
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
73497353
("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
73507354
("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),

0 commit comments

Comments
 (0)
0