8000 [dynamo] Initial support for `nonstrict_trace` (#146367) · pytorch/pytorch@301ce65 · GitHub
[go: up one dir, main page]

Skip to content

Commit 301ce65

Browse files
StrongerXiaditew01
authored andcommitted
[dynamo] Initial support for nonstrict_trace (#146367)
## Context > **Note:** `mark_traceable` got renamed to `nonstrict_trace` after > offline discussion. The reasons are (1) it aligns with `torch.export`'s > `nonstrict` notion, and (2) it's more definitive in behavior suggestion. 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) ## Summary This patch adds a `torch._dynamo.nonstrict_trace` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). Specifically, this patch focuses on the UI and functionality prototyping/plumbing. The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). ## Next Steps In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs Pull Request resolved: #146367 Approved by: https://github.com/zou3519 ghstack dependencies: #146714
1 parent 4bededa commit 301ce65

File tree

12 files changed

+779
-34
lines changed

12 files changed

+779
-34
lines changed

test/dynamo/test_decorators.py

Lines changed: 433 additions & 0 deletions
Large diffs are not rendered by default.

test/dynamo/test_flat_apply.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch._dynamo.test_case
66
import torch.utils._pytree as pytree
7+
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
78
from torch._higher_order_ops.flat_apply import (
89
flat_apply,
910
func_to_graphable,
@@ -83,6 +84,72 @@ def f(a, b):
8384
result = flat_apply(func_spec, in_spec, *flat_args)
8485
self.assertEqual(result, f(*args, **kwargs))
8586

87+
def test_nonstrict_trace_dynamo_graph(self):
88+
class Point:
89+
x: torch.Tensor
90+
y: torch.Tensor
91+
92+
def __init__(self, x, y):
93+
self.x = x
94+
self.y = y
95+
96+
class PointTensor:
97+
p: Point
98+
t: torch.Tensor
99+
100+
def __init__(self, p, t):
101+
self.p = p
102+
self.t = t
103+
104+
torch.utils._pytree.register_pytree_node(
105+
PointTensor,
106+
lambda pt: ((pt.p, pt.t), ()),
107+
lambda pt, _: PointTensor(pt[0], pt[1]),
108+
)
109+
110+
torch.utils._pytree.register_pytree_node(
111+
Point,
112+
lambda p: ((p.x, p.y), ()),
113+
lambda xy, _: Point(xy[0], xy[1]),
114+
)
115+
116+
def trace_point(p):
117+
torch._dynamo.graph_break()
118+
return p.x * p.y
119+
120+
@torch._dynamo.nonstrict_trace
121+
def trace_point_tensor(pt):
122+
torch._dynamo.graph_break()
123+
return pt.t + trace_point(pt.p)
124+
125+
backend = EagerAndRecordGraphs()
126+
127+
@torch.compile(fullgraph=True, backend=backend)
128+
def fn(x, y):
129+
p = Point(x, y)
130+
t = x + y
131+
pt = PointTensor(p, t)
132+
res = trace_point_tensor(pt)
133+
return res
134+
135+
fn(torch.randn(10), torch.randn(10))
136+
self.assertExpectedInline(
137+
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
138+
"""\
139+
class GraphModule(torch.nn.Module):
140+
def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):
141+
l_x_ = L_x_
142+
l_y_ = L_y_
143+
144+
t: "f32[10]" = l_x_ + l_y_
145+
146+
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
147+
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
148+
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
149+
return (res,)
150+
""", # NOQA: B950
151+
)
152+
86153

87154
if __name__ == "__main__":
88155
from torch._dynamo.test_case import run_tests

torch/_dynamo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
mark_static,
2727
mark_static_address,
2828
maybe_mark_dynamic,
29+
nonstrict_trace,
2930
run,
3031
set_stance,
3132
substitute_in_graph,
@@ -63,6 +64,7 @@
6364
"maybe_mark_dynamic",
6465
"mark_static",
6566
"mark_static_address",
67+
"nonstrict_trace",
6668
"optimize",
6769
"optimize_assert",
6870
"export",

torch/_dynamo/decorators.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,39 @@ def allow_in_graph(fn):
160160
return fn
161161

162162

163+
def nonstrict_trace(traceable_fn):
164+
# Like `allow_in_graph`, but with the following enhancements/differences:
165+
#
166+
# 1. Supports user-defined class as inputs, as long as the class has been
167+
# registered with pytree.
168+
# 2. Reads to global/captured tensors forces the underlying graph to treat
169+
# those tensors as constant, and we _assume_ they will not be updated. This
170+
# is similar to FX tracing.
171+
# 3. In the resulting Dynamo graph, the call to a `nonstrict_trace`-ed function
172+
# will be represented as a call to `torch._higher_order_ops.flat_apply`,
173+
# which takes in the `nonstrict_trace`-ed function and pytree-flattened
174+
# inputs.
175+
# 4. Only the returned function is traceable, and the original function will
176+
# not be. Moreover, `nonstrict_trace` can be used inside a `torch.compile`
177+
# region.
178+
#
179+
# NOTE: like `allow_in_graph`, aliasing information is neither preserved
180+
# between inputs themselves, nor between inputs and outputs.
181+
assert callable(traceable_fn), "nonstrict_trace expects a callable"
182+
183+
@functools.wraps(traceable_fn)
184+
def wrapped(*args, **kwargs):
185+
return traceable_fn(*args, **kwargs)
186+
187+
# This line allows us to reuse much of the `allow_in_graph` impl.
188+
trace_rules._allowed_callable_ids.add(id(wrapped))
189+
190+
# This line allows us to diverge the impl from `allow_in_graph`.
191+
trace_rules._nonstrict_trace_callable_ids.add(id(wrapped))
192+
193+
return wrapped
194+
195+
163196
def _disallow_in_graph_helper(throw_if_not_allowed):
164197
def inner(fn):
165198
if isinstance(fn, (list, tuple)):
@@ -176,6 +209,7 @@ def inner(fn):
176209
"Allowed callables means callables that TorchDynamo puts as-is in the extracted graph."
177210
)
178211
trace_rules._allowed_callable_ids.remove(id(fn))
212+
trace_rules._nonstrict_trace_callable_ids.remove(id(fn))
179213
trace_rules._disallowed_callable_ids.add(id(fn))
180214
return fn
181215

torch/_dynamo/output_graph.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
get_instruction_source_311,
116116
get_locals_to_steal,
117117
get_static_address_type,
118+
get_unique_name_wrt,
118119
graph_break_reasons,
119120
increment_op_count,
120121
lazy_format_graph_code,
@@ -753,6 +754,17 @@ def module_key_name(*names):
753754

754755
return name
755756

757+
def register_static_attr_and_return_proxy(
758+
self, attr_prefix: str, attr_value: Any
759+
) -> fx.Proxy:
760+
attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
761+
# TODO `nn_modules` has been historically overloaded to store a lot more
762+
# than just nn module objects, fix that.
763+
self.nn_modules[attr_name] = attr_value
764+
proxy = self.create_proxy("get_attr", attr_name, (), {})
765+
set_example_value(proxy.node, attr_value)
766+
return proxy
767+
756768
def register_attr_or_module(
757769
self,
758770
target: Union[torch.nn.Module, torch.Tensor, Any],
@@ -864,36 +876,30 @@ def wrap_name(module_key):
864876
return wrap_name(k)
865877

866878
name = OutputGraph.module_key_name(*names)
879+
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
880+
self.nn_modules[name] = target
881+
if isinstance(target, torch.nn.Module):
867882

868-
base = name
869-
for i in itertools.count():
870-
if name not in self.nn_modules and name not in self.global_scope:
871-
self.nn_modules[name] = target
872-
if isinstance(target, torch.nn.Module):
873-
874-
def register_leaf_name(leaf_name):
875-
assert self.param_name_to_source is not None
876-
new_source = ParamBufferSource(source, leaf_name)
877-
new_name = f"{name}.{leaf_name}"
878-
self.param_name_to_source[new_name] = new_source
879-
if isinstance(source, LocalSource):
880-
self.dynamo_flat_name_to_original_fqn[
881-
OutputGraph.module_key_name(new_source.name())
882-
] = leaf_name
883-
884-
# annoying, but there are cases when we do not have parameters
885-
# see test_nn_moduledict_contains
886-
if hasattr(target, "_parameters"):
887-
for leaf_name, _ in target.named_parameters():
888-
register_leaf_name(leaf_name)
889-
if hasattr(target, "_buffers"):
890-
for leaf_name, _ in target.named_buffers():
891-
register_leaf_name(leaf_name)
892-
893-
return wrap_name(name)
894-
name = f"{base}_{i}"
895-
896-
raise AssertionError("unreachable")
883+
def register_leaf_name(leaf_name):
884+
assert self.param_name_to_source is not None
885+
new_source = ParamBufferSource(source, leaf_name)
886+
new_name = f"{name}.{leaf_name}"
887+
self.param_name_to_source[new_name] = new_source
888+
if isinstance(source, LocalSource):
889+
self.dynamo_flat_name_to_original_fqn[
890+
OutputGraph.module_key_name(new_source.name())
891+
] = leaf_name
892+
893+
# annoying, but there are cases when we do not have parameters
894+
# see test_nn_moduledict_contains
895+
if hasattr(target, "_parameters"):
896+
for leaf_name, _ in target.named_parameters():
897+
register_leaf_name(leaf_name)
898+
if hasattr(target, "_buffers"):
899+
for leaf_name, _ in target.named_buffers():
900+
register_leaf_name(leaf_name)
901+
902+
return wrap_name(name)
897903

898904
def handle_aliases_for_stolen_lists(self, tx):
899905
# If list inputs are stolen, but still needed after the function call, create aliases to keep them alive

torch/_dynamo/trace_rules.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
123123
124124
"""
125-
manual_torch_name_rule_map = {
125+
manual_torch_name_rule_map: dict[str, Any] = {
126126
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
127127
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
128128
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
@@ -306,6 +306,7 @@
306306
"torch.jit._unwrap_optional": UserFunctionVariable,
307307
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
308308
"torch._dynamo.mark_static": UserFunctionVariable,
309+
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
309310
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
310311
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
311312
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
@@ -2998,6 +2999,12 @@ def _disallowed_callable_ids() -> dict[int, str]:
29982999
return rv
29993000

30003001

3002+
@FunctionIdSet
3003+
def _nonstrict_trace_callable_ids() -> dict[int, str]:
3004+
rv: dict[int, str] = {}
3005+
return rv
3006+
3007+
30013008
@FunctionIdSet
30023009
def _builtin_function_ids() -> dict[int, str]:
30033010
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
@@ -3103,6 +3110,11 @@ def is_callable_allowed(obj) -> bool:
31033110
return id(obj) in _allowed_callable_ids
31043111

31053112

3113+
def is_nonstrict_trace_callable(obj) -> bool:
3114+
_maybe_init_lazy_module(obj)
3115+
return id(obj) in _nonstrict_trace_callable_ids
3116+
3117+
31063118
def is_callable_disallowed(obj) -> bool:
31073119
_maybe_init_lazy_module(obj)
31083120
return id(obj) in _disallowed_callable_ids

torch/_dynamo/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,27 @@ def get_safe_global_name(tx, root, obj):
25832583
return f"{root}_{id(obj)}_c{tx.output.compile_id}"
25842584

25852585

2586+
def get_unique_name_wrt(prefix: str, *containers) -> str:
2587+
"""
2588+
Return a name that starts with `prefix` and is not in any of the
2589+
`containers` (e.g., map, set).
2590+
"""
2591+
name = prefix
2592+
for i in itertools.count():
2593+
found = False
2594+
for container in containers:
2595+
if name in container:
2596+
found = True
2597+
break
2598+
2599+
if not found:
2600+
return name
2601+
# else update and retry
2602+
name = f"{prefix}_{i}"
2603+
2604+
raise AssertionError("unreachable")
2605+
2606+
25862607
def wrap_fake_exception(fn):
25872608
try:
25882609
return fn()

torch/_dynamo/variables/functions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,27 @@ def call_function(
363363
args: "list[VariableTracker]",
364364
kwargs: "dict[str, VariableTracker]",
365365
) -> "VariableTracker":
366+
# Handle a `nonstrict_trace(fn)` call
367+
if self.fn is torch._dynamo.nonstrict_trace:
368+
bound = inspect.signature(self.fn).bind(*args, **kwargs)
369+
fn_var = bound.args[0]
370+
if not isinstance(fn_var, BaseUserFunctionVariable):
371+
typ = fn_var.python_type()
372+
unimplemented(
373+
f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
374+
)
375+
376+
if not isinstance(fn_var, UserFunctionVariable):
377+
fn_name = fn_var.get_name()
378+
unimplemented(
379+
f"""
380+
Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
381+
""" # NOQA: B950
382+
)
383+
384+
fn = fn_var.fn
385+
return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True)
386+
366387
if self.is_constant:
367388
return invoke_and_store_as_constant(
368389
tx, self.fn, self.get_name(), args, kwargs

0 commit comments

Comments
 (0)
0