8000 [dynamo][compile-time] Cache whether a function is inlineable (#153192) · pytorch/pytorch@11c64b7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 11c64b7

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][compile-time] Cache whether a function is inlineable (#153192)
Pull Request resolved: #153192 Approved by: https://github.com/StrongerXi, https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: #153458
1 parent e2ce17c commit 11c64b7

File tree

4 files changed

+38
-16
lines changed

4 files changed

+38
-16
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
add_loop_eager,compile_time_instruction_count,3035000000,0.015
1+
add_loop_eager,compile_time_instruction_count,3051000000,0.015
22

33

44

@@ -16,9 +16,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44480000000,0.025
1616

1717
add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
1818

19-
20-
21-
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1011000000,0.015
19+
basic_modules_ListOfLinears_eager,compile_time_instruction_count,974800000,0.015
2220

2321

2422

@@ -38,7 +36,7 @@ update_hint_regression,compile_time_instruction_count,1715000000,0.02
3836

3937

4038

41-
float_args,compile_time_instruction_count,439200000,0.015
39+
float_args,compile_time_instruction_count,444500000,0.015
4240

4341

4442

torch/_dynamo/bytecode_transformation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,6 +1372,7 @@ def clear_instruction_args(instructions):
13721372
inst.arg = None
13731373

13741374

1375+
@functools.lru_cache
13751376
def get_code_keys() -> list[str]:
13761377
# Python 3.11 changes to code keys are not fully documented.
13771378
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24

torch/_dynamo/symbolic_convert.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3792,13 +3792,6 @@ def build_inline_tracer(
37923792
args: list[VariableTracker],
37933793
kwargs,
37943794
):
3795-
if isinstance(func, SkipFunctionVariable):
3796-
unimplemented_v2(
3797-
gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)",
3798-
context=f"Attempted to inline a SkipFunctionVariable {func}",
3799-
explanation="Attempted to inline a function that was previously determined to be marked as intentionally skipped.",
3800-
hints=[],
3801-
)
38023795
assert isinstance(
38033796
func,
38043797
(
@@ -3808,8 +3801,35 @@ def build_inline_tracer(
38083801
LocalGeneratorObjectVariable,
38093802
),
38103803
)
3811-
result = InliningInstructionTranslator.check_inlineable(func)
3812-
assert result.skipped is False
3804+
code: types.CodeType = func.get_code()
3805+
result = None
3806+
tracing_ctx = parent.output.tracing_context
3807+
3808+
# Check if we have already identified this function to be inline-able.
3809+
# The exception is dont_skip_tracing flag which affects the inline
3810+
# behavior. If the flag is True, don't rely on previous results.
3811+
if not config.dont_skip_tracing and tracing_ctx:
3812+
if previous_result := tracing_ctx.previously_inlined_functions.get(
3813+
code, None
3814+
):
3815+
result = previous_result
3816+
3817+
if result is None:
3818+
if isinstance(func, SkipFunctionVariable):
3819+
unimplemented_v2(
3820+
gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)",
3821+
context=f"Attempted to inline a SkipFunctionVariable {func}",
3822+
explanation=(
3823+
"Attempted to inline a function that was previously determined to be marked as intentionally skipped."
3824+
),
3825+
hints=[],
3826+
)
3827+
result = InliningInstructionTranslator.check_inlineable(func)
3828+
assert result.skipped is False
3829+
3830+
if not config.dont_skip_tracing and tracing_ctx:
3831+
tracing_ctx.previously_inlined_functions[code] = result
3832+
38133833
try:
38143834
sub_locals = func.bind_args(parent, args, kwargs)
38153835
except TypeError as e:
@@ -3832,7 +3852,6 @@ def build_inline_tracer(
38323852
hints=[*graph_break_hints.DYNAMO_BUG],
38333853
)
38343854

3835-
code: types.CodeType = func.get_code()
38363855
if code.co_name in ("__setitem__", "__setattr__") and not (
38373856
args and isinstance(args[0], variables.UserDefinedObjectVariable)
38383857
):
@@ -3851,9 +3870,11 @@ def build_inline_tracer(
38513870
if sys.version_info >= (3, 11):
38523871
cur_inst = parent.current_instruction
38533872
parent_code = parent.f_code
3854-
header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno)
38553873

38563874
def get_trace_call_log_str():
3875+
header = parent.get_line_of_code_header(
3876+
lineno=cur_inst.positions.lineno
3877+
)
38573878
line = get_instruction_source_311(parent_code, cur_inst).rstrip()
38583879
return f"TRACE inlined call {code.co_name} from {header}\n{line}"
38593880

torch/_guards.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ def __init__(self, fake_mode):
825825
self.guards_context = GuardsContext()
826826
self.module_context = ModuleContext()
827827
self.global_context = GlobalContext()
828+
self.previously_inlined_functions = dict()
828829
self.fake_mode = fake_mode
829830
self.frame_summary_stack = []
830831
# This is morally part of frame_summary_stack, but it is kept separate
@@ -870,6 +871,7 @@ def clear(self):
870871
# Look at the note in output_graph.py in function `save_global_state`
871872
# for the context on clearing global context.
872873
self.global_context.global_state = {}
874+
self.previously_inlined_functions.clear()
873875

874876
@staticmethod
875877
@contextmanager

0 commit comments

Comments
 (0)
0