8000 [NOT FOR MERGE] Re-implement #148962 for benchmarking · pytorch/pytorch@7d4370b · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 7d4370b

Browse files
[NOT FOR MERGE] Re-implement #148962 for benchmarking
ghstack-source-id: ad12068 Pull Request resolved: #149961
1 parent ac6ab54 commit 7d4370b

File tree

4 files changed

+23
-20
lines changed

4 files changed

+23
-20
lines changed

benchmarks/dynamo/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2910,7 +2910,7 @@ def parse_args(args=None):
29102910
)
29112911
parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
29122912
parser.add_argument(
2913-
"--freezing", action="store_true", help="turn on freezing", default=False
2913+
"--freezing", action="store_true", help="turn on freezing", default=None
29142914
)
29152915
parser.add_argument(
29162916
"--inductor-config",

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3700,12 +3700,12 @@ def add_constexpr_arg(arg_name):
37003700
if (
37013701
len(non_constexpr_signature(signature)) == 4
37023702
): # input, output and 2 args
3703-
tile_hint = "tile_hint=TileHint.SQUARE,"
3703+
tile_hint = " tile_hint=TileHint.SQUARE,"
37043704
else:
3705-
tile_hint = "tile_hint=TileHint.DEFAULT,"
3705+
tile_hint = " tile_hint=TileHint.DEFAULT,"
37063706
heuristics_line = f"""
37073707
@triton_heuristics.{self._get_heuristic()}(
3708-
size_hints={size_hints!r}, {tile_hint}
3708+
size_hints={size_hints!r},{tile_hint}
37093709
filename=__file__,
37103710
triton_meta={triton_meta!r},
37113711
inductor_meta={inductor_meta!r},

torch/_inductor/compile_fx.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414
from abc import ABC, abstractmethod
1515
from collections import defaultdict
16-
from contextlib import AbstractContextManager
16+
from contextlib import AbstractContextManager, nullcontext
1717
from inspect import currentframe
1818
from itertools import count
1919
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
@@ -1579,35 +1579,34 @@ def compile_fx_aot(
15791579
model_: GraphModule,
15801580
example_inputs_: list[InputType],
15811581
inner_compile: _CompileFxCallable = compile_fx_inner,
1582-
config_patches: Optional[dict[str, str]] = None,
1582+
config_patches: Optional[dict[str, Any]] = None,
15831583
) -> Union[list[str], str]:
15841584
assert isinstance(model_, GraphModule), model_
15851585

15861586
# [See NOTE] Unwrapping subclasses AOT
15871587
unwrap_tensor_subclass_parameters(model_)
15881588

1589-
config_patches: dict[str, Any] = (
1590-
{"cpp_wrapper": True}
1591-
if config_patches is None
1592-
else {**config_patches, "cpp_wrapper": True}
1593-
)
1589+
if config_patches is None:
1590+
config_patches = {}
15941591

1595-
output_path = config_patches.get(
1596-
"aot_inductor.output_path", config.aot_inductor.output_path
1592+
config_patches.update(
1593+
cpp_wrapper=True,
1594+
freezing=config.freezing
1595+
if config.freezing is not None
1596+
else not config.aot_inductor.use_runtime_constant_folding,
15971597
)
15981598

1599-
if output_path:
1599+
if output_path := config_patches.get(
1600+
"aot_inductor.output_path", config.aot_inductor.output_path
1601+
):
16001602
assert not output_path.endswith(".pt2"), (
16011603
"The output path for aot_compile should not have an extension with .pt2 "
16021604
"this is for specifying the output path for the .so in AOTInductor. "
16031605
"If you would like to package the AOTInductor generated files "
16041606
"into a pt2, please call `torch._inductor.aoti_compile_and_package`."
16051607
)
16061608
else:
1607-
config_patches = {
1608-
**config_patches,
1609-
"aot_inductor.output_path": code_hash(model_.code),
1610-
}
1609+
config_patches["aot_inductor.output_path"] = code_hash(model_.code)
16111610

16121611
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
16131612
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
@@ -1713,7 +1712,11 @@ def fw_compiler_freezing(
17131712
if tracing_context.fw_metadata:
17141713
static_input_idxs += tracing_context.fw_metadata.static_input_indices
17151714

1716-
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1715+
with (
1716+
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
1717+
if fake_mode
1718+
else nullcontext()
1719+
):
17171720
optimized_function = inner_compile(
17181721
opt_model,
17191722
aot_example_inputs,

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ def decide_compile_threads() -> int:
859859
# Freezing will attempt to inline weights as constants in optimization
860860
# and run constant folding and other optimizations on them. After freezing, weights
861861
# can no longer be updated.
862-
freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
862+
freezing: Optional[bool] = get_tristate_env("TORCHINDUCTOR_FREEZING")
863863

864864
# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
865865
# of potentially keeping multiple copies of weights.

0 commit comments

Comments
 (0)
0