|
13 | 13 | import warnings
|
14 | 14 | from abc import ABC, abstractmethod
|
15 | 15 | from collections import defaultdict
|
16 |
| -from contextlib import AbstractContextManager |
| 16 | +from contextlib import AbstractContextManager, nullcontext |
17 | 17 | from inspect import currentframe
|
18 | 18 | from itertools import count
|
19 | 19 | from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
@@ -1553,35 +1553,34 @@ def compile_fx_aot(
|
1553 | 1553 | model_: GraphModule,
|
1554 | 1554 | example_inputs_: list[InputType],
|
1555 | 1555 | inner_compile: _CompileFxCallable = compile_fx_inner,
|
1556 |
| - config_patches: Optional[dict[str, str]] = None, |
| 1556 | + config_patches: Optional[dict[str, Any]] = None, |
1557 | 1557 | ) -> Union[list[str], str]:
|
1558 | 1558 | assert isinstance(model_, GraphModule), model_
|
1559 | 1559 |
|
1560 | 1560 | # [See NOTE] Unwrapping subclasses AOT
|
1561 | 1561 | unwrap_tensor_subclass_parameters(model_)
|
1562 | 1562 |
|
1563 |
| - config_patches: dict[str, Any] = ( |
1564 |
| - {"cpp_wrapper": True} |
1565 |
| - if config_patches is None |
1566 |
| - else {**config_patches, "cpp_wrapper": True} |
1567 |
| - ) |
| 1563 | + if config_patches is None: |
| 1564 | + config_patches = {} |
1568 | 1565 |
|
1569 |
| - output_path = config_patches.get( |
1570 |
| - "aot_inductor.output_path", config.aot_inductor.output_path |
| 1566 | + config_patches.update( |
| 1567 | + cpp_wrapper=True, |
| 1568 | + freezing=config.freezing |
| 1569 | + if config.freezing is not None |
| 1570 | + else not config.aot_inductor.use_runtime_constant_folding, |
1571 | 1571 | )
|
1572 | 1572 |
|
1573 |
| - if output_path: |
| 1573 | + if output_path := config_patches.get( |
| 1574 | + "aot_inductor.output_path", config.aot_inductor.output_path |
| 1575 | + ): |
1574 | 1576 | assert not output_path.endswith(".pt2"), (
|
1575 | 1577 | "The output path for aot_compile should not have an extension with .pt2 "
|
1576 | 1578 | "this is for specifying the output path for the .so in AOTInductor. "
|
1577 | 1579 | "If you would like to package the AOTInductor generated files "
|
1578 | 1580 | "into a pt2, please call `torch._inductor.aoti_compile_and_package`."
|
1579 | 1581 | )
|
1580 | 1582 | else:
|
1581 |
| - config_patches = { |
1582 |
| - **config_patches, |
1583 |
| - "aot_inductor.output_path": code_hash(model_.code), |
1584 |
| - } |
| 1583 | + config_patches["aot_inductor.output_path"] = code_hash(model_.code) |
1585 | 1584 |
|
1586 | 1585 | extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
1587 | 1586 | saved_compile_id = model_.meta.get("dynamo_compile_id", None)
|
@@ -1687,7 +1686,11 @@ def fw_compiler_freezing(
|
1687 | 1686 | if tracing_context.fw_metadata:
|
1688 | 1687 | static_input_idxs += tracing_context.fw_metadata.static_input_indices
|
1689 | 1688 |
|
1690 |
| - with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): |
| 1689 | + with ( |
| 1690 | + mock.patch.object(fake_mode, "allow_non_fake_inputs", True) |
| 1691 | + if fake_mode |
| 1692 | + else nullcontext() |
| 1693 | + ): |
1691 | 1694 | optimized_function = inner_compile(
|
1692 | 1695 | opt_model,
|
1693 | 1696 | aot_example_inputs,
|
|
0 commit comments