@@ -1440,16 +1440,19 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
1440
1440
@unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1441
1441
@parametrize ("fast_accum" , [False , True ])
1442
1442
@parametrize ("strided" , [False , True ])
1443
- def test_scaled_grouped_gemm_2d_2d (self , fast_accum , strided ):
1443
+ @parametrize ("use_torch_compile" , [False , True ])
1444
+ def test_scaled_grouped_gemm_2d_2d (self , fast_accum , strided , use_torch_compile ):
1444
1445
device = "cuda"
1445
1446
m , n , k , n_groups = 16 , 16 , 16 , 4 # all sizes have to be divisible by 16
1446
1447
a = torch .randn (m , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
1447
1448
b = torch .randn (n , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
1448
1449
scale_a = torch .arange (m * n_groups , device = device , dtype = torch .float32 ) / 4
1449
1450
scale_b = torch .arange (n * n_groups , device = device , dtype = torch .float32 ) / 4
1450
1451
offs = torch .arange (k , n_groups * k + 1 , k , device = device , dtype = torch .int32 )
1451
- out = torch ._scaled_grouped_mm (a , b .t (), scale_a , scale_b , offs = offs ,
1452
- out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1452
+ f = torch ._scaled_grouped_mm
1453
+ f = torch .compile (f ) if use_torch_compile else f
1454
+ out = f (a , b .t (), scale_a , scale_b , offs = offs ,
1455
+ out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1453
1456
offs_cpu = offs .cpu ()
1454
1457
alist , blist , ascalelist , bscalelist = [], [], [], []
1455
1458
start = 0
@@ -1466,7 +1469,8 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided):
1466
1469
@unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1467
1470
@parametrize ("fast_accum" , [False , True ])
1468
1471
@parametrize ("strided" , [False , True ])
1469
- def test_scaled_grouped_gemm_2d_3d (self , fast_accum , strided ):
1472
+ @parametrize ("use_torch_compile" , [False , True ])
1473
+ def test_scaled_grouped_gemm_2d_3d (self , fast_accum , strided , use_torch_compile ):
1470
1474
device = "cuda"
1471
1475
s_int = int (strided )
1472
1476
m , n , k , n_groups = 16 , 32 , 16 , 4
@@ -1478,8 +1482,10 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
1478
1482
scale_a = torch .arange (n_groups * m , device = "cuda" , dtype = torch .float32 )
1479
1483
scale_b = torch .ones (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
1480
1484
1481
- out = torch ._scaled_grouped_mm (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
1482
- out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1485
+ f = torch ._scaled_grouped_mm
1486
+ f = torch .compile (f ) if use_torch_compile else f
1487
+ out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
1488
+ out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1483
1489
1484
1490
offs_cpu = offs .cpu ()
1485
1491
alist , ascalelist , outlist = [], [], []
@@ -1496,7 +1502,8 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
1496
1502
@unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1497
1503
@parametrize ("fast_accum" , [False , True ])
1498
1504
@parametrize ("strided" , [False , True ])
1499
- def test_scaled_grouped_gemm_3d_3d (self , fast_accum , strided ):
1505
+ @parametrize ("use_torch_compile" , [False , True ])
1506
+ def test_scaled_grouped_gemm_3d_3d (self , fast_accum , strided , use_torch_compile ):
1500
1507
device = "cuda"
1501
1508
s_int = int (strided )
1502
1509
m , n , k , n_groups = 16 , 32 , 16 , 4
@@ -1507,8 +1514,10 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided):
1507
1514
scale_a = torch .ones (n_groups * m , device = "cuda" , dtype = torch .float32 ).view (n_groups , m )
1508
1515
scale_b = torch .ones (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
1509
1516
1510
- out = torch ._scaled_grouped_mm (a , b .transpose (- 2 , - 1 ), scale_a , scale_b ,
1511
- out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1517
+ f = torch ._scaled_grouped_mm
1518
+ f = torch .compile (f ) if use_torch_compile else f
1519
+ out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b ,
1520
+ out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1512
1521
1513
1522
self .scaled_grouped_mm_helper (a , b , scale_a , scale_b , out , fast_accum )
1514
1523
@@ -1517,7 +1526,8 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided):
1517
1526
@unittest .skipIf (not SM90OrLater , "Grouped gemm supported on SM90" )
1518
1527
@parametrize ("fast_accum" , [False , True ])
1519
1528
@parametrize ("strided" , [False , True ])
1520
- def test_scaled_grouped_gemm_3d_2d (self , fast_accum , strided ):
1529
+ @parametrize ("use_torch_compile" , [False , True ])
1530
+ def test_scaled_grouped_gemm_3d_2d (self , fast_accum , strided , use_torch_compile ):
1521
1531
device = "cuda"
1522
1532
s_int = int (strided )
1523
1533
m , n , k , n_groups = 16 , 32 , 16 , 4
@@ -1529,8 +1539,10 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided):
1529
1539
scale_b = torch .arange (n_groups * n , device = "cuda" , dtype = torch .float32 )
1530
1540
offs = torch .arange (n , n_groups * n + 1 , n , device = "cuda" , dtype = torch .int32 )
1531
1541
1532
- out = torch ._scaled_grouped_mm (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
1533
- out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1542
+ f = torch ._scaled_grouped_mm
1543
+ f = torch .compile (f ) if use_torch_compile else f
1544
+ out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
1545
+ out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1534
1546
offs_cpu = offs .cpu ()
1535
1547
blist , bscalelist , outlist = [], [], []
1536
1548
start = 0
0 commit comments