10000 Revert "[WIP] Initial implementation of Grouped Gemm API (#148531)" · pytorch/pytorch@c983e11 · GitHub
[go: up one dir, main page]

Skip to content

Commit c983e11

Browse files
Revert "[WIP] Initial implementation of Grouped Gemm API (#148531)"
This reverts commit ff29791. Reverted #148531 on behalf of https://github.com/janeyx99 due to Sorry but this broke ROCm jobs on trunk ([comment](#148531 (comment)))
1 parent f1787ee commit c983e11

File tree

9 files changed

+5
-1606
lines changed

9 files changed

+5
-1606
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 0 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <cstdint>
22
#include <c10/util/typeid.h>
33
#include <c10/util/Exception.h>
4-
#include <c10/util/SmallVector.h>
54
#include <c10/core/Scalar.h>
65
#include <c10/core/ScalarType.h>
76
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
@@ -17,7 +16,6 @@
1716
#include <ATen/native/Resize.h>
1817
#include <c10/util/MaybeOwned.h>
1918
#include <ATen/native/cuda/RowwiseScaledMM.h>
20-
#include <ATen/native/cuda/ScaledGroupMM.h>
2119

2220
#ifndef AT_PER_OPERATOR_HEADERS
2321
#include <ATen/Functions.h>
@@ -1365,84 +1363,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13651363
return out;
13661364
}
13671365

1368-
namespace {
1369-
c10::SmallVector<int64_t, 3> compute_grouped_gemm_output_size(const Tensor& mat_a,
1370-
const Tensor& mat_b,
1371-
const std::optional<at::Tensor>& offs
1372-
) {
1373-
const bool a_is_2d = mat_a.dim() == 2;
1374-
const bool b_is_2d = mat_b.dim() == 2;
1375-
if (a_is_2d) {
1376-
if (b_is_2d) {
1377-
return {offs->size(0), mat_a.size(0), mat_b.size(1)};
1378-
} else {
1379-
TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
1380-
return {mat_a.size(0), mat_b.size(-1)};
1381-
}
1382-
} else {
1383-
if (b_is_2d) {
1384-
// this case is not actually encountered for MoE gemms
1385-
TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
1386-
return {mat_a.size(1), mat_b.size(1)};
1387-
} else { // regular bmm
1388-
TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
1389-
return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
1390-
}
1391-
}
1392-
}
1393-
1394-
bool transposed(const Tensor& mat) {
1395-
IntArrayRef tensor_strides = mat.strides();
1396-
IntArrayRef tensor_sizes = mat.sizes();
1397-
int end_dim = mat.dim() - 1;
1398-
if ((tensor_strides[end_dim - 1] == 1) && (tensor_strides[end_dim] >= std::max<int64_t>(1, tensor_sizes[end_dim - 1]))) {
1399-
return true;
1400-
} else if ((tensor_strides[end_dim] == 1) && (tensor_strides[end_dim - 1] >= std::max<int64_t>(1, tensor_sizes[end_dim]))) {
1401-
return false;
1402-
} else {
1403-
TORCH_CHECK(false, "Tensor should not be self-overlapping");
1404-
}
1405-
}
1406-
1407-
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
1408-
if (mat.dim() == 2) {
1409-
TORCH_CHECK(
1410-
scale.dim() == 1,
1411-
"scale must be a 1D tensor, but got ",
1412-
scale.dim(),
1413-
"D, arg ",
1414-
arg_idx);
1415-
TORCH_CHECK(
1416-
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
1417-
TORCH_CHECK(
1418-
scale.size(0) == mat.size(dim) * scale_multiplier,
1419-
"scale must have the same length as mat for arg ",
1420-
arg_idx);
1421-
} else {
1422-
TORCH_CHECK(
1423-
scale.dim() == 2,
1424-
"scale must be a 2D tensor, but got ",
1425-
scale.dim(),
1426-
"D for arg ",
1427-
arg_idx);
1428-
TORCH_CHECK(
1429-
scale.stride(1),
1430-
"scale_a must be contiguous in the last dimension for arg ",
1431-
arg_idx);
1432-
TORCH_CHECK(
1433-
scale.size(0) == mat.size(0),
1434-
"scale must have the same batch dimension as mat for arg ",
1435-
arg_idx);
1436-
TORCH_CHECK(
1437-
scale.size(1) == mat.size(1 + dim),
1438-
"scale must have the same first dimension as mat for arg ",
1439-
arg_idx);
1440-
}
1441-
}
1442-
1443-
1444-
}
1445-
14461366
Tensor
14471367
_scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
14481368
const Tensor& scale_a,
@@ -1456,82 +1376,4 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
14561376
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
14571377
}
14581378

