|
13 | 13 | import warnings
|
14 | 14 | from abc import ABC, abstractmethod
|
15 | 15 | from collections import defaultdict
|
16 |
| -from contextlib import AbstractContextManager |
| 16 | +from contextlib import AbstractContextManager, nullcontext |
17 | 17 | from inspect import currentframe
|
18 | 18 | from itertools import count
|
19 | 19 | from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
@@ -1530,35 +1530,34 @@ def compile_fx_aot(
|
1530 | 1530 | model_: GraphModule,
|
1531 | 1531 | example_inputs_: list[InputType],
|
1532 | 1532 | inner_compile: _CompileFxCallable = compile_fx_inner,
|
1533 |
| - config_patches: Optional[dict[str, str]] = None, |
| 1533 | + config_patches: Optional[dict[str, Any]] = None, |
1534 | 1534 | ) -> Union[list[str], str]:
|
1535 | 1535 | assert isinstance(model_, GraphModule), model_
|
1536 | 1536 |
|
1537 | 1537 | # [See NOTE] Unwrapping subclasses AOT
|
1538 | 1538 | unwrap_tensor_subclass_parameters(model_)
|
1539 | 1539 |
|
1540 |
| - config_patches: dict[str, Any] = ( |
1541 |
| - {"cpp_wrapper": True} |
1542 |
| - if config_patches is None |
1543 |
| - else {**config_patches, "cpp_wrapper": True} |
1544 |
| - ) |
| 1540 | + if config_patches is None: |
| 1541 | + config_patches = {} |
1545 | 1542 |
|
1546 |
| - output_path = config_patches.get( |
1547 |
| - "aot_inductor.output_path", config.aot_inductor.output_path |
| 1543 | + config_patches.update( |
| 1544 | + cpp_wrapper=True, |
| 1545 | + freezing=config.freezing |
| 1546 | + if config.freezing is not None |
| 1547 | + else not config.aot_inductor.use_runtime_constant_folding, |
1548 | 1548 | )
|
1549 | 1549 |
|
1550 |
| - if output_path: |
| 1550 | + if output_path := config_patches.get( |
| 1551 | + "aot_inductor.output_path", config.aot_inductor.output_path |
| 1552 | + ): |
1551 | 1553 | assert not output_path.endswith(".pt2"), (
|
1552 | 1554 | "The output path for aot_compile should not have an extension with .pt2 "
|
1553 | 1555 | "this is for specifying the output path for the .so in AOTInductor. "
|
1554 | 1556 | "If you would like to package the AOTInductor generated files "
|
1555 | 1557 | "into a pt2, please call `torch._inductor.aoti_compile_and_package`."
|
1556 | 1558 | )
|
1557 | 1559 | else:
|
1558 |
| - config_patches = { |
1559 |
| - **config_patches, |
1560 |
| - "aot_inductor.output_path": code_hash(model_.code), |
1561 |
| - } |
| 1560 | + config_patches["aot_inductor.output_path"] = code_hash(model_.code) |
1562 | 1561 |
|
1563 | 1562 | extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
1564 | 1563 | saved_compile_id = model_.meta.get("dynamo_compile_id", None)
|
@@ -1664,7 +1663,11 @@ def fw_compiler_freezing(
|
1664 | 1663 | if tracing_context.fw_metadata:
|
1665 | 1664 | static_input_idxs += tracing_context.fw_metadata.static_input_indices
|
1666 | 1665 |
|
1667 |
| - with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): |
| 1666 | + with ( |
| 1667 | + mock.patch.object(fake_mode, "allow_non_fake_inputs", True) |
| 1668 | + if fake_mode |
| 1669 | + else nullcontext() |
| 1670 | + ): |
1668 | 1671 | optimized_function = inner_compile(
|
1669 | 1672 | opt_model,
|
1670 | 1673 | aot_example_inputs,
|
|
0 commit comments