10000 [dynamo] Initial support for `nonstrict_trace` by StrongerXi · Pull Request #146367 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] Initial support for nonstrict_trace #146367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 433 additions & 0 deletions test/dynamo/test_decorators.py

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions test/dynamo/test_flat_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch._dynamo.test_case
import torch.utils._pytree as pytree
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
from torch._higher_order_ops.flat_apply import (
flat_apply,
func_to_graphable,
Expand Down Expand Up @@ -83,6 +84,72 @@ def f(a, b):
result = flat_apply(func_spec, in_spec, *flat_args)
self.assertEqual(result, f(*args, **kwargs))

def test_nonstrict_trace_dynamo_graph(self):
class Point:
x: torch.Tensor
y: torch.Tensor

def __init__(self, x, y):
self.x = x
self.y = y

class PointTensor:
p: Point
t: torch.Tensor

def __init__(self, p, t):
self.p = p
self.t = t

torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.p, pt.t), ()),
lambda pt, _: PointTensor(pt[0], pt[1]),
)

torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)

def trace_point(p):
torch._dynamo.graph_break()
return p.x * p.y

@torch._dynamo.nonstrict_trace
def trace_point_tensor(pt):
torch._dynamo.graph_break()
return pt.t + trace_point(pt.p)

backend = EagerAndRecordGraphs()

@torch.compile(fullgraph=True, backend=backend)
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_point_tensor(pt)
return res

fn(torch.randn(10), torch.randn(10))
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):
l_x_ = L_x_
l_y_ = L_y_

t: "f32[10]" = l_x_ + l_y_

trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
return (res,)
""", # NOQA: B950
)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
mark_static,
mark_static_address,
maybe_mark_dynamic,
nonstrict_trace,
run,
set_stance,
substitute_in_graph,
Expand Down Expand Up @@ -63,6 +64,7 @@
"maybe_mark_dynamic",
"mark_static",
"mark_static_address",
"nonstrict_trace",
"optimize",
"optimize_assert",
"export",
Expand Down
34 changes: 34 additions & 0 deletions torch/_dynamo/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,39 @@ def allow_in_graph(fn):
return fn


def nonstrict_trace(traceable_fn):
# Like `allow_in_graph`, but with the following enhancements/differences:
#
# 1. Supports user-defined class as inputs, as long as the class has been
# registered with pytree.
# 2. Reads to global/captured tensors forces the underlying graph to treat
# those tensors as constant, and we _assume_ they will not be updated. This
# is similar to FX tracing.
# 3. In the resulting Dynamo graph, the call to a `nonstrict_trace`-ed function
# will be represented as a call to `torch._higher_order_ops.flat_apply`,
# which takes in the `nonstrict_trace`-ed function and pytree-flattened
# inputs.
# 4. Only the returned function is traceable, and the original function will
# not be. Moreover, `nonstrict_trace` can be used inside a `torch.compile`
# region.
#
# NOTE: like `allow_in_graph`, aliasing information is neither preserved
# between inputs themselves, nor between inputs and outputs.
assert callable(traceable_fn), "nonstrict_trace expects a callable"

@functools.wraps(traceable_fn)
def wrapped(*args, **kwargs):
return traceable_fn(*args, **kwargs)

# This line allows us to reuse much of the `allow_in_graph` impl.
trace_rules._allowed_callable_ids.add(id(wrapped))
Comment on lines +187 to +188
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id(obj) can be reused if obj is deleted. We can still do this, but need to install a weakref callback to remove the ID if obj is deleted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep created #147777 to track this, I'll try fixing it for all the relevant decorators in a subsequent patch.


# This line allows us to diverge the impl from `allow_in_graph`.
trace_rules._nonstrict_trace_callable_ids.add(id(wrapped))

return wrapped


def _disallow_in_graph_helper(throw_if_not_allowed):
def inner(fn):
if isinstance(fn, (list, tuple)):
Expand All @@ -176,6 +209,7 @@ def inner(fn):
"Allowed callables means callables that TorchDynamo puts as-is in the extracted graph."
)
trace_rules._allowed_callable_ids.remove(id(fn))
trace_rules._nonstrict_trace_callable_ids.remove(id(fn))
trace_rules._disallowed_callable_ids.add(id(fn))
return fn

Expand Down
64 changes: 35 additions & 29 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
get_instruction_source_311,
get_locals_to_steal,
get_static_address_type,
get_unique_name_wrt,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
Expand Down Expand Up @@ -748,6 +749,17 @@ def module_key_name(*names):

return name

def register_static_attr_and_return_proxy(
self, attr_prefix: str, attr_value: Any
) -> fx.Proxy:
attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
# TODO `nn_modules` has been historically overloaded to store a lot more
# than just nn module objects, fix that.
self.nn_modules[attr_name] = attr_value
proxy = self.create_proxy("get_attr", attr_name, (), {})
set_example_value(proxy.node, attr_value)
return proxy

def register_attr_or_module(
self,
target: Union[torch.nn.Module, torch.Tensor, Any],
Expand Down Expand Up @@ -859,36 +871,30 @@ def wrap_name(module_key):
return wrap_name(k)

name = OutputGraph.module_key_name(*names)
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
Copy link
Contributor Author
@StrongerXi StrongerXi Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This chunk is just refactoring name generation into this call.

self.nn_modules[name] = target
if isinstance(target, torch.nn.Module):

base = name
for i in itertools.count():
if name not in self.nn_modules and name not in self.global_scope:
self.nn_modules[name] = target
if isinstance(target, torch.nn.Module):

def register_leaf_name(leaf_name):
assert self.param_name_to_source is not None
new_source = ParamBufferSource(source, leaf_name)
new_name = f"{name}.{leaf_name}"
self.param_name_to_source[new_name] = new_source
if isinstance(source, LocalSource):
self.dynamo_flat_name_to_original_fqn[
OutputGraph.module_key_name(new_source.name())
] = leaf_name

# annoying, but there are cases when we do not have parameters
# see test_nn_moduledict_contains
if hasattr(target, "_parameters"):
for leaf_name, _ in target.named_parameters():
register_leaf_name(leaf_name)
if hasattr(target, "_buffers"):
for leaf_name, _ in target.named_buffers():
register_leaf_name(leaf_name)

return wrap_name(name)
name = f"{base}_{i}"

raise AssertionError("unreachable")
def register_leaf_name(leaf_name):
assert self.param_name_to_source is not None
new_source = ParamBufferSource(source, leaf_name)
new_name = f"{name}.{leaf_name}"
self.param_name_to_source[new_name] = new_source
if isinstance(source, LocalSource):
self.dynamo_flat_name_to_original_fqn[
OutputGraph.module_key_name(new_source.name())
] = leaf_name

# annoying, but there are cases when we do not have parameters
# see test_nn_moduledict_contains
if hasattr(target, "_parameters"):
for leaf_name, _ in target.named_parameters():
register_leaf_name(leaf_name)
if hasattr(target, "_buffers"):
for leaf_name, _ in target.named_buffers():
register_leaf_name(leaf_name)

return wrap_name(name)

def handle_aliases_for_stolen_lists(self, tx):
# If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
Expand Down
14 changes: 13 additions & 1 deletion torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@


"""
manual_torch_name_rule_map = {
manual_torch_name_rule_map: dict[str, Any] = {
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
Expand Down Expand Up @@ -306,6 +306,7 @@
"torch.jit._unwrap_optional": UserFunctionVariable,
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
"torch._dynamo.mark_static": UserFunctionVariable,
"torch._dynamo.nonstrict_trace": UserFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
Expand Down Expand Up @@ -2998,6 +2999,12 @@ def _disallowed_callable_ids() -> dict[int, str]:
return rv


@FunctionIdSet
def _nonstrict_trace_callable_ids() -> dict[int, str]:
rv: dict[int, str] = {}
return rv


@FunctionIdSet
def _builtin_function_ids() -> dict[int, str]:
# See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids
Expand Down Expand Up @@ -3103,6 +3110,11 @@ def is_callable_allowed(obj) -> bool:
return id(obj) in _allowed_callable_ids


def is_nonstrict_trace_callable(obj) -> bool:
_maybe_init_lazy_module(obj)
return id(obj) in _nonstrict_trace_callable_ids


def is_callable_disallowed(obj) -> bool:
_maybe_init_lazy_module(obj)
return id(obj) in _disallowed_callable_ids
Expand Down
21 changes: 21 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2574,6 +2574,27 @@ def get_safe_global_name(tx, root, obj):
return f"{root}_{id(obj)}_c{tx.output.compile_id}"


def get_unique_name_wrt(prefix: str, *containers) -> str:
"""
Return a name that starts with `prefix` and is not in any of the
`containers` (e.g., map, set).
"""
name = prefix
for i in itertools.count():
found = False
for container in containers:
if name in container:
found = True
break

if not found:
return name
# else update and retry
name = f"{prefix}_{i}"

raise AssertionError("unreachable")


def wrap_fake_exception(fn):
try:
return fn()
Expand Down
21 changes: 21 additions & 0 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,27 @@ def call_function(
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle a `nonstrict_trace(fn)` call
if self.fn is torch._dynamo.nonstrict_trace:
bound = inspect.signature(self.fn).bind(*args, **kwargs)
fn_var = bound.args[0]
if not isinstance(fn_var, BaseUserFunctionVariable):
typ = fn_var.python_type()
unimplemented(
f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
)

if not isinstance(fn_var, UserFunctionVariable):
fn_name = fn_var.get_name()
unimplemented(
f"""
Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region.
""" # NOQA: B950
)

fn = fn_var.fn
return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True)

if self.is_constant:
return invoke_and_store_as_constant(
tx, self.fn, self.get_name(), args, kwargs
Expand Down
Loading
Loading
0