8000 AMD/ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support by petrex · Pull Request #151360 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1847,8 +1847,12 @@ int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fa
switch (scaling_type) {
case ScalingType::BlockWise1x32:
TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu);
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
#ifdef USE_ROCM
return HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
#else
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
#endif // USE_ROCM
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
Expand Down Expand Up @@ -1946,12 +1950,26 @@ void scaled_gemm(
// hipblaslt supported row-wise before cublas, and did so their own way (via
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
// the SCALE_MODEs). Here we check for this early custom mode.
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) {
if (use_rowwise) {
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
}
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
else if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check need not be inside #if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)

#if ROCM_VERSION >= 70000
if (at::detail::getCUDAHooks().isGPUArch({"gfx950"})) {
// TODO: add constraints based on hipblaslt internals
TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0),
"Matrix dimensions must be multiples of 32 for MX format. "
"Got m=", m, ", n=", n, ", k=", k);
}
#endif
}
#else
// rowwise isn't supported using cublaslt or older hipblaslt
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
if (result_scale_ptr != nullptr) {
Expand Down Expand Up @@ -1990,15 +2008,16 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}

// The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
[[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
[[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
#endif
// For other data types, use the get_scale_mode function based on scaling type
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
// Note that AMD/ROCm follows OCP Spec 1.0, which is different from NVIDIA's implementation. See get_scale_mode() for details.
[[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
[[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000 && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
#endif // if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000 && defined(HIPBLASLT_OUTER_VEC))

CuBlasLtMatmulPreference preference;
auto ltworkspace = CublasLtWorkspace();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDADataType.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
case c10::ScalarType::Float4_e2m1fn_x2:
return CUDA_R_4F_E2M1;
#endif
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/cuda/tunable/GemmHipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
return static_cast<hipDataType>(500);
}

template <>
constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
#if ROCM_VERSION >= 70000
return HIP_R_4F_E2M1;
#else
return static_cast<hipDataType>(33);
#endif
}

template <typename T>
int GetBatchFromParams(const GemmParams<T>* params) {
return 1;
Expand Down
54 changes: 46 additions & 8 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line 6D40 number Diff line number Diff line change
Expand Up @@ -1283,15 +1283,35 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
if (use_fast_accum) {
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
}
#ifdef USE_ROCM
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
TORCH_CHECK(ROCM_VERSION >= 70000, "Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above");
}
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
TORCH_CHECK(ROCM_VERSION >= 60500, "Float8_e5m2 is only supported for ROCm 6.5 and above");
}
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
TORCH_CHECK(ROCM_VERSION >= 60500, "Float8_e4m3fn is only supported for ROCm 6.5 and above");
}
#endif
if (bias) {
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
"Bias must be either Half or BFloat16, but got ", bias->scalar_type());
TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
bias->scalar_type() == ScalarType::BFloat16,
"Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
"Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
TORCH_CHECK(out.scalar_type() != kFloat,
"Bias is not supported when out_dtype is set to Float32");

TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 ||
bias->scalar_type() == ScalarType::Half,
"Bias must be BFloat16 or Half, but got ", bias->scalar_type());

TORCH_CHECK((out.scalar_type() != kFloat &&
out.scalar_type() != ScalarType::BFloat16) ||
bias->scalar_type() == ScalarType::BFloat16,
"Bias must be BFloat16 to compute ", out.scalar_type(),
" output, but got ", bias->scalar_type());

TORCH_CHECK(out.scalar_type() != ScalarType::Half ||
bias->scalar_type() == ScalarType::Half,
"Bias must be Float16 to compute ", out.scalar_type(),
" output, but got ", bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
Expand Down Expand Up @@ -1353,6 +1373,22 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
}
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
#if ROCM_VERSION >= 70000
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");

TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&
mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0,
"Matrix dimensions must be multiples of 32 for block-wise scaling");

TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
}
#endif

cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
Expand Down Expand Up @@ -1430,12 +1466,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
Expand Down
2 changes: 2 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
parametrize,
serialTest,
skipIfHpu,
skipIfRocm,
skipIfWindows,
TEST_WITH_ROCM,
)
Expand Down Expand Up @@ -7405,6 +7406,7 @@ def f(x, s0, s1, s2):
out = f_compiled(x, s0, s1, s2)
self.assertEqual(out_ref, out)

@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
@requires_cuda
def test_partitioner_saves_weights_for_bw(self):
Expand Down
Loading
Loading
0