1
1
# Owner(s): ["module: inductor"]
2
2
import functools
3
3
import unittest
4
+ from unittest import mock
5
+ from unittest .mock import MagicMock
4
6
5
7
import torch
6
8
from torch ._dispatch .python import enable_python_dispatcher
7
9
from torch ._inductor .codegen .subgraph import SubgraphTemplate
8
10
from torch ._inductor .decomposition import select_decomp_table
9
- from torch ._inductor .ir import Buffer , FixedLayout
11
+ from torch ._inductor .ir import Buffer , FixedLayout , FlexibleLayout
10
12
from torch ._inductor .lowering import register_lowering
11
- from torch ._inductor .select_algorithm import (
12
- AlgorithmSelectorCache ,
13
- autotune_select_algorithm ,
14
- )
13
+ from torch ._inductor .select_algorithm import autotune_select_algorithm
15
14
from torch ._inductor .test_case import run_tests , TestCase
16
15
from torch .fx .experimental .proxy_tensor import make_fx
17
16
from torch .testing ._internal .common_utils import skipIfXpu , TEST_WITH_ROCM
18
17
from torch .testing ._internal .inductor_utils import GPU_TYPE , HAS_CPU , HAS_GPU
19
18
20
19
20
+ def decomposeK (a , b , kPartitions ):
21
+ m = a .shape [0 ]
22
+ n = b .shape [1 ]
23
+ k = a .shape [1 ]
24
+
25
+ B = k // kPartitions
26
+ a_reshaped = torch .permute (a .reshape (m , B , kPartitions ), (1 , 0 , 2 ))
27
+ b_reshaped = b .reshape (B , kPartitions , n )
28
+ result = torch .bmm (a_reshaped , b_reshaped , out_dtype = torch .float32 )
29
+ result_fp32 = result .to (torch .float32 )
30
+ reduced_buf = torch .sum (result_fp32 , 0 )
31
+ return reduced_buf .to (a .dtype )
32
+
33
+
21
34
class TestSubgraphChoice (TestCase ):
22
35
def setUp (self ):
23
36
super ().setUp ()
@@ -34,6 +47,8 @@ def test_subgraph_decompose_k(self):
34
47
from torch ._inductor .kernel .mm import aten_mm
35
48
from torch ._inductor .kernel .mm_common import mm_args
36
49
50
+ mat1_shape , mat2_shape = (32 , 4096 ), (4096 , 32 )
51
+
6D40
37
52
@torch .library .custom_op ("mylib::matmul_decompose" , mutates_args = {})
38
53
def matmul_decompose (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
39
54
return a @ b
@@ -42,28 +57,12 @@ def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
42
57
def _ (a , b ):
43
58
return a @ b
44
59
45
- def decomposeK (a , b , kPartitions ):
46
- m = a .shape [0 ]
47
- n = b .shape [1 ]
48
- k = a .shape [1 ]
49
-
50
- B = k // kPartitions
51
- a_reshaped = torch .permute (a .reshape (m , B , kPartitions ), (1 , 0 , 2 ))
52
- b_reshaped = b .reshape (B , kPartitions , n )
53
- result = torch .bmm (a_reshaped , b_reshaped , out_dtype = torch .float32 )
54
- result_fp32 = result .to (torch .float32 )
55
- reduced_buf = torch .sum (result_fp32 , 0 )
56
- return reduced_buf .to (a .dtype )
57
-
58
- mat1_shape , mat2_shape = (32 , 4096 ), (4096 , 32 )
59
-
60
60
@register_lowering (torch .ops .mylib .matmul_decompose )
61
61
def _ (a , b ):
62
62
_ , _ , _ , layout , mat1 , mat2 = mm_args (a , b )
63
63
64
64
choices = [aten_mm .bind ((mat1 , mat2 ), layout )]
65
65
66
- # TODO (PaulZhang12): Once decomposeK lands in Inductor, move this
67
66
kPartitions = 256
68
67
with enable_python_dispatcher ():
69
68
decompositions = select_decomp_table ()
@@ -77,15 +76,10 @@ def _(a, b):
77
76
),
78
77
)
79
78
80
- mat1_tensor , mat2_tensor = (
81
- AlgorithmSelectorCache .benchmark_example_value (mat1 ),
82
- AlgorithmSelectorCache .benchmark_example_value (mat2 ),
83
- )
84
79
decompose_k_subgraph_template .maybe_append_choice (
85
80
choices ,
86
81
input_nodes = (mat1 , mat2 ),
87
82
layout = layout ,
88
- example_inputs = [mat1_tensor , mat2_tensor ],
89
83
)
90
84
91
85
# Test benchmarking against aten
@@ -112,8 +106,88 @@ def func(mat1, mat2):
112
106
res = compiled_func (a_in , b_in )
113
107
114
108
# Check same results of compiled result and regular torch.mm
115
- # Relax precision as decomposeK does first accumulation in fp16
116
- torch .testing .assert_close (res , a_in @ b_in , atol = 1e-1 , rtol = 1e-1 )
109
+ torch .testing .assert_close (res , a_in @ b_in , atol = 1e-2 , rtol = 1e-2 )
110
+
111
+ @skipIfXpu
112
+ @unittest .skipIf (TEST_WITH_ROCM , "decompose_k not supported on ROCm" )
113
+ def test_subgraph_freeze_layout (self ):
114
+ from torch ._inductor .kernel .mm_common import mm_args
115
+
116
+ M , N , K = (4 , 128 , 14240 )
117
+ a_in = torch .randn (
118
+ (M , K ), dtype = torch .bfloat16 , device = torch .device (f"{ GPU_TYPE } :0" )
119
+ )
120
+ b_in = torch .randn (
121
+ (K , N ), dtype = torch .bfloat16 , device = torch .device (f"{ GPU_TYPE } :0" )
122
+ )
123
+
124
+ @torch .library .custom_op ("mylib::matmul_decompose_padding" , mutates_args = {})
125
+ def matmul_decompose (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
126
+ return a @ b
127
+
128
+ @matmul_decompose .register_fake
129
+ def _ (a , b ):
130
+ return a @ b
131
+
132
+ @register_lowering (torch .ops .mylib .matmul_decompose_padding )
133
+ def _ (a , b ):
134
+ _ , _ , _ , layout , mat1 , mat2 = mm_args (a , b )
135
+ mat1_layout = mat1 .layout
136
+ assert isinstance (mat1_layout , FlexibleLayout )
137
+ mat1_stride = mat1_layout .stride
138
+
139
+ choices = []
140
+
141
+ kPartitions = 2
142
+ with enable_python_dispatcher ():
143
+ decompositions = select_decomp_table ()
144
+
145
+ decompose_k_subgraph_template = SubgraphTemplate (
146
+ name = "decompose_k_mm" ,
147
+ make_fx_graph = make_fx (
148
+ functools .partial (decomposeK , kPartitions = kPartitions ),
149
+ decompositions ,
150
+ ),
151
+ )
152
+
153
+ decompose_k_subgraph_template .maybe_append_choice (
154
+ choices ,
155
+ input_nodes = (mat1 , mat2 ),
156
+ layout = layout ,
157
+ )
158
+
159
+ choice = choices [0 ]
160
+ assert isinstance (mat1 .layout , FixedLayout )
161
+
162
+ # Creating the subgraph choice should have frozen the layout
163
+ # We ensure padding so the stride should differ
164
+ assert mat1 .layout .stride != mat1_stride
165
+
166
+ for example_stride , layout_stride in zip (
167
+ choice .example_inputs [0 ].stride (), mat1 .layout .stride
168
+ ):
169
+ # Example inputs should have same stride as current layout
170
+ assert example_stride == layout_stride
171
+
172
+ return autotune_select_algorithm (
173
+ "test_subgraph_choice" , choices , [a , b ], layout
174
+ )
175
+
176
+ def func (mat1 , mat2 ):
177
+ return torch .ops .mylib .matmul_decompose_padding ((mat1 + 1.0 ), mat2 )
178
+
179
+ with mock .patch ("torch._inductor.ir.V.get_current_node" ) as get_node_mock :
180
+ node_mock = MagicMock ()
181
+ node_mock .meta = {"dislike_padding" : False }
182
+ get_node_mock .return_value = node_mock
183
+
184
+ compiled_func = torch .compile (func , mode = "max-autotune" , dynamic = False )
185
+
186
+ res = compiled_func (a_in , b_in )
187
+
188
+ # Check same results of compiled result and regular torch.mm
189
+ # Relax precision as decomposeK does first accumulation in fp16
190
+ torch .testing .assert_close (res , (a_in + 1.0 ) @ b_in , atol = 1e-2 , rtol = 1e-2 )
117
191
118
192
119
193
if __name__ == "__main__" :