8000 [Cutlass] Implement EVT example tensor creation (#150904) · pytorch/pytorch@a936d59 · GitHub
[go: up one dir, main page]

Skip to content

Commit a936d59

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] Implement EVT example tensor creation (#150904)
This PR implements a translation layer from inductor IR to "example tensors" the expected arguments of the EVT tracer. These tensors basically store the name, shape, stride, and dtype of the tensor and allow an ast-based python parse to generate the EVT C++. udpates to example tensor creation Previously merged: * #150903 * #150346 * #150345 * #150344 Pull Request resolved: #150904 Approved by: https://github.com/eellison
1 parent dda0c95 commit a936d59

File tree

3 files changed

+119
-19
lines changed

3 files changed

+119
-19
lines changed

test/inductor/test_cutlass_evt.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
import torch
55
from torch._dynamo.test_case import TestCase
6-
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
6+
from torch._inductor.codegen.cuda.cutlass_utils import (
7+
torch_dtype_to_cutlass_type,
8+
try_import_cutlass,
9+
)
710
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
811

912

@@ -55,33 +58,64 @@
5558
class MockTileDescription:
5659
threadblock_shape = (128, 128, 8)
5760

58-
def _create_mock_buffer_name_map(example_tensors):
59-
class MockNode:
60-
def __init__(self, name, stride, dtype):
61-
self.name = name
62-
self.dtype = dtype
63-
self.stride = stride
61+
class MockNode:
62+
def __init__(self, name, shape, stride, dtype):
63+
self.name = name
64+
self.dtype = dtype
65+
self.shape = shape
66+
self.stride = stride
6467

65-
def get_layout(self):
66-
class MockLayout:
67-
def __init__(self, stride, dtype):
68-
self.dtype = dtype
69-
self.stride = stride
68+
def get_layout(self):
69+
class MockLayout:
70+
def __init__(self, shape, stride, dtype):
71+
self.size = shape
72+
self.stride = stride
73+
self.dtype = dtype
7074

71-
return MockLayout(self.stride, self.dtype)
75+
return MockLayout(self.shape, self.stride, self.dtype)
7276

73-
def get_name(self):
74-
return self.name
77+
def get_name(self):
78+
return self.name
7579

80+
def _create_mock_buffer_name_map(example_tensors):
7681
name_to_buffer = {}
7782
for name, tensor in example_tensors.items():
7883
if isinstance(tensor, CutlassTensor):
79-
name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32)
84+
name_to_buffer[name] = MockNode(
85+
name, tensor.shape, tensor.stride, torch.float32
86+
)
8087

8188
return name_to_buffer
8289

8390

8491
class TestCutlassEVT(TestCase):
92+
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
93+
def test_example_tensor_creation(self):
94+
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
95+
create_example_tensors,
96+
)
97+
98+
row_major_buf0 = MockNode("buf0", (3, 4, 1), (4, 1, 0), torch.float32)
99+
col_major_buf1 = MockNode("buf1", (3, 2, 1), (1, 3, 0), torch.float32)
100+
read_names = ["buf0"]
101+
write_names = ["buf1"]
102+
buffer_renames = {"buf0": "acc"}
103+
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
104+
result = create_example_tensors(
105+
read_names, write_names, buffer_renames, name_to_buffer
106+
)
107+
self.assertEqual(result["acc"].shape, (3, 4, 1))
108+
self.assertEqual(result["acc"].stride, (4, 1, 0))
109+
self.assertEqual(
110+
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
111+
)
112+
113+
self.assertEqual(result["buf1"].shape, (3, 2, 1))
114+
self.assertEqual(result["buf1"].stride, (1, 3, 0))
115+
self.assertEqual(
116+
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
117+
)
118+
85119
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
86120
def test_evt_argument_codegen(self):
87121
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)

torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py

Whitespace-only changes.

torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from typing import Any, Union
22

3-
from torch._inductor.ir import ComputedBuffer, InputBuffer
3+
from torch._inductor.ir import (
4+
ComputedBuffer,
5+
InputBuffer,
6+
is_contiguous_strides_for_shape,
7+
)
48
from torch.utils._ordered_set import OrderedSet
59

6-
from ..cutlass_utils import try_import_cutlass
10+
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
711

812

913
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
@@ -19,6 +23,7 @@
1923
import ast
2024
import ctypes
2125
import textwrap
26+
from typing import Union
2227

2328
from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
2429
EmptyByte,
@@ -41,13 +46,74 @@
4146
from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
4247
Tensor as CutlassTensor,
4348
)
44-
from cutlass_library import DataType, EpilogueScheduleType, TileDescription
49+
from cutlass_library import (
50+
DataType,
51+
EpilogueScheduleType,
52+
LayoutType,
53+
TileDescription,
54+
)
4555

56+
import torch
4657
from torch._inductor.codegen.cuda import cuda_env
4758
from torch._inductor.utils import IndentedBuffer
4859

4960
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
5061

62+
TORCH_TO_CUTLASS_DTYPE = {
63+
torch.float32: DataType.f32,
64+
torch.float16: DataType.f16,
65+
torch.bfloat16: DataType.bf16,
66+
}
67+
68+
def create_example_tensors(
69+
read_names: list[str],
70+
write_names: list[str],
71+
buffer_renames: dict[str, str],
72+
name_to_buffer: dict[str, Buffer],
73+
) -> dict[str, CutlassTensor]:
74+
example_tensors = {}
75+
76+
def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
77+
shape = buffer.get_layout().size
78+
stride = buffer.get_layout().stride
79+
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
80+
f"{buffer.get_name()}'s shape {shape} contains symints which aren't supported for cutlass EVT"
81+
)
82+
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
83+
f"{buffer.get_name()}'s stride {stride} contains symints which aren't supported for cutlass EVT"
84+
)
85+
shape = tuple(int(x) for x in shape)
86+
stride = tuple(int(x) for x in stride)
87+
88+
is_row_major = is_contiguous_strides_for_shape(stride, shape)
89+
is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])
90+
91+
if not is_row_major and not is_column_major:
92+
raise RuntimeError(
93+
f"Cannot create example tensor for {buffer.get_name()} with \
94+
non-contiguous layout, recieved stride: {stride} and shape: {shape}"
95+
)
96+
97+
return CutlassTensor(
98+
shape=shape,
99+
layout_tag=LayoutType.RowMajor
100+
if is_row_major
101+
else LayoutType.ColumnMajor,
102+
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
103+
)
104+
105+
for name in read_names + write_names:
106+
key = name
107+
108+
if name in buffer_renames:
109+
key = buffer_renames[
110+
name
111+
] # Need to rewrite some special args (e.g. acc is a required arg name)
112+
113+
example_tensors[key] = cutlass_tensor_from_buffer(name_to_buffer[name])
114+
115+
return example_tensors
116+
51117
def trace(
52118
fn_src: str,
53119
example_tensors: dict[str, CutlassTensor],

0 commit comments

Comments
 (0)
0