1459-
1460-
Tensor
1461-
_scaled_grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
1462-
const Tensor& scale_a, const Tensor& scale_b,
1463-
const std::optional<at::Tensor>& offs,
1464-
const std::optional<at::Tensor>& bias,
1465-
const std::optional<at::Tensor>& scale_result,
1466-
std::optional<c10::ScalarType> out_dtype,
1467-
bool use_fast_accum) {
1468-
#ifndef USE_ROCM
1469-
bool allowed_device = _scaled_mm_allowed_device();
1470-
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
1471-
1472-
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
1473-
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
1474-
TORCH_CHECK(!transposed(mat_a), "Expected mat1 to not be transposed");
1475-
TORCH_CHECK(transposed(mat_b), "Expected mat2 to be transposed");
1476-
TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
1477-
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
1478-
const bool a_is_2d = mat_a.dim() == 2;
1479-
const bool b_is_2d = mat_b.dim() == 2;
1480-
TORCH_CHECK(
1481-
mat_a.size(-1) % 16 == 0,
1482-
"Expected trailing dimension of mat_a to be divisible by 16 ",
1483-
"but got mat1 shape: (",
1484-
mat_a.sizes(),
1485-
").");
1486-
TORCH_CHECK(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
1487-
"Expected mat_b shape to be divisible by 16 ",
1488-
"but got mat_b shape: (",
1489-
mat_b.sizes(),
1490-
").");
1491-
1492-
1493-
1494-
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
1495-
1496-
if (offs.has_value()) {
1497-
TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
1498-
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
1499-
}
1500-
1501-
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
1502-
TORCH_CHECK(
1503-
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
1504-
"Both scale_a and scale_b must be float (fp32) tensors.");
1505-
1506-
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
1507-
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
1508-
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
1509-
1510-
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
1511-
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
1512-
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
1513-
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
1514-
1515-
1516-
at::cuda::detail::f8f8bf16_grouped_mm(
1517-
mat_a,
1518-
mat_b,
1519-
scale_a,
1520-
scale_b,
1521-
offs,
1522-
bias,
1523-
use_fast_accum,
1524-
out);
1525-
return out;
1526-
1527-
1528-
1529-
1530-
#else
1531-
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
1532-
#endif
1533-
1534-
}
1535-
1536-
15371379
} // namespace at::native

aten/src/ATen/native/cuda/RowwiseScaledMM.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ void dispatch_fp8_rowwise_kernel_on_input_dtypes(
946946
}
947947
}
948948

949+
template <typename... Types>
949950
void dispatch_fp8_rowwise_kernel_on_bias_dtype(
950951
at::Tensor XQ,
951952
at::Tensor WQ,
@@ -956,13 +957,12 @@ void dispatch_fp8_rowwise_kernel_on_bias_dtype(
956957
at::Tensor out) {
957958
if (bias.has_value() && bias->dtype() == at::kBFloat16) {
958959
dispatch_fp8_rowwise_kernel_on_input_dtypes<
959-
cutlass::bfloat16_t>
960-
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
960+
cutlass::bfloat16_t,
961+
Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
961962
} else {
962963
dispatch_fp8_rowwise_kernel_on_input_dtypes<
963-
float>
964-
//Types...>
965-
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
964+
float,
965+
Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
966966
}
967967
}
968968

0 commit comments

Comments
 (0)
0