-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[export][reland] Convert autocast to HOO #132677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Reland of D60206382. Suggested in pytorch#128394. If there's an autocast context manager, the predispatch (strict) graph can look something like: ``` class <lambda>(torch.nn.Module): def forward(self, x: "f32[1]"): ... _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None) mm: "f32[8, 8]" = torch.ops.aten.mm.default(rand, rand_1); rand = rand_1 = None _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = None return (mm_1,) ``` But the operator `torch.amp.autocast_mode._enter_autocast` is not a valid ATen op. We remove these nodes by turning autocast into a higher order operator and make a submodule for the blocks between `_enter_autocast` and `_exit_autocast`. Some potential followup improvement: 1) Merge some of the duplicated logic with `replace_set_grad_with_hop_pass.py` 2) Check the current autocast status (any enabled? dtype?) and not create a submodule if the autocast args matches current autocast status. Test Plan: CI ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_autocast" buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_set_grad" ``` Verified that now we can export the llama model in gh issue 128394 and the gemma model in gh issue 131829 without error. Differential Revision: D60770038
This pull request was exported from Phabricator. Differential Revision: D60770038 |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: Reland of D60206382. Suggested in #128394. If there's an autocast context manager, the predispatch (strict) graph can look something like: ``` class <lambda>(torch.nn.Module): def forward(self, x: "f32[1]"): ... _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None) mm: "f32[8, 8]" = torch.ops.aten.mm.default(rand, rand_1); rand = rand_1 = None _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = None return (mm_1,) ``` But the operator `torch.amp.autocast_mode._enter_autocast` is not a valid ATen op. We remove these nodes by turning autocast into a higher order operator and make a submodule for the blocks between `_enter_autocast` and `_exit_autocast`. Some potential followup improvement: 1) Merge some of the duplicated logic with `replace_set_grad_with_hop_pass.py` 2) Check the current autocast status (any enabled? dtype?) and not create a submodule if the autocast args matches current autocast status. Test Plan: CI ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_autocast" buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_set_grad" ``` Verified that now we can export the llama model in gh issue 128394 and the gemma model in gh issue 131829 without error. Differential Revision: D60770038 Pull Request resolved: #132677 Approved by: https://github.com/angelayi
Summary: Reland of D60206382. Suggested in #128394. If there's an autocast context manager, the predispatch (strict) graph can look something like: ``` class <lambda>(torch.nn.Module): def forward(self, x: "f32[1]"): ... _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None) mm: "f32[8, 8]" = torch.ops.aten.mm.default(rand, rand_1); rand = rand_1 = None _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = None return (mm_1,) ``` But the operator `torch.amp.autocast_mode._enter_autocast` is not a valid ATen op. We remove these nodes by turning autocast into a higher order operator and make a submodule for the blocks between `_enter_autocast` and `_exit_autocast`. Some potential followup improvement: 1) Merge some of the duplicated logic with `replace_set_grad_with_hop_pass.py` 2) Check the current autocast status (any enabled? dtype?) and not create a submodule if the autocast args matches current autocast status. Test Plan: CI ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_autocast" buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r "test_predispatch_set_grad" ``` Verified that now we can export the llama model in gh issue 128394 and the gemma model in gh issue 131829 without error. Differential Revision: D60770038 Pull Request resolved: #132677 Approved by: https://github.com/angelayi
def _replace_with_hop_helper( | ||
node: torch.fx.Node, | ||
enter_block_node: torch.fx.Node, | ||
node_filter: Callable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yushangdi I'm wondering what is the purpose of that node_filter
?
As far as I can see it is not used at all, making it a dead argument. Was the idea to use it eventually?
If not, should it be removed? It looks like the usages would become significantly simpler if they don't have to pass in that filter function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, yeah I think it can be removed.
@yushangdi I'm wondering what is the purpose if that
node_filter
?As far as I can see it is not used at all, making it a dead argument. Was the idea to use it eventually?
If not, should it be removed? It looks like the usages would become significantly simpler if they don't have to pass in that filter function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR for removing it: #141983
Summary:
Reland of D60206382.
Suggested in #128394.
If there's an autocast context manager, the predispatch (strict) graph can look something like:
But the operator
torch.amp.autocast_mode._enter_autocast
is not a valid ATen op. We remove these nodes by turning autocast into a higher order operator and make a submodule for the blocks between_enter_autocast
and_exit_autocast
.Some potential followup improvement:
replace_set_grad_with_hop_pass.py
Test Plan:
CI
Verified that now we can export the llama model in gh issue 128394 and the gemma model in gh issue 131829 without error.
Differential Revision: D60770038