8000 [dynamo] Initial support for `nonstrict_trace` by StrongerXi · Pull Request #146367 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] Initial support for nonstrict_trace #146367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from

Conversation

StrongerXi
Copy link
Contributor
@StrongerXi StrongerXi commented Feb 4, 2025

Stack from ghstack (oldest at bottom):

Context

Note: mark_traceable got renamed to nonstrict_trace after
offline discussion. The reasons are (1) it aligns with torch.export's
nonstrict notion, and (2) it's more definitive in behavior suggestion.

  1. Overall Design
  2. Dynamo graph representation with torch._higher_order_ops.flat_apply

Summary

This patch adds a torch._dynamo.nonstrict_trace decorator, which
currently is an enhanced version of torch._dynamo.allow_in_graph (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo VariableTracker (while accounting for buffered side-effects and
guards). This patch takes a middle-ground (simple implementation with a
bit of user labor), by

  1. asking the user to provide pytree registration for non-proxy-able
    input types,
  2. letting Dynamo trace through pytree_flatten (which accounts for
    buffered side-effects and guards automatically),
  3. and passing in the TreeSpec as a graph attribute constant into
    torch._higher_order_ops.flat_apply (which unflattens the inputs and
    invokes the underlying function).

Next Steps

In subsequent patches, we will try to support the following:

  • annotating on class method
  • reads to global tensors
  • inputs that contains pytree.register_constant-ed instances.
  • function as input
  • more output types (e.g., any pytree-registered type)
  • torch.nn.Module as inputs

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

[ghstack-poisoned]
@StrongerXi StrongerXi requested a review from zou3519 as a code owner February 4, 2025 00:05
Copy link
pytorch-bot bot commented Feb 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146367

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7f34954 with merge base 6061664 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

StrongerXi added a commit that referenced this pull request Feb 4, 2025
ghstack-source-id: 50c635d
Pull Request resolved: #146367
@@ -97,7 +97,7 @@ def impl(func, in_spec, *flat_args):
assert (
isinstance(out, torch.Tensor)
or isinstance(out, (tuple, list))
and all(isinstance(x, torch.Tensor) for x in out)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows me to test a bit more things.

Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty goood

Comment on lines 1021 to 1031
input_vt = TupleVariable.build(
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
)
# TODO handle exception here, in case user forgets to pytree register?
out_vt = variables.UserFunctionVariable(pytree.tree_flatten).call_function(
tx, [input_vt], {}
)
8000 assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2
flat_args_vts, in_spec_vt = out_vt.items
assert isinstance(flat_args_vts, ListVariable)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably just call the flat_apply.to_graphable function, which will handle the case if the user forgets to pytree register

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I'll do that, and I'll likely capture the RuntimeError here so we can generate a mark_traceable-specific error message for users.

"""
TODO doc
"""
assert callable(fn), "mark_traceable expects a callable"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's an easy way to tell that this function is a global function (vs some nested user function). IIRC some of the assumptions fall apart for nested user function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep lemme codify/document these constraints a bit.

Comment on lines 1060 to 1062
# - Maybe we can have `flat_apply` return the output spec, so that
# Dynamo can unflatten and wrap the result.
#
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm yeah I think flat_apply needs to return the output spec

return x * d["a"]

8000 def fn(x, d):
d["a"] = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also try mutating a dict with a Tensor

Comment on lines 1308 to 1310
# Alternatives:
# 1. use `PyCodegen` to generate the bytecode, and invoke the function
# to reconstruct the python objects.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyCodegen seems like the most legit way to do this yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but I'm debating a bit -- the codegen doesn't feel like an easily reusable component for our purposes here.

Luckily we probably only need reconstruct_to_python_object for TreeSpec, which should be a much narrower space to support, and maybe we can get away with this as_python_constant-flavor impl. I'll think more.

@anijain2305 anijain2305 self-requested a review February 4, 2025 17:47
[ghstack-poisoned]
StrongerXi added a commit that referenced this pull request Feb 7, 2025
TODO:
1. add fx graph test (to check `flat_apply`)
2. add test for user-facing require-pytree-registration error
3. document constraints of `mark_traceable`
4. document why `reconstruct_to_python_object` impl is okay for input spec.

ghstack-source-id: 762f004
Pull Request resolved: #146367
[ghstack-poisoned]
StrongerXi added a commit that referenced this pull request Feb 11, 2025
Context:
1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

Summary:
This patch adds a `torch._dynamo.mark_traceable` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences).

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

Next Steps:
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

ghstack-source-id: ae344be
Pull Request resolved: #146367
@StrongerXi StrongerXi changed the title [dynamo][EXPERIMENT] Prototype for mark_traceable [dynamo] Initial support for mark_traceable Feb 11, 2025
[ghstack-poisoned]
StrongerXi added a commit that referenced this pull request Feb 11, 2025
Context:
1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

Summary:
This patch adds a `torch._dynamo.mark_traceable` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences).

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

