8000 [hop] Support more output types for `flat_apply` by StrongerXi · Pull Request #146714 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[hop] Support more output types for flat_apply #146714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from

Conversation

StrongerXi
Copy link
Contributor
@StrongerXi StrongerXi commented Feb 7, 2025

Stack from ghstack (oldest at bottom):

This patch enables flat_apply to support certain non-Tensor output
types like containers and graphable types. This will in turn enable the
upcoming mark_traceable to support more output types.

The patch also exposes a func_to_graphable rather than having the
users calling the lower level pytree.flatten(ConstantFunction(...)).

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Feb 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146714

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 86a8231 with merge base 6061664 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
StrongerXi 8000 added 3 commits February 13, 2025 23:37
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #147572

pytorchmergebot pushed a commit that referenced this pull request Feb 26, 2025
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

Pull Request resolved: #146367
Approved by: https://github.com/zou3519
ghstack dependencies: #146714
pytorchmergebot pushed a commit that referenced this pull request Feb 26, 2025
…46950)

This patch removes some duplicated name generation logic in Dynamo.

Pull Request resolved: #146950
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367
pytorchmergebot pushed a commit that referenced this pull request Feb 26, 2025
As title, also see
1. new test `test_nonstrict_trace_on_method` for example.
2. newly added comments for why we need special treatment on methods.

Pull Request resolved: #147571
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950
pytorchmergebot pushed a commit that referenced this pull request Feb 26, 2025
…`-ed function (#147572)

As title. Without this patch we get the following error:

Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite
work for AOTAutograd, which also needs to fake-tensor-propagate the
`nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the
`nonstrict_trace` processing and put the `flat_apply(...)` node into the graph.

So we can't easily to temporarily enable the `allow_non_fake_inputs`
flag on current fake mode, when AOTAutograd processes a `flat_apply`
node from Dynamo's `nonstrict_trace` handling. And after discussing
with zou3519, I decided to add a global `FakeTensorTLS` that contains a
`allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed
function to temporarily tweak this flag during its execution.

Pull Request resolved: #147572
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950, #147571
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
This patch enables `flat_apply` to support certain non-Tensor output
types like containers and graphable types. This will in turn enable the
upcoming `mark_traceable` to support more output types.

The patch also exposes a `func_to_graphable` rather than having the
users calling the lower level `pytree.flatten(ConstantFunction(...))`.

Pull Request resolved: #146714
Approved by: https://github.com/zou3519
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

Pull Request resolved: #146367
Approved by: https://github.com/zou3519
ghstack dependencies: #146714
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
…46950)

This patch removes some duplicated name generation logic in Dynamo.

Pull Request resolved: #146950
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
As title, also see
1. new test `test_nonstrict_trace_on_method` for example.
2. newly added comments for why we need special treatment on methods.

Pull Request resolved: #147571
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
…`-ed function (#147572)

As title. Without this patch we get the following error:

Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite
work for AOTAutograd, which also needs to fake-tensor-propagate the
`nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the
`nonstrict_trace` processing and put the `flat_apply(...)` node into the graph.

So we can't easily to temporarily enable the `allow_non_fake_inputs`
flag on current fake mode, when AOTAutograd processes a `flat_apply`
node from Dynamo's `nonstrict_trace` handling. And after discussing
with zou3519, I decided to add a global `FakeTensorTLS` that contains a
`allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed
function to temporarily tweak this flag during its execution.

Pull Request resolved: #147572
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950, #147571
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
This patch enables `flat_apply` to support certain non-Tensor output
types like containers and graphable types. This will in turn enable the
upcoming `mark_traceable` to support more output types.

The patch also exposes a `func_to_graphable` rather than having the
users calling the lower level `pytree.flatten(ConstantFunction(...))`.

Pull Request resolved: pytorch#146714
Approved by: https://github.com/zou3519
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

Pull Request resolved: pytorch#146367
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…torch#146950)

This patch removes some duplicated name generation logic in Dynamo.

Pull Request resolved: pytorch#146950
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714, pytorch#146367
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
As title, also see
1. new test `test_nonstrict_trace_on_method` for example.
2. newly added comments for why we need special treatment on methods.

Pull Request resolved: pytorch#147571
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714, pytorch#146367, pytorch#146950
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…`-ed function (pytorch#147572)

As title. Without this patch we get the following error:

Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite
work for AOTAutograd, which also needs to fake-tensor-propagate the
`nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the
`nonstrict_trace` processing and put the `flat_apply(...)` node into the graph.

So we can't easily to temporarily enable the `allow_non_fake_inputs`
flag on current fake mode, when AOTAutograd processes a `flat_apply`
node from Dynamo's `nonstrict_trace` handling. And after discussing
with zou3519, I decided to add a global `FakeTensorTLS` that contains a
`allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed
function to temporarily tweak this flag during its execution.

Pull Request resolved: pytorch#147572
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714, pytorch#146367, pytorch#146950, pytorch#147571
@github-actions github-actions bot deleted the gh/StrongerXi/83/head branch March 30, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0