8000 Improve cache key graph printing performance (#151928) · pytorch/pytorch@7a0781e · GitHub
[go: up one dir, main page]

Skip to content

Commit 7a0781e

Browse files
aorenstepytorchmergebot
authored andcommitted
Improve cache key graph printing performance (#151928)
Teach the graph printer how to allow overriding printing SymTypes (`SymInt`, `SymFloat`, `SymBool`) and then use that to reuse the fast SymNode printing from `torch._inductor.utils.sympy_str()` to make computing the cache key faster. On my computer the repro from #151823 goes from 480s -> 80s (still terrible... but better). Fixes #151823 Pull Request resolved: #151928 Approved by: https://github.com/laithsakka
1 parent 7dd9d51 commit 7a0781e

File tree

4 files changed

+72
-22
lines changed

4 files changed

+72
-22
lines changed

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,10 @@ def aot_dispatch_base(
186186
aot_forward_graph_str = None
187187
if aot_config.cache_info is not None:
188188
aot_forward_graph_str = fw_module.print_readable(
189-
print_output=False, include_stride=True, include_device=True
189+
print_output=False,
190+
include_stride=True,
191+
include_device=True,
192+
fast_sympy_print=True,
190193
)
191194

192195
fakified_out_wrapper = FakifiedOutWrapper()

torch/_inductor/compile_fx.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -844,9 +844,11 @@ def _compile_fx_inner(
844844
assert mb_compiled_graph is None
845845
log.debug(
846846
"FX cache bypass reason: %s",
847-
cache_info.get("cache_bypass_reason", "unknown")
848-
if cache_info is not None
849-
else "FX cache disabled or key generation failed",
847+
(
848+
cache_info.get("cache_bypass_reason", "unknown")
849+
if cache_info is not None
850+
else "FX cache disabled or key generation failed"
851+
),
850852
)
851853
mb_compiled_graph = fx_codegen_and_compile(
852854
gm, example_inputs, inputs_to_check, **graph_kwargs
@@ -1167,8 +1169,15 @@ def codegen_and_compile(
11671169
colored=True,
11681170
),
11691171
)
1172+
1173+
# We're printing the graph to be used as a cache key - so a
1174+
# printer which is a little less readable but faster is
1175+
# appropriate.
11701176
inductor_post_grad_graph_str = gm.print_readable(
1171-
print_output=False, include_stride=True, include_device=True
1177+
print_output=False,
1178+
include_stride=True,
1179+
include_device=True,
1180+
fast_sympy_print=True,
11721181
)
11731182
trace_structured(
11741183
"inductor_post_grad_graph",
@@ -1268,12 +1277,12 @@ def codegen_and_compile(
12681277
is_inference=is_inference,
12691278
is_backward=is_backward,
12701279
const_output_index=const_output_index,
1271-
const_wrapper_code=const_wrapper_code.value
1272-
if const_wrapper_code
1273-
else None,
1274-
const_kernel_code=const_kernel_code.value
1275-
if const_kernel_code
1276-
else None,
1280+
const_wrapper_code=(
1281+
const_wrapper_code.value if const_wrapper_code else None
1282+
),
1283+
const_kernel_code=(
1284+
const_kernel_code.value if const_kernel_code else None
1285+
),
12771286
const_module=const_graph,
12781287
inputs_to_check=inputs_to_check,
12791288
)

torch/fx/graph.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import typing
1313
import warnings
1414
from collections import defaultdict
15-
from collections.abc import Iterable
15+
from collections.abc import Iterable, Iterator
1616
from contextlib import contextmanager
1717
from dataclasses import dataclass
1818
from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
@@ -317,6 +317,9 @@ def _parse_stack_trace(stack_trace: str):
317317

318318
@compatibility(is_backward_compatible=False)
319319
class CodeGen:
320+
# This is an override hook so we can customize the SymNode printer.
321+
_sym_repr: Callable[["torch.types.PySymType"], str] = lambda x: repr(x)
322+
320323
def __init__(self):
321324
self._body_transformer: Optional[TransformCodeFunc] = None
322325
self._func_name: str = "forward"
@@ -609,7 +612,8 @@ def emit_node(node: Node):
609612
f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"'
610613
)
611614
elif isinstance(meta_val, py_sym_types):
612-
maybe_type_annotation = f': "Sym({meta_val})"'
615+
val_str = CodeGen._sym_repr(meta_val)
616+
maybe_type_annotation = f': "Sym({val_str})"'
613617
elif isinstance(meta_val, TensorMetadata):
614618
maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
615619

@@ -1907,6 +1911,18 @@ def on_generate_code_context_manager():
19071911
return on_generate_code_context_manager()
19081912

19091913

1914+
@contextmanager
1915+
def _override_sym_repr(
1916+
override: Callable[["torch.types.PySymType"], str]
1917+
) -> Iterator[None]:
1918+
tmp = CodeGen._sym_repr
1919+
try:
1920+
CodeGen._sym_repr = override
1921+
yield
1922+
finally:
1923+
CodeGen._sym_repr = tmp
1924+
1925+
19101926
def _identity(x):
19111927
return x
19121928

torch/fx/graph_module.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
1818

1919
from ._compatibility import compatibility
20-
from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
20+
from .graph import (
21+
_custom_builtins,
22+
_is_from_torch,
23+
_override_sym_repr,
24+
_PyTreeCodeGen,
25+
Graph,
26+
PythonCode,
27+
)
2128

2229

2330
__all__ = [
@@ -927,18 +934,33 @@ def print_readable(
927934
include_stride=False,
928935
include_device=False,
929936
colored=False,
937+
*,
938+
# If `fast_sympy_print` is True then we use a sympy printer which is faster
939+
# but may result in less-readable output.
940+
fast_sympy_print: bool = False,
930941
):
931942
"""
932943
Return the Python code generated for current GraphModule and its children GraphModules
933944
"""
934-
return _print_readable(
935-
self,
936-
self._get_name(),
937-
print_output,
938-
include_stride,
939-
include_device,
940-
colored,
941-
)
945+
ctx_mgr = contextlib.ExitStack()
946+
with ctx_mgr:
947+
if fast_sympy_print:
948+
from torch._inductor.utils import sympy_str
949+
950+
def fast_repr(expr: torch.types.PySymType) -> str:
951+
return sympy_str(expr.node.expr)
952+
953+
ctx_mgr.enter_context(_override_sym_repr(fast_repr))
954+
955+
r = _print_readable(
956+
self,
957+
self._get_name(),
958+
print_output,
959+
include_stride,
960+
include_device,
961+
colored,
962+
)
963+
return r
942964

943965
def __str__(self) -> str:
944966
orig_str = super().__str__()

0 commit comments

Comments
 (0)
0