8000 [export][reland] Convert autocast to HOO · pytorch/pytorch@60ee575 · GitHub
[go: up one dir, main page]

Skip to content

Commit 60ee575

Browse files
yushangdifacebook-github-bot
authored andcommitted
[export][reland] Convert autocast to HOO
Summary: Reland of D60206382. Suggested in #128394. If there's an autocast context manager, the predispatch (strict) graph can look something like: ``` class <lambda>(torch.nn.Module): def forward(self, x: "f32[1]"): ... _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None) mm: "f32[8, 8]" = torch.ops.aten.mm.default(rand, rand_1); rand = rand_1 = None _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = None return (mm_1,) ``` But the operator `torch.amp.autocast_mode._enter_autocast` is not a valid ATen op. We remove these nodes by turning autocast into a higher order operator and make a submodule for the blocks between `_enter_autocast` and `_exit_autocast`. Some potential followup improvement: 1) Merge some of the duplicated logic with `replace_set_grad_with_hop_pass.py` 2) Check the current autocast status (any enabled? dtype?) and not create a submodule if the autocast args matches current autocast status. Test Plan: CI ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_autocast" buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_set_grad" ``` Verified that now we can export the llama model in gh issue 128394 and the gemma model in gh issue 131829 without error. Differential Revision: D60770038
1 parent a672f6c commit 60ee575

File tree

8 files changed

+623
-99
lines changed

8 files changed

+623
-99
lines changed

test/export/test_export.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5454,7 +5454,8 @@ def forward(self, b_pred, b_t, x, y):
54545454
def forward(self, b_t, x, y):
54555455
submod_3 = self.submod_1
54565456
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
5457-
return (add_1,)""",
5457+
getitem = add_1[0]; add_1 = None
5458+
return (getitem,)""",
54585459
)
54595460

54605461
self.assertExpectedInline(
@@ -5464,7 +5465,7 @@ def forward(self, x, b_t, y):
54645465
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
54655466
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
54665467
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
5467-
return add_1""",
5468+
return (add_1,)""",
54685469
)
54695470

54705471
def test_predispatch_grad_wrappers(self):
@@ -6021,7 +6022,7 @@ def forward(self, x):
60216022
lt = item < 6
60226023
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression u1 < 6 on node 'lt'"); lt = _assert_scalar_default_3 = None
60236024
foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None
6024-
return foo_unbacked""",
6025+
return (foo_unbacked,)""",
60256026
)
60266027
ep_aot = ep_pre.run_decompositions()
60276028
self.assertExpectedInline(

0 commit comments

Comments
 (0)
0