8000 [NOT FOR MERGE] Re-implement #148962 for benchmarking · pytorch/pytorch@0420ee1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0420ee1

Browse files
[NOT FOR MERGE] Re-implement #148962 for benchmarking
ghstack-source-id: 02b4287 Pull Request resolved: #149961
1 parent 0c139fa commit 0420ee1

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

torch/_inductor/compile_fx.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414
from abc import ABC, abstractmethod
1515
from collections import defaultdict
16-
from contextlib import AbstractContextManager
16+
from contextlib import AbstractContextManager, nullcontext
1717
from inspect import currentframe
1818
from itertools import count
1919
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
@@ -1530,35 +1530,34 @@ def compile_fx_aot(
15301530
model_: GraphModule,
15311531
example_inputs_: list[InputType],
15321532
inner_compile: _CompileFxCallable = compile_fx_inner,
1533-
config_patches: Optional[dict[str, str]] = None,
1533+
config_patches: Optional[dict[str, Any]] = None,
15341534
) -> Union[list[str], str]:
15351535
assert isinstance(model_, GraphModule), model_
15361536

15371537
# [See NOTE] Unwrapping subclasses AOT
15381538
unwrap_tensor_subclass_parameters(model_)
15391539

1540-
config_patches: dict[str, Any] = (
1541-
{"cpp_wrapper": True}
1542-
if config_patches is None
1543-
else {**config_patches, "cpp_wrapper": True}
1544-
)
1540+
if config_patches is None:
1541+
config_patches = {}
15451542

1546-
output_path = config_patches.get(
1547-
"aot_inductor.output_path", config.aot_inductor.output_path
1543+
config_patches.update(
1544+
cpp_wrapper=True,
1545+
freezing=config.freezing
1546+
if config.freezing is not None
1547+
else not config.aot_inductor.use_runtime_constant_folding,
15481548
)
15491549

1550-
if output_path:
1550+
if output_path := config_patches.get(
1551+
"aot_inductor.output_path", config.aot_inductor.output_path
1552+
):
15511553
assert not output_path.endswith(".pt2"), (
15521554
"The output path for aot_compile should not have an extension with .pt2 "
15531555
"this is for specifying the output path for the .so in AOTInductor. "
15541556
"If you would like to package the AOTInductor generated files "
15551557
"into a pt2, please call `torch._inductor.aoti_compile_and_package`."
15561558
)
15571559
else:
1558-
config_patches = {
1559-
**config_patches,
1560-
"aot_inductor.output_path": code_hash(model_.code),
1561-
}
1560+
config_patches["aot_inductor.output_path"] = code_hash(model_.code)
15621561

15631562
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
15641563
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
@@ -1664,7 +1663,11 @@ def fw_compiler_freezing(
16641663
if tracing_context.fw_metadata:
16651664
static_input_idxs += tracing_context.fw_metadata.static_input_indices
16661665

1667-
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1666+
with (
1667+
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
1668+
if fake_mode
1669+
else nullcontext()
1670+
):
16681671
optimized_function = inner_compile(
16691672
opt_model,
16701673
aot_example_inputs,

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def decide_compile_threads() -> int:
844844
# Freezing will attempt to inline weights as constants in optimization
845845
# and run constant folding and other optimizations on them. After freezing, weights
846846
# can no longer be updated.
847-
freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
847+
freezing: Optional[bool] = get_tristate_env("TORCHINDUCTOR_FREEZING")
848848

849849
# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
850850
# of potentially keeping multiple copies of weights.

0 commit comments

Comments
 (0)
0