-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[dynamo] Initial support for nonstrict_trace
#146367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146367
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7f34954 with merge base 6061664 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -97,7 +97,7 @@ def impl(func, in_spec, *flat_args): | |||
assert ( | |||
isinstance(out, torch.Tensor) | |||
or isinstance(out, (tuple, list)) | |||
and all(isinstance(x, torch.Tensor) for x in out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows me to test a bit more things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks pretty goood
torch/_dynamo/variables/torch.py
Outdated
input_vt = TupleVariable.build( | ||
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) | ||
) | ||
# TODO handle exception here, in case user forgets to pytree register? | ||
out_vt = variables.UserFunctionVariable(pytree.tree_flatten).call_function( | ||
tx, [input_vt], {} | ||
) | ||
8000 assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 | ||
flat_args_vts, in_spec_vt = out_vt.items | ||
assert isinstance(flat_args_vts, ListVariable) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably just call the flat_apply.to_graphable function, which will handle the case if the user forgets to pytree register
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, I'll do that, and I'll likely capture the RuntimeError
here so we can generate a mark_traceable
-specific error message for users.
torch/_dynamo/decorators.py
Outdated
""" | ||
TODO doc | ||
""" | ||
assert callable(fn), "mark_traceable expects a callable" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there's an easy way to tell that this function is a global function (vs some nested user function). IIRC some of the assumptions fall apart for nested user function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep lemme codify/document these constraints a bit.
torch/_dynamo/variables/torch.py
Outdated
# - Maybe we can have `flat_apply` return the output spec, so that | ||
# Dynamo can unflatten and wrap the result. | ||
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mm yeah I think flat_apply needs to return the output spec
test/dynamo/test_decorators.py
Outdated
return x * d["a"] | ||
|
||
8000 def fn(x, d): | ||
d["a"] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should also try mutating a dict with a Tensor
# Alternatives: | ||
# 1. use `PyCodegen` to generate the bytecode, and invoke the function | ||
# to reconstruct the python objects. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyCodegen seems like the most legit way to do this yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, but I'm debating a bit -- the codegen doesn't feel like an easily reusable component for our purposes here.
Luckily we probably only need reconstruct_to_python_object
for TreeSpec
, which should be a much narrower space to support, and maybe we can get away with this as_python_constant
-flavor impl. I'll think more.
Context: 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) Summary: This patch adds a `torch._dynamo.mark_traceable` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). Next Steps: In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs ghstack-source-id: ae344be Pull Request resolved: #146367
mark_traceable
mark_traceable
Context: 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) Summary: This patch adds a `torch._dynamo.mark_traceable` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). Next Steps: In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs ghstack-source-id: 515ad42 Pull Request resolved: #146367
name = f"{base}_{i}" | ||
|
||
raise AssertionError("unreachable") | ||
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This chunk is just refactoring name generation into this call.
Context: 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) Summary: This patch adds a `torch._dynamo.mark_traceable` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). Next Steps: In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs ghstack-source-id: 831063e Pull Request resolved: #146367
torch/utils/_pytree.py
Outdated
@dataclasses.dataclass(frozen=True, repr=False) | ||
class LeafSpec(TreeSpec): | ||
def __init__(self) -> None: | ||
super().__init__(None, None, []) | ||
type: Any = dataclasses.field(default=None, init=False) | ||
context: Context = dataclasses.field(default=None, init=False) | ||
children_specs: list["TreeSpec"] = dataclasses.field( | ||
default_factory=list, init=False | ||
) | ||
|
||
def __post_init__(self) -> None: | ||
# Override `__post_init__` for `num_leaves` derivation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically making LeafSpec
a true dataclass whose __init__
signature accords with dataclasses.field(leaf_spec_instance)
. This avoids overspecializing FrozenDataclassVariable. as_python_constant
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain a bit more why we needed to change LeafSpec?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, so without this change, the implementation in FrozenDataclassVariable. as_python_constant
(which is basically a more correct variation of FrozenDataclassVariable.as_proxy
) wouldn't work for LeafSpec
, because the method assumes a correlation between dataclasses.fields(obj)
and obj.__class__.__init__
:
args = []
kwargs = {}
for field in fields(self.value):
if field.init:
data = self.fields[field.name].as_python_constant()
if getattr(field, "kw_only", False):
kwargs[field.name] = data
else:
args.append(data)
ctor = self.python_type()
return ctor(*args, **kwargs)
Specifically, in the old LeafSpec
, __init__
takes in no argument, but fields(obj)
would give the type
, context
and children_spec
fields with init=True
(inherited from TreeSpec
), so the impl above would mistakenly give too many positional args to __init__
.
So we could either
- specialize
FrozenDataclassVariable.as_python_constant
toLeafSpec
andTreeSpec
by grabbing the relevant fields and invoking constructor manually - do what I did in this patch, which I chose because it also makes it easier to reason about the
LeafSpec
dataclass in general (with thedataclasses.fields
method).
Lmk what you think, I have no strong preference here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
test/dynamo/test_flat_apply.py
Outdated
@torch.compile(fullgraph=True, backend=backend) | ||
def fn(x, y): | ||
t0 = x + 1 | ||
t1 = func(x, y, t0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update this example to include a pytree-registered input to make it less boring.
return t0 * t2 | ||
|
||
x, y = torch.randn(10), torch.randn(10) | ||
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a few self.assertExpectedInline
tests? Easier to reason about when you have a correct answer in front of you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 requested this and I added one to test_flat_apply.py
as suggested. My main concern about those tests is that the captured graph is kinda an impl detail and subject to change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should put all the assertExpectedInline tests into one file (test_flat_apply). The point of them is that it's easy to modify them if something does change (use the envvar)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is good stuff. Looks very close to be complete. We should make the docstring better for mark_traceable
, but everything else is good.
# This line allows us to reuse much of the `allow_in_graph` impl. | ||
trace_rules._allowed_callable_ids.add(id(wrapped)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
id(obj)
can be reused if obj
is deleted. We can still do this, but need to install a weakref callback to remove the ID if obj
is deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep created #147777 to track this, I'll try fixing it for all the relevant decorators in a subsequent patch.
mark_traceable
nonstrict_trace
Address feedbacks and rename |
torch/utils/_pytree.py
Outdated
@@ -964,11 +964,16 @@ def unflatten(self, leaves: Iterable[Any]) -> PyTree: | |||
return unflatten_fn(child_pytrees, self.context) | |||
|
|||
|
|||
@dataclasses.dataclass(frozen=True, repr=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this affect the inherited __eq__
in any way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dataclasses.dataclass
has eq=True
by default, and will create a semantically equivalent LeafSpec.__eq__
as TreeSpec.__eq__
. Lemme use eq=False
explicitly here to demand reuse. In theory one can also provide a faster LeafSpec.__eq__
that's type based only (rather than comparing individual fields, in the generated __eq__
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for checking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I had some last minor comments, please read
Starting merge as part of PR stack under #147572 |
…46950) This patch removes some duplicated name generation logic in Dynamo. Pull Request resolved: #146950 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367
As title, also see 1. new test `test_nonstrict_trace_on_method` for example. 2. newly added comments for why we need special treatment on methods. Pull Request resolved: #147571 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367, #146950
…`-ed function (#147572) As title. Without this patch we get the following error: Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite work for AOTAutograd, which also needs to fake-tensor-propagate the `nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the `nonstrict_trace` processing and put the `flat_apply(...)` node into the graph. So we can't easily to temporarily enable the `allow_non_fake_inputs` flag on current fake mode, when AOTAutograd processes a `flat_apply` node from Dynamo's `nonstrict_trace` handling. And after discussing with zou3519, I decided to add a global `FakeTensorTLS` that contains a `allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed function to temporarily tweak this flag during its execution. Pull Request resolved: #147572 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367, #146950, #147571
Summary: ## Context > **Note:** `mark_traceable` got renamed to `nonstrict_trace` after > offline discussion. The reasons are (1) it aligns with `torch.export`'s > `nonstrict` notion, and (2) it's more definitive in behavior suggestion. 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) ## Summary This patch adds a `torch._dynamo.nonstrict_trace` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). Specifically, this patch focuses on the UI and functionality prototyping/plumbing. The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). ## Next Steps In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs X-link: pytorch/pytorch#146367 Approved by: https://github.com/zou3519 ghstack dependencies: #146714 Reviewed By: wdvr Differential Revision: D70270803 fbshipit-source-id: d37fb2d7f79e2f7cc9e200e0c796f7087fe093bf
## Context > **Note:** `mark_traceable` got renamed to `nonstrict_trace` after > offline discussion. The reasons are (1) it aligns with `torch.export`'s > `nonstrict` notion, and (2) it's more definitive in behavior suggestion. 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) ## Summary This patch adds a `torch._dynamo.nonstrict_trace` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). Specifically, this patch focuses on the UI and functionality prototyping/plumbing. The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). ## Next Steps In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs Pull Request resolved: #146367 Approved by: https://github.com/zou3519 ghstack dependencies: #146714
…46950) This patch removes some duplicated name generation logic in Dynamo. Pull Request resolved: #146950 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367
As title, also see 1. new test `test_nonstrict_trace_on_method` for example. 2. newly added comments for why we need special treatment on methods. Pull Request resolved: #147571 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367, #146950
…`-ed function (#147572) As title. Without this patch we get the following error: Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite work for AOTAutograd, which also needs to fake-tensor-propagate the `nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the `nonstrict_trace` processing and put the `flat_apply(...)` node into the graph. So we can't easily to temporarily enable the `allow_non_fake_inputs` flag on current fake mode, when AOTAutograd processes a `flat_apply` node from Dynamo's `nonstrict_trace` handling. And after discussing with zou3519, I decided to add a global `FakeTensorTLS` that contains a `allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed function to temporarily tweak this flag during its execution. Pull Request resolved: #147572 Approved by: https://github.com/zou3519 ghstack dependencies: #146714, #146367, #146950, #147571
## Context > **Note:** `mark_traceable` got renamed to `nonstrict_trace` after > offline discussion. The reasons are (1) it aligns with `torch.export`'s > `nonstrict` notion, and (2) it's more definitive in behavior suggestion. 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) ## Summary This patch adds a `torch._dynamo.nonstrict_trace` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). Specifically, this patch focuses on the UI and functionality prototyping/plumbing. The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). ## Next Steps In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs Pull Request resolved: pytorch#146367 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#146714
…torch#146950) This patch removes some duplicated name generation logic in Dynamo. Pull Request resolved: pytorch#146950 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#146714, pytorch#146367
As title, also see 1. new test `test_nonstrict_trace_on_method` for example. 2. newly added comments for why we need special treatment on methods. Pull Request resolved: pytorch#147571 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#146714, pytorch#146367, pytorch#146950
…`-ed function (pytorch#147572) As title. Without this patch we get the following error: Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite work for AOTAutograd, which also needs to fake-tensor-propagate the `nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the `nonstrict_trace` processing and put the `flat_apply(...)` node into the graph. So we can't easily to temporarily enable the `allow_non_fake_inputs` flag on current fake mode, when AOTAutograd processes a `flat_apply` node from Dynamo's `nonstrict_trace` handling. And after discussing with zou3519, I decided to add a global `FakeTensorTLS` that contains a `allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed function to temporarily tweak this flag during its execution. Pull Request resolved: pytorch#147572 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#146714, pytorch#146367, pytorch#146950, pytorch#147571
Stack from ghstack (oldest at bottom):
nonstrict_trace
-ed function #147572nonstrict_trace
on class method #147571get_unique_name_wrt
helper when applicable #146950nonstrict_trace
#146367flat_apply
#146714Context
torch._higher_order_ops.flat_apply
Summary
This patch adds a
torch._dynamo.nonstrict_trace
decorator, whichcurrently is an enhanced version of
torch._dynamo.allow_in_graph
(seedocstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.
The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo
VariableTracker
(while accounting for buffered side-effects andguards). This patch takes a middle-ground (simple implementation with a
bit of user labor), by
input types,
pytree_flatten
(which accounts forbuffered side-effects and guards automatically),
torch._higher_order_ops.flat_apply
(which unflattens the inputs andinvokes the underlying function).
Next Steps
In subsequent patches, we will try to support the following:
pytree.register_constant
-ed instances.torch.nn.Module
as inputscc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames