8000 AOTI freezing: fix test issues and enable by default · pytorch/pytorch@8ee87e0 · GitHub
Skip to content

Commit 8ee87e0

Browse files
AOTI freezing: fix test issues and enable by default
ghstack-source-id: 81c50be Pull Request resolved: #149961
1 parent 1efd132 commit 8ee87e0

File tree

7 files changed

+51
-39
lines changed

7 files changed

+51
-39
lines changed

benchmarks/dynamo/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2917,7 +2917,7 @@ def parse_args(args=None):
29172917
)
29182918
parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
29192919
parser.add_argument(
2920-
"--freezing", action="store_true", help="turn on freezing", default=False
2920+
"--freezing", action="store_true", help="turn on freezing", default=None
29212921
)
29222922
parser.add_argument(
29232923
"--inductor-config",

test/inductor/test_aot_inductor.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@
122122
raise
123123

124124

125+
def _is_cpu_freezing(self):
126+
return (config.freezing is None or config.freezing) and self.device != GPU_TYPE
127+
128+
125129
class AOTInductorTestsTemplate:
126130
def test_simple(self):
127131
class Model(torch.nn.Module):
@@ -4157,16 +4161,16 @@ def forward(self, a):
41574161
a = torch.randn(batch, M, K, device=self.device)
41584162
example_inputs = (a,)
41594163

4160-
kernel_calls = (
4161-
[
4164+
is_cpu_freezing = _is_cpu_freezing(self)
4165+
if self.device == GPU_TYPE:
4166+
kernel_calls = [
41624167
("triton_poi_fused_0", 1),
41634168
(f"aoti_torch_{GPU_TYPE}_addmm_out", 2),
41644169
]
4165-
if self.device == GPU_TYPE
4166-
else [
4167-
("aoti_torch_cpu_addmm_out", 2),
4168-
]
4169-
)
4170+
elif is_cpu_freezing:
4171+
kernel_calls = [("cpp_fused_0", 1)]
4172+
else:
4173+
kernel_calls = [("aoti_torch_cpu_addmm_out", 2)]
41704174

41714175
# test default debug printing all tensor values codegen
41724176
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
@@ -4190,7 +4194,9 @@ def forward(self, a):
41904194
).run(code)
41914195

41924196
# test printing selected kernel's tensor values codegen
4193-
filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out"
4197+
filtered_kernel_name = (
4198+
"cpp_fused_0" if is_cpu_freezing else f"aoti_torch_{self.device}_addmm_out"
4199+
)
41944200
with config.patch(
41954201
{
41964202
"aot_inductor.debug_intermediate_value_printer": "2",
@@ -4201,7 +4207,7 @@ def forward(self, a):
42014207
AOTIRunnerUtil.legacy_compile, model, example_inputs
42024208
)
42034209
filtered_kernel_calls = [
4204-
(filtered_kernel_name, 2),
4210+
(filtered_kernel_name, 1 if is_cpu_freezing else 2),
42054211
]
42064212
for kernel_call, count in filtered_kernel_calls:
42074213
FileCheck().check_count(
@@ -4246,17 +4252,18 @@ def forward(self, a):
42464252
batch = 2
42474253
a = torch.randn(batch, M, K, device=self.device)
42484254
example_inputs = (a,)
4249-
kernel_calls = (
4250-
f"aoti_torch_{GPU_TYPE}_addmm_out"
4251-
if self.device == GPU_TYPE
4252-
else "aoti_torch_cpu_addmm_out"
4255+
4256+
kernel_call = (
4257+
"graph_1_cpp_fused_0"
4258+
if _is_cpu_freezing(self)
4259+
else f"aoti_torch_{self.device}_addmm_out"
42534260
)
42544261
with config.patch({"cpp.enable_kernel_profile": enable_kernel_profile}):
42554262
_, code = run_and_get_cpp_code(
42564263
AOTIRunnerUtil.compile, model, example_inputs
42574264
)
42584265
shim_fn_codes = (
4259-
f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef<c10::IValue>());'
4266+
f'RECORD_FUNCTION("{kernel_call}", c10::ArrayRef<c10::IValue>());'
42604267
)
42614268
if enable_kernel_profile:
42624269
FileCheck().check(shim_fn_codes).run(code)
@@ -4506,14 +4513,15 @@ def forward(self, a, b, c):
45064513
so_path, code = run_and_get_cpp_code(
45074514
AOTIRunnerUtil.legacy_compile, model, example_inputs
45084515
)
4509-
lowerbound_check = "u1 >= 1" if mark_unbacked else "u0 >= 2"
4516+
varname = f"u{int(mark_unbacked) + (2 if _is_cpu_freezing(self) else 0)}"
4517+
lowerbound_check = f"{varname} >= {1 if mark_unbacked else 2}"
45104518
FileCheck().check_count(lowerbound_check, 1).run(code)
45114519

45124520
compiled = AOTIRunnerUtil.legacy_load(self.device, so_path)
45134521
compiled(*example_inputs)
45144522

45154523
# Check the runtime assertion.
4516-
with self.assertRaisesRegex(Exception, ""):
4524+
with self.assertRaises(Exception):
45174525
unexpected_inputs = (torch.ones(0, device=self.device), b, c)
45184526
compiled(*unexpected_inputs)
45194527

torch/_inductor/codegen/cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5270,7 +5270,7 @@ def codegen_group(self, name=None) -> str:
52705270
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
52715271
code.writelines(
52725272
[
5273-
f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
5273+
f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>());'
52745274
]
52755275
)
52765276
for old, new in self.args.aliases():

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3695,12 +3695,12 @@ def add_constexpr_arg(arg_name):
36953695
if (
36963696
len(non_constexpr_signature(signature)) == 4
36973697
): # input, output and 2 args
3698-
tile_hint = "tile_hint=TileHint.SQUARE,"
3698+
tile_hint = " tile_hint=TileHint.SQUARE,"
36993699
else:
3700-
tile_hint = "tile_hint=TileHint.DEFAULT,"
3700+
tile_hint = " tile_hint=TileHint.DEFAULT,"
37013701
heuristics_line = f"""
37023702
@triton_heuristics.{self._get_heuristic()}(
3703-
size_hints={size_hints!r}, {tile_hint}
3703+
size_hints={size_hints!r},{tile_hint}
37043704
filename=__file__,
37053705
triton_meta={triton_meta!r},
37063706
inductor_meta={inductor_meta!r},

torch/_inductor/compile_fx.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414
from abc import ABC, abstractmethod
1515
from collections import defaultdict
16-
from contextlib import AbstractContextManager
16+
from contextlib import AbstractContextManager, nullcontext
1717
from inspect import currentframe
1818
from itertools import count
1919
from operator import attrgetter
@@ -1691,35 +1691,34 @@ def compile_fx_aot(
16911691
model_: GraphModule,
16921692
example_inputs_: list[InputType],
16931693
inner_compile: _CompileFxCallable = compile_fx_inner,
1694-
config_patches: Optional[dict[str, str]] = None,
1694+
config_patches: Optional[dict[str, Any]] = None,
16951695
) -> Union[list[str], str]:
16961696
assert isinstance(model_, GraphModule), model_
16971697

16981698
# [See NOTE] Unwrapping subclasses AOT
16991699
unwrap_tensor_subclass_parameters(model_)
17001700

1701-
config_patches: dict[str, Any] = (
1702-
{"cpp_wrapper": True}
1703-
if config_patches is None
1704-
else {**config_patches, "cpp_wrapper": True}
1705-
)
1701+
if config_patches is None:
1702+
config_patches = {}
17061703

1707-
output_path = config_patches.get(
1708-
"aot_inductor.output_path", config.aot_inductor.output_path
1704+
config_patches.update(
1705+
cpp_wrapper=True,
1706+
freezing=config.freezing
1707+
if config.freezing is not None
1708+
else not config.aot_inductor.use_runtime_constant_folding,
17091709
)
17101710

1711-
if output_path:
1711+
if output_path := config_patches.get(
1712+
"aot_inductor.output_path", config.aot_inductor.output_path
1713+
):
17121714
assert not output_path.endswith(".pt2"), (
17131715
"The output path for aot_compile should not have an extension with .pt2 "
17141716
"this is for specifying the output path for the .so in AOTInductor. "
17151717
"If you would like to package the AOTInductor generated files "
17161718
"into a pt2, please call `torch._inductor.aoti_compile_and_package`."
17171719
)
17181720
else:
1719-
config_patches = {
1720-
**config_patches,
1721-
"aot_inductor.output_path": code_hash(model_.code),
1722-
}
1721+
config_patches["aot_inductor.output_path"] = code_hash(model_.code)
17231722

17241723
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
17251724
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
@@ -1824,7 +1823,11 @@ def fw_compiler_freezing(
18241823
if tracing_context.fw_metadata:
18251824
static_input_idxs = tracing_context.fw_metadata.static_input_indices
18261825

1827-
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1826+
with (
1827+
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
1828+
if fake_mo 10418 de
1829+
else nullcontext()
1830+
):
18281831
optimized_function = inner_compile(
18291832
opt_model,
18301833
aot_example_inputs,

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ def decide_compile_threads() -> int:
861861
# Freezing will attempt to inline weights as constants in optimization
862862
# and run constant folding and other optimizations on them. After freezing, weights
863863
# can no longer be updated.
864-
freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
864+
freezing: Optional[bool] = get_tristate_env("TORCHINDUCTOR_FREEZING")
865865

866866
# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
867867
# of potentially keeping multiple copies of weights.

torch/fx/graph_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,9 @@ def __init__(
535535
if self.graph._tracer_extras:
536536
self._tracer_extras = self.graph._tracer_extras
537537

538-
# Dictionary to store metadata
539-
self.meta: dict[str, Any] = {}
538+
# Dictionary to store metadata. Initialize with the root metadata, if present,
539+
# to avoid losing information when doing fx transformations.
540+
self.meta: dict[str, Any] = root.meta if isinstance(root, GraphModule) else {}
540541
self._replace_hooks: list[Callable] = []
541542
self._create_node_hooks: list[Callable] = []
542543
self._erase_node_hooks: list[Callable] = []

0 commit comments

Comments
 (0)
0