1
1
# Owner(s): ["module: dynamo"]
2
2
import unittest
3
+ import unittest .mock as mock
3
4
4
5
import torch
5
6
import torch ._dynamo .test_case
11
12
normalize_gm ,
12
13
)
13
14
from torch ._higher_order_ops .schema import find_hop_schema
15
+ from torch .testing ._internal .common_utils import (
16
+ instantiate_parametrized_tests ,
17
+ parametrize ,
18
+ )
14
19
from torch .testing ._internal .inductor_utils import HAS_CUDA
15
20
16
21
@@ -135,17 +140,47 @@ def inner(x, y):
135
140
136
141
backend = EagerAndRecordGraphs ()
137
142
138
- @torch .compile (backend = backend , fullgraph = True )
139
143
def f (x , y ):
140
144
return invoke_quant_test (inner , x , y , scheme = "nf4" )
141
145
142
- with self . assertRaisesRegex (
143
- RuntimeError ,
144
- "Encountered input mutation during higher order op tracing for HOP" ,
146
+ with mock . patch (
147
+ "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation" ,
148
+ True ,
145
149
):
146
- f (x .clone (), y )
150
+ torch .compile (f , backend = backend , fullgraph = True )(x .clone (), y )
151
+ self .assertEqual (len (backend .graphs ), 1 )
152
+ self .assertExpectedInline (
153
+ normalize_graph (backend .graphs [0 ]),
154
+ """\
155
+ class GraphModule(torch.nn.Module):
156
+ def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
157
+ l_x_ = L_x_
158
+ l_y_ = L_y_
147
159
148
- def test_schema_gen_pytree_in_out_with_mutation (self ):
160
+ subgraph_0 = self.subgraph_0
161
+ invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
162
+ getitem: "f32[3, 3]" = invoke_quant_test[0]; invoke_quant_test = None
163
+ return (getitem,)
164
+
165
+ class subgraph_0(torch.nn.Module):
166
+ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
167
+ add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
168
+
169
+ mul_: "f32[3, 3]" = l_y_.mul_(-1); mul_ = None
170
+
171
+ matmul: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
172
+ sin: "f32[3, 3]" = matmul.sin(); matmul = None
173
+ cos: "f32[3, 3]" = sin.cos(); sin = None
174
+ return (cos,)
175
+ """ , # noqa: B950
176
+ )
177
+ self .assertExpectedInline (
178
+ str (find_hop_schema (backend .graphs [0 ], invoke_quant_test )[0 ]),
179 + """invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor(a2!) arg1, *, str scheme="nf4") -> ((Tensor))""" ,
180
+ )
181
+
182
+ @parametrize ("backend" , ["eager" , "aot_eager" ])
183
+ def test_schema_gen_pytree_in_out_with_mutation (self , backend ):
149
184
def inner (x_y ):
150
185
x , y = x_y
151
186
x .add_ (1 )
@@ -159,17 +194,88 @@ def inner(x_y):
159
194
x = torch .randn (3 , 3 , requires_grad = False )
160
195
y = torch .randn (3 , 3 , requires_grad = True )
161
196
162
- backend = EagerAndRecordGraphs ()
197
+ if backend == "eager" :
198
+ bk = EagerAndRecordGraphs ()
199
+ else :
200
+ assert backend == "aot_eager"
201
+ bk = AotEagerAndRecordGraphs ()
163
202
164
- @torch .compile (backend = backend , fullgraph = True )
165
203
def f (x , y ):
166
204
return invoke_quant_test (inner , [x , y ], scheme = "nf4" )
167
205
168
- with self . assertRaisesRegex (
169
- RuntimeError ,
170
- "Encountered input mutation during higher order op tracing for HOP" ,
206
+ with mock . patch (
207
+ "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation" ,
208
+ True ,
171
209
):
172
- f (x .clone (), y )
210
+ torch .compile (f , backend = bk , fullgraph = True )(x .clone (), y )
211
+
212
+ if backend == "eager" :
213
+ self .assertEqual (len (bk .graphs ), 1 )
214
+ self .assertExpectedInline (
215
+ normalize_graph (bk .graphs [0 ]),
216
+ """\
217
+ class GraphModule(torch.nn.Module):
218
+ def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
219
+ l_x_ = L_x_
220
+ l_y_ = L_y_
221
+
222
+ subgraph_0 = self.subgraph_0
223
+ invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
224
+ getitem: "f32[3, 3]" = invoke_quant_test[0]
225
+ getitem_1: "f32[3, 3]" = invoke_quant_test[1]
226
+ getitem_2: "f32[3, 3]" = invoke_quant_test[2]
227
+ getitem_3: "f32[3, 3]" = invoke_quant_test[3]; invoke_quant_test = None
228
+ return (getitem, getitem_1, getitem_2, getitem_3)
229
+
230
+ class subgraph_0(torch.nn.Module):
231
+ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
232
+ add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
233
+
234
+ matmul: "f32[3, 3]" = l_x_ @ l_y_
235
+ sin: "f32[3, 3]" = matmul.sin(); matmul = None
236
+ child: "f32[3, 3]" = sin.cos(); sin = None
237
+
238
+ child_1: "f32[3, 3]" = l_x_ + l_y_
239
+ child_2: "f32[3, 3]" = l_x_ - l_y_
240
+
241
+ child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
242
+ return (child, child_1, child_2, child_3)
243
+ """ , # noqa: B950
244
+ )
245
+ self .assertExpectedInline (
246
+ str (find_hop_schema (bk .graphs [0 ], invoke_quant_test )[0 ]),
247
+ """invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor arg1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""" , # noqa: B950
248
+ )
249
+ elif backend == "aot_eager" :
250
+ self .assertEqual (len (bk .fw_graphs ), 1 )
251
+ self .assertExpectedInline (
252
+ normalize_graph (bk .fw_graphs [0 ]),
253
+ """\
254
+ class GraphModule(torch.nn.Module):
255
+ def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
256
+ auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
257
+ _tree_spec_constant0 = self._tree_spec_constant0
258
+ auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None
259
+ getitem: "f32[3, 3]" = auto_functionalized_v2[0]
260
+ getitem_1: "f32[3, 3]" = auto_functionalized_v2[1]
261
+ getitem_2: "f32[3, 3]" = auto_functionalized_v2[2]
262
+ getitem_3: "f32[3, 3]" = auto_functionalized_v2[3]
263
+ getitem_4: "f32[3, 3]" = auto_functionalized_v2[4]; auto_functionalized_v2 = None
264
+ return (getitem, getitem_1, getitem_2, getitem_3, primals_1, primals_2, getitem_4)
265
+
266
+ class auto_functionalized_subgraph_0(torch.nn.Module):
267
+ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
268
+ add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1)
269
+ mm: "f32[3, 3]" = torch.ops.aten.mm.default(add, arg1_1)
270
+ sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None
271
+ cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); sin = None
272
+ add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1)
273
+ sub: "f32[3, 3]" = torch.ops.aten.sub.Tensor(add, arg1_1)
274
+ mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(add,
10000
arg1_1); arg1_1 = None
275
+ copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
276
+ return (cos, add_1, sub, mm_1)
277
+ """ , # noqa: B950
278
+ )
173
279
174
280
def test_none_input (self ):
175
281
def inner (x , y ):
@@ -239,6 +345,49 @@ def forward(self, l_y_: "f32[3, 4]"):
239
345
""" ,
240
346
)
241
347
348
+ def test_auto_functionalize (self ):
349
+ def inner (x , y ):
350
+ x .add_ (1 )
351
+ return x + y
352
+
353
+ backend = AotEagerAndRecordGraphs ()
354
+
355
+ def f (x , y ):
356
+ return invoke_quant_test (inner , x , y , scheme = "nf4" )
357
+
358
+ x = torch .randn (3 , 3 , requires_grad = False )
359
+ x_clone = x .clone ()
360
+ y = torch .randn (3 , 3 , requires_grad = True )
361
+ with mock .patch (
362
+ "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation" ,
363
+ True ,
364
+ ):
365
+ compiled_out = torch .compile (f , backend = backend , fullgraph = True )(x , y )
366
+ # assert x is not mutated
367
+ self .assertEqual (x , x_clone )
368
+ self .assertEqual (compiled_out , x + y + 1 )
369
+ self .assertEqual (len (backend .fw_graphs ), 1 )
370
+ self .assertExpectedInline (
371
+ normalize_graph (backend .fw_graphs [0 ]),
372
+ """\
373
+ class GraphModule(torch.nn.Module):
374
+ def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
375
+ auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
376
+ _tree_spec_constant0 = self._tree_spec_constant0
377
+ auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None
378
+ getitem: "f32[3, 3]" = auto_functionalized_v2[0]
379
+ getitem_1: "f32[3, 3]" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
380
+ return (getitem, primals_1, primals_2, getitem_1)
381
+
382
+ class auto_functionalized_subgraph_0(torch.nn.Module):
383
+ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
384
+ add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1)
385
+ add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1); arg1_1 = None
386
+ copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
387
+ return (add_1,)
388
+ """ , # noqa: B950
389
+ )
390
+
242
391
@torch ._dynamo .config .patch (assume_static_by_default = True )
243
392
def test_aot_eager (self ):
244
393
def inner (x , y ):
@@ -353,6 +502,9 @@ def inner(x, y):
353
502
invoke_quant_test (result , x , y , scheme = "nf4" )
354
503
355
504
505
+ instantiate_parametrized_tests (BaseHOPTest )
506
+
507
+
356
508
if __name__ == "__main__" :
357
509
from torch ._dynamo .test_case import run_tests
358
510
0 commit comments