8000 Update auto-tuning support for _scaled_grouped_mm · pytorch/pytorch@604837c · GitHub
[go: up one dir, main page]

Skip to content

Commit 604837c

Browse files
committed
Update auto-tuning support for _scaled_grouped_mm
1. Enable strided inputs 2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs 3. Fix non-TMA load variant 4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor 5. Fix cases when group size along K dimension is not multiple of block size along K 6. Implemented meta registration ghstack-source-id: 5bbfaae Pull Request resolved: #150944
1 parent a3123dd commit 604837c

File tree

7 files changed

+469
-178
lines changed

7 files changed

+469
-178
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ namespace {
15331533
"D, arg ",
15341534
arg_idx);
15351535
TORCH_CHECK(
1536-
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
1536+
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
15371537
TORCH_CHECK(
15381538
scale.size(0) == mat.size(dim) * scale_multiplier,
15391539
"scale must have the same length as mat for arg ",
@@ -1546,8 +1546,8 @@ namespace {
15461546
"D for arg ",
15471547
arg_idx);
15481548
TORCH_CHECK(
1549-
scale.stride(1),
1550-
"scale_a must be contiguous in the last dimension for arg ",
1549+
scale.stride(1) == 1,
1550+
"scale must be contiguous in the last dimension for arg ",
15511551
arg_idx);
15521552
TORCH_CHECK(
15531553
scale.size(0) == mat.size(0),
@@ -1611,6 +1611,7 @@ bool use_fast_accum) {
16111611

16121612

16131613
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
1614+
TORCH_CHECK(!scale_result.has_value(), "Scale result not supported yet");
16141615
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
16151616

16161617
if (offs.has_value()) {

torch/_inductor/graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def mark_nodes_dislike_padding(
204204
aten.convolution,
205205
aten.convolution_backward,
206206
aten._scaled_mm,
207-
aten._scaled_grouped_mm,
208207
]
209208
)
210209
# what's a better way to collect the reduction ops?

torch/_inductor/kernel/mm_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
5656

5757

5858
@SymbolicGridFn
59-
def persistent_grouped_mm_grid(m, n, meta):
59+
def persistent_grouped_mm_grid(*args):
60+
meta = args[-1]
6061
return (meta["NUM_SMS"], 1, 1)
6162

6263

0 commit comments

Comments
 (0)
0