8000 [ONNX] Inline prim::PythonOp for Autograd Function Export by shubhambhokare1 · Pull Request #74765 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Inline prim::PythonOp for Autograd Function Export #74765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

shubhambhokare1
Copy link
Collaborator
@shubhambhokare1 shubhambhokare1 commented Mar 25, 2022

Add flag (inline_autograd) to enable inline export of model consisting of autograd functions. Currently, this flag should only be used in TrainingMode.EVAL and not for training.

An example:

If a model containing autograd.Function is as follows

                class AutogradFunc(torch.autograd.Function):
                  @staticmethod
                  def forward(ctx, i):
                      result = i.exp()
                      result = result.log()
                      ctx.save_for_backward(result)
                      return result

Then the model is exported as

                graph(%0 : Float):
                  %1 : Float = ^AutogradFunc(%0)
                  return (%1)

If inline_autograd is set to True, this will be exported as

                graph(%0 : Float):
                  %1 : Float = onnx::Exp(%0)
                  %2 : Float = onnx::Log(%1)
                  return (%2)

If one of the ops within the autograd module is not supported, that particular node is exported as is mirroring ONNX_FALLTHROUGH mode

Fixes: #61813

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Mar 25, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 4f9ce2c (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 25, 2022
@shubhambhokare1 shubhambhokare1 changed the title [ONNX] API to inline prim::PythonOp for Autograd Function Export [ONNX] [DO NOT REVIEW] API to inline prim::PythonOp for Autograd Function Export Mar 25, 2022
@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/autograd-subgraph branch from f9ae28e to 562e715 Compare March 28, 2022 17:27
@shubhambhokare1 shubhambhokare1 changed the title [ONNX] [DO NOT REVIEW] API to inline prim::PythonOp for Autograd Function Export [ONNX] API to inline prim::PythonOp for Autograd Function Export Mar 28, 2022
@shubhambhokare1 shubhambhokare1 marked this pull request as ready for review March 28, 2022 19:36
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 28, 2022
Copy link
Collaborator
@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Shubham! This will make much better user experience when exporting autograd Function.

@BowenBao
Copy link
Collaborator

please rebase with master again to resolve conflict

@BowenBao BowenBao self-assigned this Mar 30, 2022
@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/autograd-subgraph branch from 687300f to 6b4be40 Compare April 27, 2022 07:38
@BowenBao
Copy link
Collaborator
BowenBao commented May 3, 2022

Please rebase with master to resolve conflict

@soulitzer soulitzer removed their request for review May 12, 2022 20:30
@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/autograd-subgraph branch from 11ae6e2 to 81c2eb4 Compare June 2, 2022 20:45
@shubhambhokare1 shubhambhokare1 requested a review from BowenBao June 6, 2022 18:16
@@ -590,17 +626,26 @@ static void _trace_post_record(
auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
node = unpacked;
}

std::vector<torch::jit::Node*> subgraph_trace_outputs;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

@BowenBao
Copy link
Collaborator
BowenBao commented Jun 7, 2022

please rebase with master then check onnx CI, failures may be related.

       5 failed
         - test/onnx/test_utility_funs.py:1209 TestUtilityFuns_opset14.test_autograd_module_name
         - test/onnx/test_utility_funs.py:1178 TestUtilityFuns_opset14.test_autograd_onnx_fallthrough
         - test/onnx/test_utility_funs.py:1209 TestUtilityFuns_opset15.test_autograd_module_name
         - test/onnx/test_utility_funs.py:1178 TestUtilityFuns_opset15.test_autograd_onnx_fallthrough
         - test/onnx/test_utility_funs.py:489 TestUtilityFuns_opset15.test_constant_fold_div

@shubhambhokare1 shubhambhokare1 changed the title [ONNX] API to inline prim::PythonOp for Autograd Function Export [ONNX] Inline prim::PythonOp for Autograd Function Export Jun 8, 2022
@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/autograd-subgraph branch from 08876fa to c0c4af5 Compare June 9, 2022 18:08
namespace torch {
namespace jit {

void convertSubgraphToSubBlock(Block* block) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wonder if we can create subblock directly for autograd node? Might affect other existing jit passes maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will create an issue to optimize this conversion

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please comment with link to issue if possible...

@shubhambhokare1 shubhambhokare1 force-pushed the sbhokare/autograd-subgraph branch from 8310353 to 4f9ce2c Compare July 25, 2022 19:07
@shubhambhokare1
Copy link
Collaborator Author

@albanD Rebased, all CI is green except Meta Internal-Only Changes Check

@shubhambhokare1
Copy link
Collaborator Author

Any updates @albanD?

@BowenBao
Copy link
Collaborator
BowenBao commented Aug 3, 2022

@malfet please help import and land... I think it's dragging too long and @shubhambhokare1 might need another rebase now after it was rebased last time for merge.

@facebook-github-bot
Copy link
Contributor

@malfet has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@malfet
Copy link
Contributor
malfet commented Aug 3, 2022

@pytorchbot merge -f "Internal changes are OK"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

facebook-github-bot pushed a commit that referenced this pull request Aug 4, 2022
…74765)

Summary:
Add flag (inline_autograd) to enable inline export of model consisting of autograd functions. Currently, this flag should only be used in TrainingMode.EVAL and not for training.

An example:

If a model containing ``autograd.Function`` is as follows
```
                class AutogradFunc(torch.autograd.Function):
                  staticmethod
                  def forward(ctx, i):
                      result = i.exp()
                      result = result.log()
                      ctx.save_for_backward(result)
                      return result
```
Then the model is exported as
```
                graph(%0 : Float):
                  %1 : Float = ^AutogradFunc(%0)
                  return (%1)
```
If inline_autograd is set to True, this will be exported as
```
                graph(%0 : Float):
                  %1 : Float = onnx::Exp(%0)
                  %2 : Float = onnx::Log(%1)
                  return (%2)
```

If one of the ops within the autograd module is not supported, that particular node is exported as is mirroring ONNX_FALLTHROUGH mode

Fixes: #61813

Pull Request resolved: #74765
Approved by: https://github.com/BowenBao, https://github.com/malfet

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/95d873855e6b5a7b44e102d3aec81d6db3215c0f

Original Phabricator Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Reviewed By: george-qi, kit1980

Differential Revision: D37738323

fbshipit-source-id: 03ff75a809403b134c2a545952706cbeac8d0065
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue onnx-needs-import This PR is related to ONNX, but touches files outside of merge rule patterns, and hence needs import open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ONNX export fails for trivial torch.autograd.Function
8 participants
0