10000 [Cutlass] Implement EVT example tensor creation · pytorch/pytorch@7df4876 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7df4876

Browse files
committed
[Cutlass] Implement EVT example tensor creation
udpates to example tensor creation ghstack-source-id: c1d2768 Pull Request resolved: #150904
1 parent 75c71ab commit 7df4876

File tree

3 files changed

+119
-19
lines changed

3 files changed

+119
-19
lines changed

test/inductor/test_cutlass_evt.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
import torch
55
from torch._dynamo.test_case import TestCase
6-
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
6+
from torch._inductor.codegen.cuda.cutlass_utils import (
7+
torch_dtype_to_cutlass_type,
8+
try_import_cutlass,
9+
)
710
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
811

912

@@ -55,33 +58,64 @@
5558
class MockTileDescription:
5659
threadblock_shape = (128, 128, 8)
5760

58-
def _create_mock_buffer_name_map(example_tensors):
59-
class MockNode:
60-
def __init__(self, name, stride, dtype):
61-
self.name = name
62-
self.dtype = dtype
63-
self.stride = stride
61+
class MockNode:
62+
def __init__(self, name, shape, stride, dtype):
63+
self.name = name
64+
self.dtype = dtype
65+
self.shape = shape
66+
self.stride = stride
6467

65-
def get_layout(self):
66-
class MockLayout:
67-
def __init__(self, stride, dtype):
68-
self.dtype = dtype
69-
self.stride = stride
68+
def get_layout(self):
69+
class MockLayout:
70+
def __init__(self, shape, stride, dtype):
71+
self.size = shape
72+
self.stride = stride
73+
self.dtype = dtype
7074

71-
return MockLayout(self.stride, self.dtype)
75+
return MockLayout(self.shape, self.stride, self.dtype)
7276

73-
def get_name(self):
74-
return self.name
77+
def get_name(self):
78+
return self.name
7579

80+
def _create_mock_buffer_name_map(example_tensors):
7681
name_to_buffer = {}
7782
for name, tensor in example_tensors.items():
7883
if isinstance(tensor, CutlassTensor):
79-
name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32)
84+
name_to_buffer[name] = MockNode(
85+
name, tensor.shape, tensor.stride, torch.float32
86+
)
8087

8188
return name_to_buffer
8289

8390

8491
class TestCutlassEVT(TestCase):
92+
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
93+
def test_example_tensor_creation(self):
94+
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
95+
create_example_tensors,
96+
)
97+
98+
row_major_buf0 = MockNode("buf0", (3, 4, 1), (4, 1, 0), torch.float32)
99+
col_major_buf1 = MockNode("buf1", (3, 2, 1), (1, 3, 0), torch.float32)
100+
read_names = ["buf0"]
101+
write_names = ["buf1"]
102+
buffer_renames = {"buf0": "acc"}
103+
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
104+
result = create_example_tensors(
105+
read_names, write_names, buffer_renames, name_to_buffer
106+
)
107+
self.assertEqual(result["acc"].shape, (3, 4, 1))
108+
self.assertEqual(result["acc"].stride, (4, 1, 0))
109+
self.assertEqual(
110+
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
111+
)
112+
113+
self.assertEqual(result["buf1"].shape, (3, 2, 1))
114+
self.assertEqual(result["buf1"].stride, (1, 3, 0))
115+
self.assertEqual(
116+
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
117+
)
118+
85119
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
86120
def test_evt_argument_codegen(self):
87121
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)

torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py

Whitespace-only changes.

torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from typing import Any, Union
22

3-
from torch._inductor.ir import ComputedBuffer, InputBuffer
3+
from torch._inductor.ir import (
4+
ComputedBuffer,
5+
InputBuffer,
6+
is_contiguous_strides_for_shape,
7+
)
48
from torch.utils._ordered_set import OrderedSet
59

6-
from ..cutlass_utils import try_import_cutlass
10+
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
711

812

913
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
@@ -19,6 +23,7 @@
1923
import ast
2024
import ctypes
2125
import textwrap
26+
from typing import Union
2227

2328
from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
2429
EmptyByte,
@@ -41,13 +46,74 @@
4146
from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
4247
Tensor as CutlassTensor,
4348
)
44-
from cutlass_library import DataType, EpilogueScheduleType, TileDescription
49+
from cutlass_library import (
50+
DataType,
51+
EpilogueScheduleType,
52+
LayoutType,
53+
TileDescription,
54+
)
4555

56+
import torch
4657
from torch._inductor.codegen.cuda import cuda_env
4758
from torch._inductor.utils import IndentedBuffer
4859

4960
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
5061

62+
TORCH_TO_CUTLASS_DTYPE = {
63+
torch.float32: DataType.f32,
64+
torch.float16: DataType.f16,
65+
torch.bfloat16: DataType.bf16,
66+
}
67+
68+
def create_example_tensors(
69+
read_names: list[str],
70+
write_names: list[str],
71+
buffer_renames: dict[str, str],
72+
name_to_buffer: dict[str, Buffer],
73+
) -> dict[str, CutlassTensor]:
74+
example_tensors = {}
75+
76+
def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
77+
shape = buffer.get_layout().size
78+
stride = buffer.get_layout().stride
79+
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
80+
f"{buffer.get_name()}'s shape {shape} contains symints which aren't supported for cutlass EVT"
81+
)
82+
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
83+
f"{buffer.get_name()}'s stride {stride} contains symints which aren't supported for cutlass EVT"
84+
)
85+
shape = tuple(int(x) for x in shape)
86+
stride = tuple(int(x) for x in stride)
87+
88+
is_row_major = is_contiguous_strides_for_shape(stride, shape)
89+
is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])
90+
91+
if not is_row_major and not is_column_major:
92+
raise RuntimeError(
93+
f"Cannot create example tensor for {buffer.get_name()} with \
94+
non-contiguous layout, recieved stride: {stride} and shape: {shape}"
95+
)
96+
97+
return CutlassTensor(
98+
shape=shape,
99+
layout_tag=LayoutType.RowMajor
100+
if is_row_major
101+
else LayoutType.ColumnMajor,
102+
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
103+
)
104+
105+
for name in read_names + write_names:
106+
key = name
107+
108+
if name in buffer_renames:
109+
key = buffer_renames[
110+
name
111+
] # Need to rewrite some special args (e.g. acc is a required arg name)
112+
113+
example_tensors[key] = cutlass_tensor_from_buffer(name_to_buffer[name])
114+
115+
return example_tensors
116+
51117
def trace(
52118
fn_src: str,
53119
example_tensors: dict[str, CutlassTensor],

0 commit comments

Comments
 (0)
0