8000 basic compile support for grouped_mm · pytorch/pytorch@89a051a · GitHub
[go: up one dir, main page]

Skip to content

Commit 89a051a

Browse files
committed
basic compile support for grouped_mm
ghstack-source-id: 3c696c5 Pull Request resolved: #153384
1 parent daca611 commit 89a051a

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 45 additions & 0 deletions
< 10000 /tr>
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
PLATFORM_SUPPORTS_FLASH_ATTENTION,
7373
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
7474
SM80OrLater,
75+
SM90OrLater,
7576
TEST_CUDNN,
7677
tf32_on_and_off,
7778
with_tf32_off,
@@ -13741,6 +13742,50 @@ def forward(
1374113742
)
1374213743
torch._inductor.aot_compile(traced, inputs)
1374313744

13745+
@skipCUDAIf(not SM90OrLater, "Requires sm90")
13746+
@requires_gpu()
13747+
@config.patch(implicit_fallbacks=True)
13748+
def test_grouped_mm(self):
13749+
@torch.compile(fullgraph=True)
13750+
def f(a, b, offs, out_dtype):
13751+
return torch._grouped_mm(
13752+
a, b.transpose(-2, -1), offs=offs, out_dtype=out_dtype
13753+
)
13754+
13755+
device = "cuda"
13756+
dtype = torch.bfloat16
13757+
13758+
m, n, k, n_groups = 16, 32, 16, 4
13759+
a_ref = torch.randn(m * n_groups, k, device=device, dtype=dtype)[:, :k]
13760+
13761+
b_ref = torch.randn(
13762+
n_groups,
13763+
n,
13764+
k,
13765+
device=device,
13766+
dtype=dtype,
13767+
)[::1, :, :k]
13768+
13769+
offs = torch.arange(
13770+
k, n_groups * k + 1, k, device=device, dtype=torch.int32
13771+
)
13772+
13773+
a_ref.requires_grad_(True)
13774+
b_ref.requires_grad_(True)
13775+
13776+
a_test = a_ref.clone().detach().requires_grad_()
13777+
b_test = b_ref.clone().detach().requires_grad_()
13778+
13779+
out_ref = f(a_ref, b_ref, offs, out_dtype=torch.bfloat16)
13780+
out_ref.sum().backward()
13781+
13782+
out_test = f(a_test, b_test, offs=offs, out_dtype=torch.bfloat16)
13783+
out_test.sum().backward()
13784+
13785+
self.assertEqual(out_ref, out_test)
13786+
self.assertEqual(a_ref.grad, a_test.grad)
13787+
self.assertEqual(b_ref.grad, b_test.grad)
13788+
1374413789
def test_optimize_indexing_assert(self):
1374513790
def has_indirect(code, tl_fn: str):
1374613791
self.assertTrue(

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,7 @@
15141514
"torch._fused_sdp_choice",
15151515
"torch._fw_primal_copy",
15161516
"torch._grid_sampler_2d_cpu_fallback",
1517+
"torch._grouped_mm",
15171518
"torch._has_compatible_shallow_copy_type",
15181519
"torch._histogramdd_bin_edges",
15191520
"torch._histogramdd_from_bin_cts",

0 commit comments

Comments
 (0)
0