Next Steps:
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

ghstack-source-id: 515ad42
Pull Request resolved: #146367
name = f"{base}_{i}"

raise AssertionError("unreachable")
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
Copy link
Contributor Author
@StrongerXi StrongerXi Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This chunk is just refactoring name generation into this call.

[ghstack-poisoned]
StrongerXi added a commit that referenced this pull request Feb 11, 2025
Context:
1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

Summary:
This patch adds a `torch._dynamo.mark_traceable` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences).

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

Next Steps:
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

ghstack-source-id: 831063e
Pull Request resolved: #146367
Comment on lines 967 to 976
@dataclasses.dataclass(frozen=True, repr=False)
class LeafSpec(TreeSpec):
def __init__(self) -> None:
super().__init__(None, None, [])
type: Any = dataclasses.field(default=None, init=False)
context: Context = dataclasses.field(default=None, init=False)
children_specs: list["TreeSpec"] = dataclasses.field(
default_factory=list, init=False
)

def __post_init__(self) -> None:
# Override `__post_init__` for `num_leaves` derivation.
Copy link
Contributor Author
@StrongerXi StrongerXi Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically making LeafSpec a true dataclass whose __init__ signature accords with dataclasses.field(leaf_spec_instance). This avoids overspecializing FrozenDataclassVariable. as_python_constant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a bit more why we needed to change LeafSpec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, so without this change, the implementation in FrozenDataclassVariable. as_python_constant (which is basically a more correct variation of FrozenDataclassVariable.as_proxy) wouldn't work for LeafSpec, because the method assumes a correlation between dataclasses.fields(obj) and obj.__class__.__init__:

args = []
kwargs = {}
for field in fields(self.value):
    if field.init:
        data = self.fields[field.name].as_python_constant()
        if getattr(field, "kw_only", False):
            kwargs[field.name] = data
        else:
            args.append(data)
                                                       
ctor = self.python_type()
return ctor(*args, **kwargs)

Specifically, in the old LeafSpec, __init__ takes in no argument, but fields(obj) would give the type, context and children_spec fields with init=True (inherited from TreeSpec), so the impl above would mistakenly give too many positional args to __init__.

So we could either

  1. specialize FrozenDataclassVariable.as_python_constant to LeafSpec and TreeSpec by grabbing the relevant fields and invoking constructor manually
  2. do what I did in this patch, which I chose because it also makes it easier to reason about the LeafSpec dataclass in general (with the dataclasses.fields method).

