8000 Add user annotation for FX graph cache key (#159318) · pytorch/pytorch@f1fb57d · GitHub
[go: up one dir, main page]

Skip to content

Commit f1fb57d

Browse files
shengfukevinpytorchmergebot
authored andcommitted
Add user annotation for FX graph cache key (#159318)
Summary: AI system co-design team requested to add user annotation for FX graph cache key in PyTorch Kineto trace and Execution trace. With this annotation, they can know the FX graph to which the kernels belong. Test Plan: buck2 run mode/opt caffe2/test:test_profiler_cuda -- profiler.test_execution_trace.TestExecutionTraceCUDA Rollback Plan: Differential Revision: D79019069 Pull Request resolved: #159318 Approved by: https://github.com/sraikund16, https://github.com/jansel
1 parent 6d0f456 commit f1fb57d

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

test/dynamo/test_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def fn(x, y):
181181
torch.randn(10, 15),
182182
)
183183

184-
annotations = [e.name for e in prof.events() if "Compiled" in e.name]
184+
annotations = [e.name for e in prof.events() if "Torch-Compiled" in e.name]
185185
self.assertEqual(
186186
annotations,
187187
[

test/profiler/test_execution_trace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def fn(a, b, c):
404404

405405
nodes = self.get_execution_trace_root(fp.name)
406406
found_captured_triton_kernel_node = False
407+
found_call_compiled_fx_graph = False
407408
for n in nodes:
408409
assert "name" in n
409410
if "triton_" in n["name"]:
@@ -412,7 +413,10 @@ def fn(a, b, c):
412413
found_captured_triton_kernel_node = True
413414
assert len(n["inputs"]["values"]) > 0
414415
assert len(n["outputs"]["values"]) == 0
416+
elif "Call CompiledFxGraph" in n["name"]:
417+
found_call_compiled_fx_graph = True
415418
assert found_captured_triton_kernel_node
419+
assert found_call_compiled_fx_graph
416420

417421
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
418422
@unittest.skipIf(

torch/_inductor/output_code.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
output_node,
4949
set_tracing_context_output_strides,
5050
)
51+
from torch.autograd.profiler import record_function
5152
from torch.utils._ordered_set import OrderedSet
5253

5354
from . import config
@@ -581,7 +582,10 @@ def __del__(self) -> None:
581582
def __call__(self, inputs: Sequence[Any]) -> Any:
582583
assert self.current_callable is not None
583584
try:
584-
return self.current_callable(inputs)
585+
with record_function(
586+
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
587+
):
588+
return self.current_callable(inputs)
585589
finally:
586590
get_runtime_metrics_context().finish()
587591
AutotuneCacheBundler.end_compile()

0 commit comments

Comments
 (0)
0