@@ -1587,14 +1587,19 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
1587
1587
@parametrize ("use_torch_compile" , [False , True ])
1588
1588
def test_scaled_grouped_gemm_2d_2d (self , fast_accum , strided , use_torch_compile ):
1589
1589
device = "cuda"
1590
- m , n , k , n_groups = 16 , 16 , 16 , 4 # all sizes have to be divisible by 16
1590
+ m , n , k , n_groups = 16 , 32 , 64 , 4 # all sizes have to be divisible by 16
1591
1591
a = torch .randn (m , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
1592
1592
b = torch .randn (n , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
1593
1593
scale_a = torch .arange (m * n_groups , device = device , dtype = torch .float32 ) / 4
1594
1594
scale_b = torch .arange (n * n_groups , device = device , dtype = torch .float32 ) / 4
1595
1595
offs = torch .arange (k , n_groups * k + 1 , k , device = device , dtype = torch .int32 )
1596
1596
f = torch ._scaled_grouped_mm
1597
- f = torch .compile (f ) if use_torch_compile else f
1597
+ f = torch .compile (
1598
+ f ,
1599
+ options = {
1600
+ "max_autotune" : True ,
1601
+ "max_autotune_gemm_backends" : "TRITON" ,
1602
+ }) if use_torch_compile else f
1598
1603
out = f (a , b .t (), scale_a , scale_b , offs = offs ,
1599
1604
out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1600
1605
offs_cpu = offs .cpu ()
@@ -1618,7 +1623,7 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile)
1618
1623
def test_scaled_grouped_gemm_2d_3d (self , fast_accum , strided , use_torch_compile ):
1619
1624
device = "cuda"
1620
1625
s_int = int (strided )
1621
- m , n , k , n_groups = 16 , 32 , 16 , 4
1626
+ m , n , k , n_groups = 16 , 32 , 64 , 4
1622
1627
a = torch .randn (m * n_groups , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[:, :k ]
1623
1628
b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
1624
1629
self .assertTrue (a .is_contiguous () is not strided )
@@ -1631,7 +1636,12 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
1631
1636
scale_b = torch .ones (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
1632
1637
1633
1638
f = torch ._scaled_grouped_mm
1634
- f = torch .compile (f , dynamic = False ) if use_torch_compile else f
1639
+ f = torch .compile (
1640
+ f ,
1641
+ options = {
1642
+ "max_autotune" : True ,
1643
+ "max_autotune_gemm_backends" : "TRITON" ,
1644
+ }) if use_torch_compile else f
1635
1645
out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
1636
1646
out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1637
1647
@@ -1643,7 +1653,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
1643
1653
ascalelist .append (scale_a [start :offs_cpu [i ]])
1644
1654
outlist .append (out [start :offs_cpu [i ]])
1645
1655
start = offs_cpu [i ]
1646
- self .scaled_grouped_mm_helper (alist , b , ascalelist , scale_b , outlist , fast_accum )
1656
+ self .scaled_grouped_mm_helper (alist , b , ascalelist , scale_b , outlist , fast_accum )
1647
1657
1648
1658
1649
1659
@unittest .skipIf (TEST_WITH_ROCM , "ROCm doesn't support CUTLASS" )
@@ -1655,7 +1665,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
1655
1665
def test_scaled_grouped_gemm_3d_3d (self , fast_accum , strided , use_torch_compile ):
1656
1666
device = "cuda"
1657
1667
s_int = int (strided )
1658
- m , n , k , n_groups = 16 , 32 , 16 , 4
1668
+ m , n , k , n_groups = 16, 32 , 64 , 4
1659
1669
a = torch .randn (n_groups * (1 + s_int ), m , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
1660
1670
b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
1661
1671
self .assertTrue (a .is_contiguous () is not strided )
@@ -1664,7 +1674,12 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile)
1664
1674
scale_b = torch .ones (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
1665
1675
1666
1676
f = torch ._scaled_grouped_mm
1667
- f = torch .compile (f ) if use_torch_compile else f
1677
+ f = torch .compile (
1678
+ f ,
1679
+ options = {
1680
+ "max_autotune" : True ,
1681
+ "max_autotune_gemm_backends" : "TRITON" ,
1682
+ }) if use_torch_compile else f
1668
1683
out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b ,
1669
1684
out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1670
1685
@@ -1680,7 +1695,7 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile)
1680
1695
def test_scaled_grouped_gemm_3d_2d (self , fast_accum , strided , use_torch_compile ):
1681
1696
device = "cuda"
1682
1697
s_int = int (strided )
1683
- m , n , k , n_groups = 16 , 32 , 16 , 4
1698
+ m , n , k , n_groups = 16 , 128 , 64 , 4
1684
1699
a = torch .randn (n_groups * (1 + s_int ), m , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
1685
1700
b = torch .randn (n * n_groups , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[:, :k ]
1686
1701
self .assertTrue (a .is_contiguous () is not strided )
@@ -1693,7 +1708,12 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile)
1693
1708
offs [0 ] = offs [1 ]
1694
1709
1695
1710
f = torch ._scaled_grouped_mm
1696
- f = torch .compile (f ) if use_torch_compile else f
1711
+ f = torch .compile (
1712
+ f ,
1713
+ options = {
1714
+ "max_autotune" : True ,
1715
+ "max_autotune_gemm_backends" : "TRITON" ,
1716
+ }) if use_torch_compile else f
1697
1717
out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
1698
1718
out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
1699
1719
offs_cpu = offs .cpu ()
@@ -1704,7 +1724,7 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile)
1704
1724
bscalelist .append (scale_b [start :offs_cpu [i ]])
1705
1725
outlist .append (out [:, start :offs_cpu [i ]])
1706
1726
start = offs_cpu [i ]
1707
- self .scaled_grouped_mm_helper (a , blist , scale_a , bscalelist , outlist , fast_accum )
1727
+ self .scaled_grouped_mm_helper (a , blist , scale_a , bscalelist , outlist , fast_accum )
1708
1728
1709
1729
1710
1730
@unittest .skipIf (not PLATFORM_SUPPORTS_MX_GEMM , mx_skip_msg )
0 commit comments