Lmk what you think, I have no strong preference here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@torch.compile(fullgraph=True, backend=backend)
def fn(x, y):
t0 = x + 1
t1 = func(x, y, t0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update this example to include a pytree-registered input to make it less boring.

[ghstack-poisoned]
return t0 * t2

x, y = torch.randn(10), torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a few self.assertExpectedInline tests? Easier to reason about when you have a correct answer in front of you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 requested this and I added one to test_flat_apply.py as suggested. My main concern about those tests is that the captured graph is kinda an impl detail and subject to change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put all the assertExpectedInline tests into one file (test_flat_apply). The point of them is that it's easy to modify them if something does change (use the envvar)

Copy link
Contributor
@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good stuff. Looks very close to be complete. We should make the docstring better for mark_traceable, but everything else is good.

Comment on lines +186 to +187
# This line allows us to reuse much of the `allow_in_graph` impl.
trace_rules._allowed_callable_ids.add(id(wrapped))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id(obj) can be reused if obj is deleted. We can still do this, but need to install a weakref callback to remove the ID if obj is deleted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep created #147777 to track this, I'll try fixing it for all the relevant decorators in a subsequent patch.

[ghstack-poisoned]
@StrongerXi StrongerXi changed the title [dynamo] Initial support for mark_traceable [dynamo] Initial support for nonstrict_trace Feb 25, 2025
@StrongerXi
Copy link
Contributor Author

Address feedbacks and rename mark_traceable to nonstrict_trace

@@ -964,11 +964,16 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree:
return unflatten_fn(child_pytrees, self.context)


@dataclasses.dataclass(frozen=True, repr=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this affect the inherited __eq__ in any way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dataclasses.dataclass has eq=True by default, and will create a semantically equivalent LeafSpec.__eq__ as TreeSpec.__eq__. Lemme use eq=False explicitly here to demand reuse. In theory one can also provide a faster LeafSpec.__eq__ that's type based only (rather than comparing individual fields, in the generated __eq__).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for checking

Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. I had some last minor comments, please read

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #147572

pytorchmergebot pushed a commit that referenced this pull request Feb 26, 2025
…46950)

This patch removes some duplicated name generation logic in Dynamo.

Pull Request resolved: #146950
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367
pytorchmergebot pushed a commit that referenced this pull request Feb 26, 2025
As title, also see
1. new test `test_nonstrict_trace_on_method` for example.
2. newly added comments for why we need special treatment on methods.

Pull Request resolved: #147571
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950
pytorchmergebot pushed a commit that referenced this pull request Feb 26, 2025
…`-ed function (#147572)

As title. Without this patch we get the following error:

Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite
work for AOTAutograd, which also needs to fake-tensor-propagate the
`nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the
`nonstrict_trace` processing and put the `flat_apply(...)` node into the graph.

So we can't easily to temporarily enable the `allow_non_fake_inputs`
flag on current fake mode, when AOTAutograd processes a `flat_apply`
node from Dynamo's `nonstrict_trace` handling. And after discussing
with zou3519, I decided to add a global `FakeTensorTLS` that contains a
`allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed
function to temporarily tweak this flag during its execution.

Pull Request resolved: #147572
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950, #147571
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Feb 27, 2025
Summary:
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

X-link: pytorch/pytorch#146367
Approved by: https://github.com/zou3519
ghstack dependencies: #146714

Reviewed By: wdvr

Differential Revision: D70270803

fbshipit-source-id: d37fb2d7f79e2f7cc9e200e0c796f7087fe093bf
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

Pull Request resolved: #146367
Approved by: https://github.com/zou3519
ghstack dependencies: #146714
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
…46950)

This patch removes some duplicated name generation logic in Dynamo.

Pull Request resolved: #146950
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
As title, also see
1. new test `test_nonstrict_trace_on_method` for example.
2. newly added comments for why we need special treatment on methods.

Pull Request resolved: #147571
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
…`-ed function (#147572)

As title. Without this patch we get the following error:

Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite
work for AOTAutograd, which also needs to fake-tensor-propagate the
`nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the
`nonstrict_trace` processing and put the `flat_apply(...)` node into the graph.

So we can't easily to temporarily enable the `allow_non_fake_inputs`
flag on current fake mode, when AOTAutograd processes a `flat_apply`
node from Dynamo's `nonstrict_trace` handling. And after discussing
with zou3519, I decided to add a global `FakeTensorTLS` that contains a
`allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed
function to temporarily tweak this flag during its execution.

Pull Request resolved: #147572
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950, #147571
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

Pull Request resolved: pytorch#146367
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…torch#146950)

This patch removes some duplicated name generation logic in Dynamo.

Pull Request resolved: pytorch#146950
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714, pytorch#146367
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
As title, also see
1. new test `test_nonstrict_trace_on_method` for example.
2. newly added comments for why we need special treatment on methods.

Pull Request resolved: pytorch#147571
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714, pytorch#146367, pytorch#146950
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…`-ed function (pytorch#147572)

As title. Without this patch we get the following error:

Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite
work for AOTAutograd, which also needs to fake-tensor-propagate the
`nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the
`nonstrict_trace` processing and put the `flat_apply(...)` node into the graph.

So we can't easily to temporarily enable the `allow_non_fake_inputs`
flag on current fake mode, when AOTAutograd processes a `flat_apply`
node from Dynamo's `nonstrict_trace` handling. And after discussing
with zou3519, I decided to add a global `FakeTensorTLS` that contains a
`allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed
function to temporarily tweak this flag during its execution.

Pull Request resolved: pytorch#147572
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#146714, pytorch#146367, pytorch#146950, pytorch#147571
@github-actions github-actions bot deleted the gh/StrongerXi/81/head branch March 30, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0