You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Enhance support for Float8 and Float4 data types in scaled_gemm and related functions. Update error messages to reflect ROCm 6.5 compatibility. Add HIP data type mapping for Float4_e2m1fn_x2. Ensure proper version checks for ROCm in CUDA operations.
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
1273
1273
}
1274
+
#ifdef USE_ROCM
1275
+
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
1276
+
TORCH_CHECK(ROCM_VERSION >= 60500, "Float4_e2m1fn_x2 is only supported for ROCm 6.5 and above");
1277
+
}
1278
+
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
1279
+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e5m2 is only supported for ROCm 6.5 and above");
1280
+
}
1281
+
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
1282
+
TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e4m3fn is only supported for ROCm 6.5 and above");
1283
+
}
1284
+
#endif
1274
1285
if (bias) {
1275
1286
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
0 commit comments