10000 AOTI freezing: fix test issues and enable by default · pytorch/pytorch@058a891 · GitHub
[go: up one dir, main page]

Skip to content

Commit 058a891

Browse files
AOTI freezing: fix test issues and enable by default
ghstack-source-id: d1bf7c3 Pull Request resolved: #149961
1 parent 84c905a commit 058a891

File tree

7 files changed

+51
-39
lines changed

7 files changed

+51
-39
lines changed

benchmarks/dynamo/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2917,7 +2917,7 @@ def parse_args(args=None):
29172917
)
29182918
parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
29192919
parser.add_argument(
2920-
"--freezing", action="store_true", help="turn on freezing", default=False
2920+
"--freezing", action="store_true", help="turn on freezing", default=None
29212921
)
29222922
parser.add_argument(
29232923
"--inductor-config",

test/inductor/test_aot_inductor.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@
122122
raise
123123

124124

125+
def _is_cpu_freezing(self):
126+
return (config.freezing is None or config.freezing) and self.device != GPU_TYPE
127+
128+
125129
class AOTInductorTestsTemplate:
126130
def test_simple(self):
127131
class Model(torch.nn.Module):
@@ -4137,16 +4141,16 @@ def forward(self, a):
41374141
a = torch.randn(batch, M, K, device=self.device)
41384142
example_inputs = (a,)
41394143

4140-
kernel_calls = (
4141-
[
4144+
is_cpu_freezing = _is_cpu_freezing(self)
4145+
if self.device == GPU_TYPE:
4146+
kernel_calls = [
41424147
("triton_poi_fused_0", 1),
41434148
(f"aoti_torch_{GPU_TYPE}_addmm_out", 2),
41444149
]
4145-
if self.device == GPU_TYPE
4146-
else [
4147-
("aoti_torch_cpu_addmm_out", 2),
4148-
]
4149-
)
4150+
elif is_cpu_freezing:
4151+
kernel_calls = [("cpp_fused_0", 1)]
4152+
else:
4153+
kernel_calls = [("aoti_torch_cpu_addmm_out", 2)]
41504154

41514155
# test default debug printing all tensor values codegen
41524156
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
@@ -4170,7 +4174,9 @@ def forward(self, a):
41704174
).run(code)
41714175

41724176
# test printing selected kernel's tensor values codegen
4173-
filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out"
4177+
filtered_kernel_name = (
4178+
"cpp_fused_0" if is_cpu_freezing else f"aoti_torch_{self.device}_addmm_out"
4179+
)
41744180
with config.patch(
41754181
{
41764182
"aot_inductor.debug_intermediate_value_printer": "2",
@@ -4181,7 +4187,7 @@ def forward(self, a):
41814187
AOTIRunnerUtil.legacy_compile, model, example_inputs
41824188
)
41834189
filtered_kernel_calls = [
4184-
(filtered_kernel_name, 2),
4190+
(filtered_kernel_name, 1 if is_cpu_freezing else 2),
41854191
]
41864192
for kernel_call, count in filtered_kernel_calls:
41874193
FileCheck().check_count(
@@ -4226,17 +4232,18 @@ def forward(self, a):
42264232
batch = 2
42274233
a = torch.randn(batch, M, K, device=self.device)
42284234
example_inputs = (a,)
4229-
kernel_calls = (
4230-
f"aoti_torch_{GPU_TYPE}_addmm_out"
4231-
if self.device == GPU_TYPE
4232-
else "aoti_torch_cpu_addmm_out"
4235+
4236+
kernel_call = (
4237+
"graph_1_cpp_fused_0"
4238+
if _is_cpu_freezing(self)
4239+
else f"aoti_torch_{self.device}_addmm_out"
42334240
)
42344241
with config.patch({"cpp.enable_kernel_profile": enable_kernel_profile}):
42354242
_, code = run_and_get_cpp_code(
42364243
AOTIRunnerUtil.compile, model, example_inputs
42374244
)
42384245
shim_fn_codes = (
4239-
f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef<c10::IValue>());'
4246+
f'RECORD_FUNCTION("{kernel_call}", c10::ArrayRef<c10::IValue>());'
42404247
)
42414248
if enable_kernel_profile:
42424249
FileCheck().check(shim_fn_codes).run(code)
@@ -4486,14 +4493,15 @@ def forward(self, a, b, c):
44864493
so_path, code = run_and_get_cpp_code(
44874494
AOTIRunnerUtil.legacy_compile, model, example_inputs
44884495
)
4489-
lowerbound_check = "u1 >= 1" if mark_unbacked else "u0 >= 2"
4496+
varname = f"u{int(mark_unbacked) + (2 if _is_cpu_freezing(self) else 0)}"
4497+
lowerbound_check = f"{varname} >= {1 if mark_unbacked else 2}"
44904498
FileCheck().check_count(lowerbound_check, 1).run(code)
44914499

44924500
compiled = AOTIRunnerUtil.legacy_load(self.device, so_path)
44934501
compiled(*example_inputs)
44944502

44954503
# Check the runtime assertion.
4496-
with self.assertRaisesRegex(Exception, ""):
4504+
with self.assertRaises(Exception):
44974505
unexpected_inputs = (torch.ones(0, device=self.device), b, c)
44984506
compiled(*unexpected_inputs)
44994507

torch/_inductor/codegen/cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5262,7 +5262,7 @@ def codegen_group(self, name=None) -> str:
52625262
prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
52635263
code.writelines(
52645264
[
5265-
f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
5265+
f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>());'
52665266
]
52675267
)
52685268
for old, new in self.args.aliases():

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3695,12 +3695,12 @@ def add_constexpr_arg(arg_name):
36953695
if (
36963696
len(non_constexpr_signature(signature)) == 4
36973697
): # input, output and 2 args
3698-
tile_hint = "tile_hint=TileHint.SQUARE,"
3698+
tile_hint = " tile_hint=TileHint.SQUARE,"
36993699
else:
3700-
tile_hint = "tile_hint=TileHint.DEFAULT,"
3700+
tile_hint = " tile_hint=TileHint.DEFAULT,"
37013701
heuristics_line = f"""
37023702
@triton_heuristics.{self._get_heuristic()}(
3703-
size_hints={size_hints!r}, {tile_hint}
3703+
size_hints={size_hints!r},{tile_hint}
37043704
filename=__file__,
37053705
triton_meta={triton_meta!r},
37063706
inductor_meta={inductor_meta!r},

torch/_inductor/compile_fx.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414
from abc import ABC, abstractmethod
1515
from collections import defaultdict
16-
from contextlib import AbstractContextManager
16+
from contextlib import AbstractContextManager, nullcontext
1717
from inspect import currentframe
1818
from itertools import count
1919
from operator import attrgetter
@@ -1656,35 +1656,34 @@ def compile_fx_aot(
16561656
model_: GraphModule,
16571657
example_inputs_: list[InputType],
16581658
inner_compile: _CompileFxCallable = compile_fx_inner,
1659-
config_patches: Optional[dict[str, str]] = None,
1659+
config_patches: Optional[dict[str, Any]] = None,
16601660
) -> Union[list[str], str]:
16611661
assert isinstance(model_, GraphModule), model_
16621662

16631663
# [See NOTE] Unwrapping subclasses AOT
16641664
unwrap_tensor_subclass_parameters(model_)
16651665

1666-
config_patches: dict[str, Any] = (
1667-
{"cpp_wrapper": True}
1668-
if config_patches is None
1669-
else {**config_patches, "cpp_wrapper": True}
1670-
)
1666+
if config_patches is None:
1667+
config_patches = {}
16711668

1672-
output_path = config_patches.get(
1673-
"aot_inductor.output_path", config.aot_inductor.output_path
1669+
config_patches.update(
1670+
cpp_wrapper=True,
1671+
freezing=config.freezing
1672+
if config.freezing is not None
1673+
else not config.aot_inductor.use_runtime_constant_folding,
16741674
)
16751675

1676-
if output_path:
1676+
if output_path := config_patches.get(
1677+
"aot_inductor.output_path", config.aot_inductor.output_path
1678+
):
16771679
assert not output_path.endswith(".pt2"), (
16781680
"The output path for aot_compile should not have an extension with .pt2 "
16791681
"this is for specifying the output path for the .so in AOTInductor. "
16801682
"If you would like to package the AOTInductor generated files "
16811683
"into a pt2, please call `torch._inductor.aoti_compile_and_package`."
16821684
)
16831685
else:
1684-
config_patches = {
1685-
**config_patches,
1686-
"aot_inductor.output_path": code_hash(model_.code),
1687-
}
1686+
config_patches["aot_inductor.output_path"] = code_hash(model_.code)
16881687

16891688
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
16901689
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
@@ -1789,7 +1788,11 @@ def fw_compiler_freezing(
17891788
if tracing_context.fw_metadata:
17901789
static_input_idxs = tracing_context.fw_metadata.static_input_indices
17911790

1792-
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1791+
with (
1792+
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
1793+
if fake_mo 97AE de
1794+
else nullcontext()
1795+
):
17931796
optimized_function = inner_compile(
17941797
opt_model,
17951798
aot_example_inputs,

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ def decide_compile_threads() -> int:
861861
# Freezing will attempt to inline weights as constants in optimization
862862
# and run constant folding and other optimizations on them. After freezing, weights
863863
# can no longer be updated.
864-
freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
864+
freezing: Optional[bool] = get_tristate_env("TORCHINDUCTOR_FREEZING")
865865

866866
# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
867867
# of potentially keeping multiple copies of weights.

torch/fx/graph_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,9 @@ def __init__(
535535
if self.graph._tracer_extras:
536536
self._tracer_extras = self.graph._tracer_extras
537537

538-
# Dictionary to store metadata
539-
self.meta: dict[str, Any] = {}
538+
# Dictionary to store metadata. Initialize with the root metadata, if present,
539+
# to avoid losing information when doing fx transformations.
540+
self.meta: dict[str, Any] = root.meta if isinstance(root, GraphModule) else {}
540541
self._replace_hooks: list[Callable] = []
541542
self._create_node_hooks: list[Callable] = []
542543
self._erase_node_hooks: list[Callable] = []

0 commit comments

Comments
 (0)
0