|
13 | 13 | import warnings
|
14 | 14 | from abc import ABC, abstractmethod
|
15 | 15 | from collections import defaultdict
|
16 |
| -from contextlib import AbstractContextManager |
| 16 | +from contextlib import AbstractContextManager, nullcontext |
17 | 17 | from inspect import currentframe
|
18 | 18 | from itertools import count
|
19 | 19 | from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
@@ -1579,35 +1579,34 @@ def compile_fx_aot(
|
1579 | 1579 | model_: GraphModule,
|
1580 | 1580 | example_inputs_: list[InputType],
|
1581 | 1581 | inner_compile: _CompileFxCallable = compile_fx_inner,
|
1582 |
| - config_patches: Optional[dict[str, str]] = None, |
| 1582 | + config_patches: Optional[dict[str, Any]] = None, |
1583 | 1583 | ) -> Union[list[str], str]:
|
1584 | 1584 | assert isinstance(model_, GraphModule), model_
|
1585 | 1585 |
|
1586 | 1586 | # [See NOTE] Unwrapping subclasses AOT
|
1587 | 1587 | unwrap_tensor_subclass_parameters(model_)
|
1588 | 1588 |
|
1589 |
| - config_patches: dict[str, Any] = ( |
1590 |
| - {"cpp_wrapper": True} |
1591 |
| - if config_patches is None |
1592 |
| - else {**config_patches, "cpp_wrapper": True} |
1593 |
| - ) |
| 1589 | + if config_patches is None: |
| 1590 | + config_patches = {} |
1594 | 1591 |
|
1595 |
| - output_path = config_patches.get( |
1596 |
| - "aot_inductor.output_path", config.aot_inductor.output_path |
| 1592 | + config_patches.update( |
| 1593 | + cpp_wrapper=True, |
| 1594 | + freezing=config.freezing |
| 1595 | + if config.freezing is not None |
| 1596 | + else not config.aot_inductor.use_runtime_constant_folding, |
1597 | 1597 | )
|
1598 | 1598 |
|
1599 |
| - if output_path: |
| 1599 | + if output_path := config_patches.get( |
| 1600 | + "aot_inductor.output_path", config.aot_inductor.output_path |
| 1601 | + ): |
1600 | 1602 | assert not output_path.endswith(".pt2"), (
|
1601 | 1603 | "The output path for aot_compile should not have an extension with .pt2 "
|
1602 | 1604 | "this is for specifying the output path for the .so in AOTInductor. "
|
1603 | 1605 | "If you would like to package the AOTInductor generated files "
|
1604 | 1606 | "into a pt2, please call `torch._inductor.aoti_compile_and_package`."
|
1605 | 1607 | )
|
1606 | 1608 | else:
|
1607 |
| - config_patches = { |
1608 |
| - **config_patches, |
1609 |
| - "aot_inductor.output_path": code_hash(model_.code), |
1610 |
| - } |
| 1609 | + config_patches["aot_inductor.output_path"] = code_hash(model_.code) |
1611 | 1610 |
|
1612 | 1611 | extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
1613 | 1612 | saved_compile_id = model_.meta.get("dynamo_compile_id", None)
|
@@ -1713,7 +1712,11 @@ def fw_compiler_freezing(
|
1713 | 1712 | if tracing_context.fw_metadata:
|
1714 | 1713 | static_input_idxs += tracing_context.fw_metadata.static_input_indices
|
1715 | 1714 |
|
1716 |
| - with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): |
| 1715 | + with ( |
| 1716 | + mock.patch.object(fake_mode, "allow_non_fake_inputs", True) |
| 1717 | + if fake_mode |
| 1718 | + else nullcontext() |
| 1719 | + ): |
1717 | 1720 | optimized_function = inner_compile(
|
1718 | 1721 | opt_model,
|
1719 | 1722 | aot_example_inputs,
|
|
0 commit comments