|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from torch._dynamo.test_case import TestCase
|
6 |
| -from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass |
| 6 | +from torch._inductor.codegen.cuda.cutlass_utils import ( |
| 7 | + torch_dtype_to_cutlass_type, |
| 8 | + try_import_cutlass, |
| 9 | +) |
7 | 10 | from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
8 | 11 |
|
9 | 12 |
|
|
55 | 58 | class MockTileDescription:
|
56 | 59 | threadblock_shape = (128, 128, 8)
|
57 | 60 |
|
58 |
| - def _create_mock_buffer_name_map(example_tensors): |
59 |
| - class MockNode: |
60 |
| - def __init__(self, name, stride, dtype): |
61 |
| - self.name = name |
62 |
| - self.dtype = dtype |
63 |
| - self.stride = stride |
| 61 | + class MockNode: |
| 62 | + def __init__(self, name, shape, stride, dtype): |
| 63 | + self.name = name |
| 64 | + self.dtype = dtype |
| 65 | + self.shape = shape |
| 66 | + self.stride = stride |
64 | 67 |
|
65 |
| - def get_layout(self): |
66 |
| - class MockLayout: |
67 |
| - def __init__(self, stride, dtype): |
68 |
| - self.dtype = dtype |
69 |
| - self.stride = stride |
| 68 | + def get_layout(self): |
| 69 | + class MockLayout: |
| 70 | + def __init__(self, shape, stride, dtype): |
| 71 | + self.size = shape |
| 72 | + self.stride = stride |
| 73 | + self.dtype = dtype |
70 | 74 |
|
71 |
| - return MockLayout(self.stride, self.dtype) |
| 75 | + return MockLayout(self.shape, self.stride, self.dtype) |
72 | 76 |
|
73 |
| - def get_name(self): |
74 |
| - return self.name |
| 77 | + def get_name(self): |
| 78 | + return self.name |
75 | 79 |
|
| 80 | + def _create_mock_buffer_name_map(example_tensors): |
76 | 81 | name_to_buffer = {}
|
77 | 82 | for name, tensor in example_tensors.items():
|
78 | 83 | if isinstance(tensor, CutlassTensor):
|
79 |
| - name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32) |
| 84 | + name_to_buffer[name] = MockNode( |
| 85 | + name, tensor.shape, tensor.stride, torch.float32 |
| 86 | + ) |
80 | 87 |
|
81 | 88 | return name_to_buffer
|
82 | 89 |
|
83 | 90 |
|
84 | 91 | class TestCutlassEVT(TestCase):
|
| 92 | + @unittest.skipIf(not try_import_cutlass(), "requires cutlass") |
| 93 | + def test_example_tensor_creation(self): |
| 94 | + from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( |
| 95 | + create_example_tensors, |
| 96 | + ) |
| 97 | + |
| 98 | + row_major_buf0 = MockNode("buf0", (3, 4, 1), (4, 1, 0), torch.float32) |
| 99 | + col_major_buf1 = MockNode("buf1", (3, 2, 1), (1, 3, 0), torch.float32) |
| 100 | + read_names = ["buf0"] |
| 101 | + write_names = ["buf1"] |
| 102 | + buffer_renames = {"buf0": "acc"} |
| 103 | + name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1} |
| 104 | + result = create_example_tensors( |
| 105 | + read_names, write_names, buffer_renames, name_to_buffer |
| 106 | + ) |
| 107 | + self.assertEqual(result["acc"].shape, (3, 4, 1)) |
| 108 | + self.assertEqual(result["acc"].stride, (4, 1, 0)) |
| 109 | + self.assertEqual( |
| 110 | + result["acc"].element, torch_dtype_to_cutlass_type(torch.float32) |
| 111 | + ) |
| 112 | + |
| 113 | + self.assertEqual(result["buf1"].shape, (3, 2, 1)) |
| 114 | + self.assertEqual(result["buf1"].stride, (1, 3, 0)) |
| 115 | + self.assertEqual( |
| 116 | + result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32) |
| 117 | + ) |
| 118 | + |
85 | 119 | @unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
86 | 120 | def test_evt_argument_codegen(self):
|
87 | 121 | epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
|
|
0 commit comments