@@ -28,6 +28,66 @@ def __init__(self):
28
28
def __call__ (self , subgraph , * operands , scheme ):
29
29
return super ().__call__ (subgraph , * operands , scheme = scheme )
30
30
31
+ def gen_schema (self , subgraph , * operands , scheme ):
32
+ # Idea 1: using inspect.signature and sample inputs to generate a schema
33
+ # Idea 2: we still need to know how to call into subgraph/fn given the inputs.
34
+ # wrap_subgraphs gives two callable to call into subgraph.
35
+ from torch ._higher_order_ops .schema import (
36
+ CFunctionSchemaGen ,
37
+ HopArgumentInfoGen ,
38
+ )
39
+ from torch ._higher_order_ops .utils import (
40
+ check_input_alias_and_mutation_return_ouputs ,
41
+ )
42
+
43
+ (
44
+ mutated_inp_idx ,
45
+ inp_inp_alias ,
46
+ inp_out_alias ,
47
+ out_
8000
out_alias ,
48
+ output ,
49
+ ) = check_input_alias_and_mutation_return_ouputs (subgraph , operands )
50
+ assert (
51
+ len (inp_inp_alias ) == 0
52
+ and len (inp_out_alias ) == 0
53
+ and len (out_out_alias ) == 0
54
+ ), f"Aliasing is not suppported for HOP subgraph. { subgraph } "
55
+
56
+ args = [
57
+ HopArgumentInfoGen .from_example (
58
+ subgraph , name = "subgraph" , default_value = None , is_mutated = False
59
+ )
60
+ ]
61
+ for idx , arg in enumerate (operands ):
62
+ example_value = arg
63
+ arg_name = f"operands{ idx } "
64
+ args .append (
65
+ HopArgumentInfoGen .from_example (
66
+ example_value = example_value ,
67
+ name = arg_name ,
68
+ default_value = None ,
69
+ is_mutated = idx in mutated_inp_idx ,
70
+ )
71
+ )
72
+
73
+ args .append (
74
+ HopArgumentInfoGen .from_example (
75
+ example_value = scheme ,
76
+ name = "scheme" ,
77
+ default_value = scheme ,
78
+ is_mutated = False ,
79
+ kw_only = True ,
80
+ )
81
+ )
82
+ output = HopArgumentInfoGen .from_example (
83
+ example_value = output ,
84
+ name = "output" ,
85
+ default_value = None ,
86
+ is_mutated = False ,
87
+ kw_only = False ,
88
+ )
89
+ return CFunctionSchemaGen .from_hop_argument_info (str (self ), args , output )
90
+
31
91
32
92
invoke_quant_test = InvokeQuantTest ()
33
93
@@ -93,7 +153,7 @@ def f(x, y):
93
153
self .assertEqual (len (schemas ), 1 )
94
154
self .assertExpectedInline (
95
155
str (schemas [0 ]),
96
- """invoke_quant_test(Any subgraph, Tensor arg0 , Tensor arg1 , str scheme="nf4") -> ((Tensor))""" , # noqa: B950
156
+ """invoke_quant_test(Any subgraph, Tensor operands0 , Tensor operands1, * , str scheme="nf4") -> ((Tensor))""" , # noqa: B950
97
157
)
98
158
99
159
def test_schema_gen_pytree_in_out (self ):
@@ -121,7 +181,7 @@ def f(x, y):
121
181
self .assertEqual (len (schemas ), 1 )
122
182
self .assertExpectedInline (
123
183
str (schemas [0 ]),
124
- """invoke_quant_test(Any subgraph, Tensor arg0 , Tensor arg1 , str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""" , # noqa: B950
184
+ """invoke_quant_test(Any subgraph, Tensor operands0 , Tensor operands1, * , str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""" , # noqa: B950
125
185
)
126
186
127
187
def test_schema_gen_single_return_with_mutation (self ):
@@ -135,15 +195,40 @@ def inner(x, y):
135
195
136
196
backend = EagerAndRecordGraphs ()
137
197
138
- @torch .compile (backend = backend , fullgraph = True )
139
198
def f (x , y ):
140
199
return invoke_quant_test (inner , x , y , scheme = "nf4" )
141
200
142
- with self .assertRaisesRegex (
143
- RuntimeError ,
144
- "Encountered input mutation during higher order op tracing for HOP" ,
145
- ):
146
- f (x .clone (), y )
201
+ torch .compile (f , backend = backend , fullgraph = True )(x .clone (), y )
202
+ self .assertEqual (len (backend .graphs ), 1 )
203
+ self .assertExpectedInline (
204
+ normalize_graph (backend .graphs [0 ]),
205
+ """\
206
+ class GraphModule(torch.nn.Module):
207
+ def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
208
+ l_x_ = L_x_
209
+ l_y_ = L_y_
210
+
211
+ subgraph_0 = self.subgraph_0
212
+ invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
213
+ getitem: "f32[3, 3]" = invoke_quant_test[0]; invoke_quant_test = None
214
+ return (getitem,)
215
+
216
+ class subgraph_0(torch.nn.Module):
217
+ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
218
+ add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
219
+
220
+ mul_: "f32[3, 3]" = l_y_.mul_(-1); mul_ = None
221
+
222
+ matmul: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
223
+ sin: "f32[3, 3]" = matmul.sin(); matmul = None
224
+ cos: "f32[3, 3]" = sin.cos(); sin = None
225
+ return (cos,)
226
+ """ , # noqa: B950
227
+ )
228
+ self .assertExpectedInline (
229
+ str (find_hop_schema (backend .graphs [0 ], invoke_quant_test )[0 ]),
230
+ """invoke_quant_test(Any subgraph, Tensor(a1!) operands0, Tensor(a2!) operands1, *, str scheme="nf4") -> ((Tensor))""" ,
231
+ )
147
232
148
233
def test_schema_gen_pytree_in_out_with_mutation (self ):
149
234
def inner (x_y ):
@@ -161,15 +246,46 @@ def inner(x_y):
161
246
162
247
backend = EagerAndRecordGraphs ()
163
248
164
- @torch .compile (backend = backend , fullgraph = True )
165
249
def f (x , y ):
166
250
return invoke_quant_test (inner , [x , y ], scheme = "nf4" )
167
251
168
- with self .assertRaisesRegex (
169
- RuntimeError ,
170
- "Encountered input mutation during higher order op tracing for HOP" ,
171
- ):
172
- f (x .clone (), y )
252
+ torch .compile (f , backend = backend , fullgraph = True )(x .clone (), y )
253
+ self .assertEqual (len (backend .graphs ), 1 )
254
+ self .assertExpectedInline (
255
+ normalize_graph (backend .graphs [0 ]),
256
+ """\
257
+ class GraphModule(torch.nn.Module):
258
+ def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
259
+ l_x_ = L_x_
260
+ l_y_ = L_y_
261
+
262
+ subgraph_0 = self.subgraph_0
263
+ invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); subgraph_0 = l_x_ = l_y_ = None
264
+ getitem: "f32[3, 3]" = invoke_quant_test[0]
265
+ getitem_1: "f32[3, 3]" = invoke_quant_test[1]
266
+ getitem_2: "f32[3, 3]" = invoke_quant_test[2]
267
+ getitem_3: "f32[3, 3]" = invoke_quant_test[3]; invoke_quant_test = None
268
+ return (getitem, getitem_1, getitem_2, getitem_3)
269
+
270
+ class subgraph_0(torch.nn.Module):
271
+ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
272
+ add_: "f32[3, 3]" = l_x_.add_(1); add_ = None
273
+
274
+ matmul: "f32[3, 3]" = l_x_ @ l_y_
275
+ sin: "f32[3, 3]" = matmul.sin(); matmul = None
276
+ child: "f32[3, 3]" = sin.cos(); sin = None
277
+
278
+ child_1: "f32[3, 3]" = l_x_ + l_y_
279
+ child_2: "f32[3, 3]" = l_x_ - l_y_
280
+
281
+ child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
282
+ return (child, child_1, child_2, child_3)
283
+ """ , # noqa: B950
284
+ )
285
+ self .assertExpectedInline (
286
+ str (find_hop_schema (backend .graphs [0 ], invoke_quant_test )[0 ]),
287
+ """invoke_quant_test(Any subgraph, Tensor(a1!) operands0, Tensor operands1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""" , # noqa: B950
288
+ )
173
289
174
290
def test_none_input (self ):
175
291
def inner (x , y ):
@@ -239,6 +355,44 @@ def forward(self, l_y_: "f32[3, 4]"):
239
355
""" ,
240
356
)
241
357
358
+ def test_auto_functionalize (self ):
359
+ def inner (x , y ):
360
+ x .add_ (1 )
361
+ return x + y
362
+
363
+ backend = AotEagerAndRecordGraphs ()
364
+
365
+ def f (x , y ):
366
+ return invoke_quant_test (inner , x , y , scheme = "nf4" )
367
+
368
+ x = torch .randn (3 , 3 , requires_grad = False )
369
+ x_clone = x .clone ()
370
+ y = torch .randn (3 , 3 , requires_grad = True )
371
+ compiled_out = torch .compile (f , backend = backend , fullgraph = True )(x , y )
372
+ # assert x is not mutated
373
+ self .assertEqual (x , x_clone )
374
+ self .assertEqual (compiled_out , x + y + 1 )
375
+ self .assertEqual (len (backend .fw_graphs ), 1 )
376
+ self .assertExpectedInline (
377
+ normalize_graph (backend .fw_graphs [0 ]),
378
+ """\
379
+ class GraphModule(torch.nn.Module):
380
+ def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
381
+ functiona_schema_0 = self.functiona_schema_0
382
+ auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
383
+ auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, operands1 = primals_2, scheme = 'nf4', _operands0_base_index = 0, _all_bases = [primals_1], _op_schema = functiona_schema_0); auto_functionalized_subgraph_0 = functiona_schema_0 = None
384
+ getitem: "f32[3, 3]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
385
+ return (getitem, primals_1, primals_2)
386
+
387
+ class auto_functionalized_subgraph_0(torch.nn.Module):
388
+ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
389
+ add_: "f32[3, 3]" = torch.ops.aten.add_.Tensor(arg0_1, 1); arg0_1 = None
390
+
391
+ add: "f32[3, 3]" = torch.ops.aten.add.Tensor(add_, arg1_1); add_ = arg1_1 = None
392
+ return (add,)
393
+ """ , # noqa: B950
394
+ )
395
+
242
396
@torch ._dynamo .config .patch (assume_static_by_default = True )
243
397
def test_aot_eager (self ):
244
398
def inner (x , y ):
@@ -265,16 +419,17 @@ def f(x, y):
265
419
"""\
266
420
class GraphModule(torch.nn.Module):
267
421
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
268
- subgraph0 = self.subgraph0
269
- invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph0, primals_1, primals_2, scheme = 'nf4'); subgraph0 = None
270
- getitem: "f32[3, 3]" = invoke_quant_test[0]; invoke_quant_test = None
422
+ functiona_schema_0 = self.functiona_schema_0
423
+ auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
424
+ auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, operands0 = primals_1, operands1 = primals_2, scheme = 'nf4', _all_bases = [], _op_schema = functiona_schema_0); auto_functionalized_subgraph_0 = functiona_schema_0 = None
425
+ getitem: "f32[3, 3]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
271
426
return (getitem, primals_1, primals_2)
272
427
273
- class subgraph0 (torch.nn.Module):
428
+ class auto_functionalized_subgraph_0 (torch.nn.Module):
274
429
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
275
430
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1); arg0_1 = arg1_1 = None
276
- sin : "f32[3, 3]" = torch.ops.aten.sin .default(mm); mm = None
277
- cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin ); sin = None
431
+ sin_ : "f32[3, 3]" = torch.ops.aten.sin_ .default(mm); mm = None
432
+ cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin_ ); sin_ = None
278
433
return (cos,)
279
434
""" , # NOQA: B950
280
435
)
@@ -285,20 +440,21 @@ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
285
440
"""\
286
441
class GraphModule(torch.nn.Module):
287
442
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]", tangents_1: "f32[3, 3]"):
288
- subgraph1 = self.subgraph1
289
- invoke_quant_test_1 = torch.ops.higher_order.invoke_quant_test(subgraph1, primals_1, primals_2, tangents_1, scheme = 'nf4'); subgraph1 = primals_1 = primals_2 = tangents_1 = None
290
- getitem_1: "f32[3, 3]" = invoke_quant_test_1[0]
291
- getitem_2: "f32[3, 3]" = invoke_quant_test_1[1]; invoke_quant_test_1 = None
443
+ functiona_schema_1 = self.functiona_schema_1
444
+ auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
445
+ auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_1, operands0 = primals_1, operands1 = primals_2, operands2 = tangents_1, scheme = 'nf4', _all_bases = [], _op_schema = functiona_schema_1); auto_functionalized_subgraph_1 = primals_1 = primals_2 = tangents_1 = functiona_schema_1 = None
446
+ getitem_1: "f32[3, 3]" = auto_functionalized_v2_1[0]
447
+ getitem_2: "f32[3, 3]" = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None
292
448
return (getitem_1, getitem_2)
293
449
294
- class subgraph1 (torch.nn.Module):
450
+ class auto_functionalized_subgraph_1 (torch.nn.Module):
295
451
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]", arg2_1: "f32[3, 3]"):
296
452
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1)
297
453
clone: "f32[3, 3]" = torch.ops.aten.clone.default(mm)
298
- sin : "f32[3, 3]" = torch.ops.aten.sin .default(mm); mm = None
299
- cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin ); cos = None
300
- sin_1 : "f32[3, 3]" = torch.ops.aten.sin.default(sin ); sin = None
301
- neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin_1 ); sin_1 = None
454
+ sin_ : "f32[3, 3]" = torch.ops.aten.sin_ .default(mm); mm = None
455
+ cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin_ ); cos = None
456
+ sin : "f32[3, 3]" = torch.ops.aten.sin.default(sin_ ); sin_ = None
457
+ neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin ); sin = None
302
458
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg2_1, neg); arg2_1 = neg = None
303
459
cos_1: "f32[3, 3]" = torch.ops.aten.cos.default(clone); clone = None
304
460
mul_1: "f32[3, 3]" = torch.ops.aten.mul.Tensor(mul, cos_1); mul = cos_1 = None
@@ -320,21 +476,22 @@ def inner2(x, y):
320
476
321
477
x = torch .randn (3 , 3 )
322
478
y = torch .randn (3 , 3 )
479
+ x_clone = x .clone ()
480
+ y_clone = y .clone ()
323
481
324
482
@torch .compile (backend = "eager" , fullgraph = True )
325
483
def f (inner , x , y ):
326
484
return invoke_quant_test (inner , x , y , scheme = "nf4" )
327
485
486
+ compiled_f = torch .compile (f , backend = "eager" , fullgraph = True )
487
+
328
488
with self .assertRaisesRegex (
329
489
RuntimeError , "Encountered aliasing during higher order op tracing for HOP"
330
490
):
331
- f (inner , x , y )
491
+ compiled_f (inner , x , y )
332
492
333
- with self .assertRaisesRegex (
334
- RuntimeError ,
335
- "Encountered input mutation during higher order op tracing for HOP" ,
336
- ):
337
- f (inner2 , x , y )
493
+ compiled_out = compiled_f (inner2 , x , y )
494
+ self .assertEqual (compiled_out , f (inner2 , x_clone , y_clone ))
338
495
339
496
def test_eager_call (self ):
340
497
def inner (x , y ):
0 commit comments