10000 [Dynamo] Introduce hook receiving list of traced code objects by jbschlosser · Pull Request #153622 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Dynamo] Introduce hook receiving list of traced code objects #153622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions test/dynamo/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import dataclasses
import os
import pprint
import sys
from unittest import mock
Expand Down Expand Up @@ -141,6 +142,63 @@ def break_it2(x):
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)

def test_frame_traced_hook(self):
from utils import add, break_it

traced_file_sets = []

def get_traced_files(s):
nonlocal traced_file_sets
traced_file_sets.append(s)

utils_path = os.path.join(os.path.dirname(__file__), "utils.py")

# === no inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_files})
def fn(x):
return x * 2

x = torch.randn(3)
traced_file_sets = []
fn(x)
# expect hook to be called once with this file
self.assertEqual(traced_file_sets, [{__file__}])

# === successful inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_files})
def fn(x):
return add(x) * 2

x = torch.randn(3)
traced_file_sets = []
fn(x)
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
# expect hook to be called once with both this file and file of inlined func
self.assertEqual(traced_file_sets, [{__file__, utils_path}])

# === graph break occurs during inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_files})
def fn(x):
y = break_it(x)
return y * 2

x = torch.randn(3)
traced_file_sets = []
fn(x)
# expect hook to be called twice; once for this file one for file of inlined func
self.assertEqual(traced_file_sets, [{__file__}, {utils_path}])

# === empty graph ===
@torch.compile(options={"frame_traced_fn": get_traced_files})
def fn(x):
return x

x = torch.randn(3)
traced_file_sets = []
fn(x)
# hook is not expected to be called at all for an empty graph
self.assertEqual(traced_file_sets, [])


class TestModel(torch.nn.Module):
def __init__(self):
Expand Down
4 changes: 4 additions & 0 deletions test/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def add(x):
return x + 1


def break_it(x):
return x.sum().item()


def create_dummy_module_and_function():
module = types.ModuleType("dummy_module")
module.__spec__ = importlib.machinery.ModuleSpec(
Expand Down
5 changes: 5 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,6 +2597,10 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)

frame_traced_fn = None
if options and isinstance(options, dict):
frame_traced_fn = options.pop("frame_traced_fn", None)

if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
Expand All @@ -2608,6 +2612,7 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
frame_traced_fn=frame_traced_fn,
)(model) # type: ignore[return-value]


Expand Down
18 changes: 11 additions & 7 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,13 +936,17 @@ def count_args(code: CodeType) -> int:
annotation_str,
)

if not output.is_empty_graph() and hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not
# generate any graph. This can happen in export when TorchDynamo
# generated bytecode has some reconstruction logic for mutated
# variables which can trigger TorchDynamo on the children frames but
# they are benign and do not generate any new graphs.
hooks.guard_export_fn(output.guards)
if not output.is_empty_graph():
if hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not
# generate any graph. This can happen in export when TorchDynamo
# generated bytecode has some reconstruction logic for mutated
# variables which can trigger TorchDynamo on the children frames but
# they are benign and do not generate any new graphs.
hooks.guard_export_fn(output.guards)
if hooks.frame_traced_fn is not None:
output.tracing_context.traced_files.add(output.co_fields["co_filename"])
hooks.frame_traced_fn(output.tracing_context.traced_files)

return wrap_guarded_code(guarded_code)

Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ def _optimize(
guard_export_fn=None,
guard_fail_fn=None,
guard_filter_fn=None,
frame_traced_fn=None,
disable=False,
dynamic=None,
) -> Union[OptimizeContext, _NullDecorator]:
Expand Down Expand Up @@ -1033,6 +1034,7 @@ def toy_example(a, b): ...
guard_export_fn=guard_export_fn,
guard_fail_fn=guard_fail_fn,
guard_filter_fn=guard_filter_fn,
frame_traced_fn=frame_traced_fn,
)
torch._C._log_api_usage_once("torch._dynamo.optimize")
if (
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
The Hooks class manages two types of hook functions:
- guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input
- guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input

- frame_traced_fn: Called when a frame has finished tracing, resulting in a non-empty graph.
This hook will be passed the set of filenames containing traced code
These hooks enable customization of guard export and failure handling behaviors.
"""

Expand All @@ -23,3 +24,4 @@ class Hooks:
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None
frame_traced_fn: Optional[Callable[[set[str]], None]] = None
1 change: 1 addition & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,7 @@ def inline_call_(self):
parent.inconsistent_side_effects |= self.inconsistent_side_effects

log.debug("DONE INLINING %s", code)
self.output.tracing_context.traced_files.add(code.co_filename)

if config.enable_faithful_generator_behavior or (
isinstance(self, InliningGeneratorInstructionTranslator)
Expand Down
2 changes: 2 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,8 @@ def __init__(self, fake_mode):
# see note: [Returning Fake Tensors on First AOT Autograd Call]
self.fakify_first_call = False
self.hop_dispatch_set_cache = HopDispatchSetCache()
# set of filenames for inlined functions
self.traced_files: set[str] = set()

def clear(self):
# Look at the note in output_graph.py in function `save_global_state`
Expand Down
Loading
0