8000 [NOT FOR MERGE] Re-implement #148962 for benchmarking · pytorch/pytorch@956f86d · GitHub
[go: up one dir, main page]

Skip to content

Commit 956f86d

Browse files
[NOT FOR MERGE] Re-implement #148962 for benchmarking
ghstack-source-id: ac1a806 Pull Request resolved: #149961
1 parent d7bee66 commit 956f86d

File tree

4 files changed

+23
-20
lines changed

4 files changed

+23
-20
lines changed

benchmarks/dynamo/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2909,7 +2909,7 @@ def parse_args(args=None):
29092909
)
29102910
parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
29112911
parser.add_argument(
2912-
"--freezing", action="store_true", help="turn on freezing", default=False
2912+
"--freezing", action="store_true", help="turn on freezing", default=None
29132913
)
29142914
parser.add_argument(
29152915
"--inductor-config",

torch/_inductor/codegen/triton.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3689,12 +3689,12 @@ def add_constexpr_arg(arg_name):
36893689
if (
36903690
len(non_constexpr_signature(signature)) == 4
36913691
): # input, output and 2 args
3692-
tile_hint = "tile_hint=TileHint.SQUARE,"
3692+
tile_hint = " tile_hint=TileHint.SQUARE,"
36933693
else:
3694-
tile_hint = "tile_hint=TileHint.DEFAULT,"
3694+
tile_hint = " tile_hint=TileHint.DEFAULT,"
36953695
heuristics_line = f"""
36963696
@triton_heuristics.{self._get_heuristic()}(
3697-
size_hints={size_hints!r}, {tile_hint}
3697+
size_hints={size_hints!r},{tile_hint}
36983698
filename=__file__,
36993699
triton_meta={triton_meta!r},
37003700
inductor_meta={inductor_meta!r},

torch/_inductor/compile_fx.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414
from abc import ABC, abstractmethod
1515
from collections import defaultdict
16-
from contextlib import AbstractContextManager
16+
from contextlib import AbstractContextManager, nullcontext
1717
from inspect import currentframe
1818
from itertools import count
1919
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
@@ -1553,35 +1553,34 @@ def compile_fx_aot(
15531553
model_: GraphModule,
15541554
example_inputs_: list[InputType],
15551555
inner_compile: _CompileFxCallable = compile_fx_inner,
1556-
config_patches: Optional[dict[str, str]] = None,
1556+
config_patches: Optional[dict[str, Any]] = None,
15571557
) -> Union[list[str], str]:
15581558
assert isinstance(model_, GraphModule), model_
15591559

15601560
# [See NOTE] Unwrapping subclasses AOT
15611561
unwrap_tensor_subclass_parameters(model_)
15621562

1563-
config_patches: dict[str, Any] = (
1564-
{"cpp_wrapper": True}
1565-
if config_patches is None
1566-
else {**config_patches, "cpp_wrapper": True}
1567-
)
1563+
if config_patches is None:
1564+
config_patches = {}
15681565

1569-
output_path = config_patches.get(
1570-
"aot_inductor.output_path", config.aot_inductor.output_path
1566+
config_patches.update(
1567+
cpp_wrapper=True,
1568+
freezing=config.freezing
1569+
if config.freezing is not None
1570+
else not config.aot_inductor.use_runtime_constant_folding,
15711571
)
15721572

1573-
if output_path:
1573+
if output_path := config_patches.get(
1574+
"aot_inductor.output_path", config.aot_inductor.output_path
1575+
):
15741576
assert not output_path.endswith(".pt2"), (
15751577
"The output path for aot_compile should not have an extension with .pt2 "
15761578
"this is for specifying the output path for the .so in AOTInductor. "
15771579
"If you would like to package the AOTInductor generated files "
15781580
"into a pt2, please call `torch._inductor.aoti_compile_and_package`."
15791581
)
15801582
else:
1581-
config_patches = {
1582-
**config_patches,
1583-
"aot_inductor.output_path": code_hash(model_.code),
1584-
}
1583+
config_patches["aot_inductor.output_path"] = code_hash(model_.code)
15851584

15861585
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
15871586
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
@@ -1687,7 +1686,11 @@ def fw_compiler_freezing(
16871686
if tracing_context.fw_metadata:
16881687
static_input_idxs += tracing_context.fw_metadata.static_input_indices
16891688

1690-
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1689+
with (
1690+
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
1691+
if fake_mode
1692+
else nullcontext()
1693+
):
16911694
optimized_function = inner_compile(
16921695
opt_model,
16931696
aot_example_inputs,

torch/_inductor/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def decide_compile_threads() -> int:
847847
# Freezing will attempt to inline weights as constants in optimization
848848
# and run constant folding and other optimizations on them. After freezing, weights
849849
# can no longer be updated.
850-
freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
850+
freezing: Optional[bool] = get_tristate_env("TORCHINDUCTOR_FREEZING")
851851

852852
# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
853853
# of potentially keeping multiple copies of weights.

0 commit comments

Comments
 (0)
0