1
1
#include < cstdint>
2
2
#include < c10/util/typeid.h>
3
3
#include < c10/util/Exception.h>
4
- #include < c10/util/SmallVector.h>
5
4
#include < c10/core/Scalar.h>
6
5
#include < c10/core/ScalarType.h>
7
6
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
17
16
#include < ATen/native/Resize.h>
18
17
#include < c10/util/MaybeOwned.h>
19
18
#include < ATen/native/cuda/RowwiseScaledMM.h>
20
- #include < ATen/native/cuda/ScaledGroupMM.h>
21
19
22
20
#ifndef AT_PER_OPERATOR_HEADERS
23
21
#include < ATen/Functions.h>
@@ -1365,84 +1363,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
1365
1363
return out;
1366
1364
}
1367
1365
1368
- namespace {
1369
- c10::SmallVector<int64_t , 3 > compute_grouped_gemm_output_size (const Tensor& mat_a,
1370
- const Tensor& mat_b,
1371
- const std::optional<at::Tensor>& offs
1372
- ) {
1373
- const bool a_is_2d = mat_a.dim () == 2 ;
1374
- const bool b_is_2d = mat_b.dim () == 2 ;
1375
- if (a_is_2d) {
1376
- if (b_is_2d) {
1377
- return {offs->size (0 ), mat_a.size (0 ), mat_b.size (1 )};
1378
- } else {
1379
- TORCH_CHECK (offs->size (0 ) == mat_b.size (0 ), " matrix batch sizes have to match" );
1380
- return {mat_a.size (0 ), mat_b.size (-1 )};
1381
- }
1382
- } else {
1383
- if (b_is_2d) {
1384
- // this case is not actually encountered for MoE gemms
1385
- TORCH_CHECK (offs->size (0 ) == mat_a.size (0 ), " matrix batch sizes have to match" );
1386
- return {mat_a.size (1 ), mat_b.size (1 )};
1387
- } else { // regular bmm
1388
- TORCH_CHECK (mat_a.size (0 ) == mat_b.size (0 ), " batched dimension has to match" );
1389
- return {mat_a.size (0 ), mat_a.size (1 ), mat_b.size (-1 )};
1390
- }
1391
- }
1392
- }
1393
-
1394
- bool transposed (const Tensor& mat) {
1395
- IntArrayRef tensor_strides = mat.strides ();
1396
- IntArrayRef tensor_sizes = mat.sizes ();
1397
- int end_dim = mat.dim () - 1 ;
1398
- if ((tensor_strides[end_dim - 1 ] == 1 ) && (tensor_strides[end_dim] >= std::max<int64_t >(1 , tensor_sizes[end_dim - 1 ]))) {
1399
- return true ;
1400
- } else if ((tensor_strides[end_dim] == 1 ) && (tensor_strides[end_dim - 1 ] >= std::max<int64_t >(1 , tensor_sizes[end_dim]))) {
1401
- return false ;
1402
- } else {
1403
- TORCH_CHECK (false , " Tensor should not be self-overlapping" );
1404
- }
1405
- }
1406
-
1407
- void check_scale (const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1 ) {
1408
- if (mat.dim () == 2 ) {
1409
- TORCH_CHECK (
1410
- scale.dim () == 1 ,
1411
- " scale must be a 1D tensor, but got " ,
1412
- scale.dim (),
1413
- " D, arg " ,
1414
- arg_idx);
1415
- TORCH_CHECK (
1416
- scale.is_contiguous (), " scale_a must be contiguous for arg " , arg_idx);
1417
- TORCH_CHECK (
1418
- scale.size (0 ) == mat.size (dim) * scale_multiplier,
1419
- " scale must have the same length as mat for arg " ,
1420
- arg_idx);
1421
- } else {
1422
- TORCH_CHECK (
1423
- scale.dim () == 2 ,
1424
- " scale must be a 2D tensor, but got " ,
1425
- scale.dim (),
1426
- " D for arg " ,
1427
- arg_idx);
1428
- TORCH_CHECK (
1429
- scale.stride (1 ),
1430
- " scale_a must be contiguous in the last dimension for arg " ,
1431
- arg_idx);
1432
- TORCH_CHECK (
1433
- scale.size (0 ) == mat.size (0 ),
1434
- " scale must have the same batch dimension as mat for arg " ,
1435
- arg_idx);
1436
- TORCH_CHECK (
1437
- scale.size (1 ) == mat.size (1 + dim),
1438
- " scale must have the same first dimension as mat for arg " ,
1439
- arg_idx);
1440
- }
1441
- }
1442
-
1443
-
1444
- }
1445
-
1446
1366
Tensor
1447
1367
_scaled_mm_cuda (const Tensor& mat_a, const Tensor& mat_b,
1448
1368
const Tensor& scale_a,
@@ -1456,82 +1376,4 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
1456
1376
return _scaled_mm_out_cuda (mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
1457
1377
}
1458
1378
1459
-
1460
- Tensor
1461
- _scaled_grouped_mm_cuda (const Tensor& mat_a, const Tensor& mat_b,
1462
- const Tensor& scale_a, const Tensor& scale_b,
1463
- const std::optional<at::Tensor>& offs,
1464
- const std::optional<at::Tensor>& bias,
1465
- const std::optional<at::Tensor>& scale_result,
1466
- std::optional<c10::ScalarType> out_dtype,
1467
- bool use_fast_accum) {
1468
- #ifndef USE_ROCM
1469
- bool allowed_device = _scaled_mm_allowed_device ();
1470
- TORCH_CHECK (allowed_device, " torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+" );
1471
-
1472
- TORCH_CHECK (mat_a.dtype () == at::kFloat8_e4m3fn , " Expected mat_a to be Float8_e4m3 matrix got " , mat_a.scalar_type ());
1473
- TORCH_CHECK (mat_b.dtype () == at::kFloat8_e4m3fn , " Expected mat_a to be Float8_e4m3 matrix got " , mat_b.scalar_type ());
1474
- TORCH_CHECK (!transposed (mat_a), " Expected mat1 to not be transposed" );
1475
- TORCH_CHECK (transposed (mat_b), " Expected mat2 to be transposed" );
1476
- TORCH_CHECK (mat_a.dim () == 2 || mat_a.dim () == 3 , " mat_a has to be 2 or 3d" );
1477
- TORCH_CHECK (mat_b.dim () == 2 || mat_b.dim () == 3 , " mat_b has to be 2 or 3d" );
1478
- const bool a_is_2d = mat_a.dim () == 2 ;
1479
- const bool b_is_2d = mat_b.dim () == 2 ;
1480
- TORCH_CHECK (
1481
- mat_a.size (-1 ) % 16 == 0 ,
1482
- " Expected trailing dimension of mat_a to be divisible by 16 " ,
1483
- " but got mat1 shape: (" ,
1484
- mat_a.sizes (),
1485
- " )." );
1486
- TORCH_CHECK (mat_b.size (-2 ) % 16 == 0 && mat_b.size (-1 ) % 16 == 0 ,
1487
- " Expected mat_b shape to be divisible by 16 " ,
1488
- " but got mat_b shape: (" ,
1489
- mat_b.sizes (),
1490
- " )." );
1491
-
1492
-
1493
-
1494
- TORCH_CHECK (offs.has_value () == (a_is_2d || b_is_2d), " Have to provide offsets if there is a 2d matrix" );
1495
-
1496
- if (offs.has_value ()) {
1497
- TORCH_CHECK (offs->dim () == 1 , " offs has to be 1D" );
1498
- TORCH_CHECK (offs->dtype () == at::kInt , " Offsets have to be int32" );
1499
- }
1500
-
1501
- // Both Per-Tensor and Row-wise scaling expect fp32 tensors
1502
- TORCH_CHECK (
1503
- scale_a.scalar_type () == kFloat && scale_b.scalar_type () == kFloat ,
1504
- " Both scale_a and scale_b must be float (fp32) tensors." );
1505
-
1506
- const int scale_multiplier = (mat_a.dim () == 2 && mat_b.dim () == 2 ) ? offs->size (0 ) : 1 ;
1507
- check_scale (mat_a, scale_a, 0 ,0 , scale_multiplier);
1508
- check_scale (mat_b, scale_b, 1 , 1 , scale_multiplier);
1509
-
1510
- const auto out_dtype_ = out_dtype.value_or (mat_a.scalar_type ());
1511
- TORCH_CHECK (out_dtype_ == kBFloat16 , " Only bf16 high precision output types are supported for grouped gemm" );
1512
- const auto out_size = compute_grouped_gemm_output_size (mat_a, mat_b, offs);
1513
- Tensor out = at::empty (out_size, mat_a.options ().dtype (out_dtype_));
1514
-
1515
-
1516
- at::cuda::detail::f8f8bf16_grouped_mm (
1517
- mat_a,
1518
- mat_b,
1519
- scale_a,
1520
- scale_b,
1521
- offs,
1522
- bias,
1523
- use_fast_accum,
1524
- out);
1525
- return out;
1526
-
1527
-
1528
-
1529
-
1530
- #else
1531
- TORCH_CHECK (false , " grouped gemm is not supported on ROCM" )
1532
- #endif
1533
-
1534
- }
1535
-
1536
-
1537
1379
} // namespace at::native
0 commit comments