8000 codecache: Remove cpp_prefix.h duplication per build, then precompile it by benjaminglass1 · Pull Request #144293 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

codecache: Remove cpp_prefix.h duplication per build, then precompile it #144293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
0702ad4
Update
benjaminglass1 Jan 7, 2025
9953b07
Update
benjaminglass1 Jan 7, 2025
2761aa5
Update
benjaminglass1 Jan 7, 2025
65ccc39
Update
benjaminglass1 Jan 7, 2025
42f0b07
Update
benjaminglass1 Jan 7, 2025
613dca3
Update
benjaminglass1 Jan 8, 2025
1bfb139
Update
benjaminglass1 Jan 8, 2025
4c0df96
Update
benjaminglass1 Jan 9, 2025
7023f4e
Update
benjaminglass1 Jan 10, 2025
3a7d905
Update
benjaminglass1 Jan 10, 2025
694fed5
Update
benjaminglass1 Jan 11, 2025
6b01538
Update
benjaminglass1 Jan 11, 2025
db4db83
Update
benjaminglass1 Jan 14, 2025
2087b95
Update
benjaminglass1 Jan 15, 2025
d08ba05
Update
benjaminglass1 Jan 15, 2025
ab5df27
Update
benjaminglass1 Jan 17, 2025
f4f262c
Update
benjaminglass1 Jan 17, 2025
57dbff1
Update
benjaminglass1 Jan 22, 2025
d8bb0d5
Update
benjaminglass1 Jan 27, 2025
c4f1e27
Update
benjaminglass1 Jan 30, 2025
55f6389
Update
benjaminglass1 Jan 30, 2025
134313b
Update
benjaminglass1 Feb 4, 2025
6710e5c
Update
benjaminglass1 Feb 4, 2025
fc43707
Update
benjaminglass1 Feb 5, 2025
e8d83c7
Update
benjaminglass1 Feb 10, 2025
5538250
Update
benjaminglass1 Feb 11, 2025
e865694
Update
benjaminglass1 Feb 11, 2025
47e023e
Update
benjaminglass1 Feb 11, 2025
1dd3ae4
Update
benjaminglass1 Feb 12, 2025
0088b88
Update
benjaminglass1 Feb 12, 2025
ce9d408
Update
benjaminglass1 Feb 12, 2025
72c117f
Update
benjaminglass1 Feb 13, 2025
2ce14aa
Update
benjaminglass1 Feb 25, 2025
e97bbea
Update
benjaminglass1 Feb 25, 2025
2900554
Update
benjaminglass1 Feb 27, 2025
fee1aec
Update
benjaminglass1 Feb 27, 2025
dac87a7
Update
benjaminglass1 Feb 28, 2025
834cd1b
Update
benjaminglass1 Feb 28, 2025
bb966c7
Update
benjaminglass1 Mar 6, 2025
9b23e25
Update
benjaminglass1 Mar 7, 2025
3a38db7
Update
benjaminglass1 Mar 10, 2025
9cdb33e
Update
benjaminglass1 Mar 10, 2025
b43890e
Update
benjaminglass1 Mar 10, 2025
5c10244
Update
benjaminglass1 Mar 11, 2025
03b7817
Update
benjaminglass1 Mar 11, 2025
ef92d9c
Update
benjaminglass1 Mar 11, 2025
c90fdbe
Update
benjaminglass1 Mar 11, 2025
95c7e4e
Update
benjaminglass1 Mar 12, 2025
c5c6b58
Update
benjaminglass1 Mar 12, 2025
5dcac83
Update
benjaminglass1 Mar 13, 2025
0ac30b0
Update
benjaminglass1 Mar 13, 2025
c735503
Update
benjaminglass1 Mar 24, 2025
1d01ac6
Update
benjaminglass1 Mar 27, 2025
f09b973
Update
benjaminglass1 Mar 28, 2025
b9a674b
Update
benjaminglass1 Apr 4, 2025
4128127
Update
benjaminglass1 Apr 8, 2025
dac9518
Update
benjaminglass1 Apr 10, 2025
d7f7258
Update
benjaminglass1 Apr 15, 2025
b130c23
Update
benjaminglass1 Apr 15, 2025
ccdde59
Update
benjaminglass1 Apr 16, 2025
5db0de3
Update
benjaminglass1 Apr 16, 2025
929a715
Update
benjaminglass1 Apr 17, 2025
ec375e8
Update
benjaminglass1 Apr 21, 2025
91bddf8
Update
benjaminglass1 Apr 30, 2025
a26ded9
Update
benjaminglass1 May 1, 2025
edd1bb0
Update
benjaminglass1 May 2, 2025
16eb3db
Update
benjaminglass1 May 7, 2025
98cb264
Update
benjaminglass1 May 10, 2025
a4ac732
Update
benjaminglass1 May 13, 2025
a5e7f27
Update
benjaminglass1 May 15, 2025
1fb5e56
Update
benjaminglass1 May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .ci/pytorch/test.sh

Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,11 @@ test_inductor_cpp_wrapper_shard() {
if [[ "$1" -eq "2" ]]; then
# For now, manually put the opinfo tests in shard 2, and all other tests in
# shard 1. Test specific things triggering past bugs, for now.
# shard 1. Run all CPU tests, as well as specific GPU tests triggering past
# bugs, for now.
python test/run_test.py \
--include inductor/test_torchinductor_opinfo \
-k 'linalg or to_sparse' \
-k 'linalg or to_sparse or TestInductorOpInfoCPU' \
--verbose
exit
fi
Expand Down
2 changes: 1 addition & 1 deletion .github/merge_rules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,13 @@
- torch/_inductor/mkldnn_lowerings.py
- torch/_inductor/fx_passes/mkldnn_fusion.py
- torch/_inductor/fx_passes/quantization.py
- torch/_inductor/codegen/cpp_prefix.h
- torch/_inductor/codegen/cpp.py
- torch/_inductor/codegen/cpp_utils.py
- torch/_inductor/codegen/cpp_micro_gemm.py
- torch/_inductor/codegen/cpp_template_kernel.py
- torch/_inductor/codegen/cpp_template.py
- torch/_inductor/codegen/cpp_gemm_template.py
- torch/csrc/inductor/cpp_prefix.h
- test/inductor/te 8000 st_mkldnn_pattern_matcher.py
- test/inductor/test_cpu_repro.py
- test/inductor/test_cpu_cpp_wrapper.py
Expand Down
4 changes: 0 additions & 4 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
aoti_eager_cache_dir,
load_aoti_eager_cache,
)
from torch._inductor.codecache import cpp_prefix_path
from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext
from torch._inductor.fx_passes import pad_mm
from torch._inductor.test_case import TestCase as InductorTestCase
Expand Down Expand Up @@ -6673,7 +6672,6 @@ def fn(x):
)

@xfail_if_mps_unimplemented
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
@config.patch(force_disable_caches=True)
@skip_if_cpp_wrapper("run_and_get_kernels issue")
def test_deterministic_codegen(self):
Expand Down Expand Up @@ -6722,7 +6720,6 @@ def c(x):
self.assertEqual(coda_b0, coda_b2)
self.assertEqual(coda_c0, coda_c2)

@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
@config.patch(force_disable_caches=True)
@skip_if_cpp_wrapper("run_and_get_kernels issue")
@xfail_if_mps
Expand All @@ -6744,7 +6741,6 @@ def b(x):
_, (code0, code1) = _run_and_get_stripped_kernels(b, x)
self.assertEqual(code0, code1)

@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
@config.patch(force_disable_caches=True)
@skip_if_cpp_wrapper("run_and_get_kernels issue")
@xfail_if_mps
Expand Down
58 changes: 21 additions & 37 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,6 @@ def get_code_hash(root: str) -> bytes:
# a hash representing the state of the source code.
extra_files = (
"codegen/aoti_runtime/interface.cpp",
"codegen/cpp_prefix.h",
"script.ld",
)
inductor_root = os.path.dirname(__file__)
Expand Down Expand Up @@ -1857,6 +1856,12 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
min_optimize=not config.aot_inductor.package_cpp_only,
**compile_command,
)
if cpp_prefix := _get_cpp_prefix_header(device_type):
kernel_build_options.precompiled_header = _precompile_header(
cpp_prefix,
cpp_command,
**compile_command,
)

wrapper_builder = CppBuilder(
name=str(wrapper_path_operator.stem),
Expand Down Expand Up @@ -2053,37 +2058,6 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
return output_so


# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py.
# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock.
# Cycle goes:
# - CppCodeCache.load()
# - pick_vec_isa()
# - valid_vec_isa_list()
# - VecISA.__bool__() <-- takes out a lock
# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock.
@clear_on_fresh_inductor_cache
@functools.lru_cache
def cpp_prefix_path() -> str:
path = Path(__file__).parent / "codegen/cpp_prefix.h"
with path.open() as f:
content = f.read()
_, filename = write(
content,
"h",
)
return normalize_path_separator(filename)


def cpp_prefix() -> str:
filename = cpp_prefix_path()
if config.is_fbcode():
# We need relative paths, since we bundle up
# everything that we compile into a folder for remote compilation.
return f'#include "{os.path.basename(filename)}"'
else:
return f'#include "{filename}"'


_libgomp: Optional[CDLL] = None


Expand Down Expand Up @@ -2206,6 +2180,12 @@ def _get_file_checksum(filename: str) -> str:
return header_full_path


def _get_cpp_prefix_header(device: str) -> Optional[str]:
if device.startswith("cpu"):
return "torch/csrc/inductor/cpp_prefix.h"
return None


def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str:
"""Given a device type (and optionally whether we're in AOT Inductor mode), returns
the path to the cpp_wrapper header file to be precompiled."""
Expand Down Expand Up @@ -2255,7 +2235,6 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
def _get_uncompiled_header(cls, device: str) -> str | None:
"""
Given a device type, returns the path to a CPP header file to be precompiled.
Currently, this is only utilized by the cpp_wrapper classes.
"""
return None

Expand Down Expand Up @@ -2475,6 +2454,10 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType:
spec.loader.exec_module(module) # type: ignore[union-attr]
return module

@classmethod
def _get_uncompiled_header(cls, device: str) -> str | None:
return _get_cpp_prefix_header(device)

@classmethod
def load_pybinding_async(
cls,
Expand Down Expand Up @@ -2593,10 +2576,6 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):

@classmethod
def _get_uncompiled_header(cls, device: str) -> str | None:
"""
Given a device type, returns the path to a CPP header file to be precompiled.
Currently, this is only utilized by the cpp_wrapper classes.
"""
return _get_cpp_wrapper_header(device)


Expand Down Expand Up @@ -2951,6 +2930,11 @@ def build_standalone_runtime(cls) -> str:
cls._standalone_runtime_path = sofile
return sofile

@classmethod
def _get_uncompiled_header(cls, device: str) -> str | None:
"""Header precompiling is currently disabled for halide."""
return None


def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
from torch.utils._filelock import FileLock
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT

from ..._dynamo.utils import counters
from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics
from .. import config, cpp_builder, cpu_vec_isa, ir, metrics
from ..loop_body import LoopBody
from ..scheduler import (
BaseSchedulerNode,
Expand Down Expand Up @@ -5256,7 +5256,7 @@ def codegen_group(self, name=None) -> str:
]
if enable_kernel_profile:
code.writelines(["#include <ATen/record_function.h>"])
code.writeline(codecache.cpp_prefix())
code.writeline("#include <torch/csrc/inductor/cpp_prefix.h>")

# 2. Function definition
kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import sympy

from .. import codecache, config, ir
from .. import config, ir
from ..autotune_process import CppBenchmarkRequest, TensorMeta
from ..utils import IndentedBuffer, Placeholder, unique
from ..virtualized import V
Expand Down Expand Up @@ -122,7 +122,7 @@ def make_kernel_render(

def header(self) -> IndentedBuffer:
res = IndentedBuffer()
res.writeline(codecache.cpp_prefix())
res.writeline("#include <torch/csrc/inductor/cpp_prefix.h>")
# TODO: add c10::ForcedUnroll test to test_aoti_abi_check
res.splice("""#include <c10/util/Unroll.h>""")
res.splice("""#include <torch/csrc/inductor/aoti_torch/c/shim.h>""")
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def prologue_fusion_enabled() -> bool:
# Controls automatic precompiling of common include files for codecache.CppCodeCache
# (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is
# controlled by a separate flag.
cpp_cache_precompile_headers: bool = True
cpp_cache_precompile_headers: bool = not is_fbcode()

online_softmax = os.environ.get("TORCHINDUCTOR_ONLINE_SOFTMAX", "1") == "1"

Expand Down Expand Up @@ -1293,7 +1293,7 @@ class aot_inductor:
package_constants_in_so: bool = True

# Experimental. Controls automatic precompiling of common AOTI include files.
precompile_headers: bool = False
precompile_headers: bool = not is_fbcode()


class cuda:
Expand Down
41 changes: 11 additions & 30 deletions torch/_inductor/cpp_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,28 +586,24 @@ def _get_optimization_cflags(
return cflags


def _get_shared_cflag(compile_only: bool) -> list[str]:
def _get_shared_cflag(do_link: bool) -> list[str]:
if _IS_WINDOWS:
"""
MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170
"""
SHARED_FLAG = ["DLL", "MD"]
else:
if compile_only:
return ["fPIC"]
if platform.system() == "Darwin" and "clang" in get_cpp_compiler():
# This causes undefined symbols to behave the same as linux
return ["shared", "fPIC", "undefined dynamic_lookup"]
else:
return ["shared", "fPIC"]

return SHARED_FLAG
return ["DLL", "MD"]
if not do_link:
return ["fPIC"]
if platform.system() == "Darwin" and "clang" in get_cpp_compiler():
# This causes undefined symbols to behave the same as linux
return ["shared", "fPIC", "undefined dynamic_lookup"]
return ["shared", "fPIC"]


def get_cpp_options(
cpp_compiler: str,
compile_only: bool,
do_link: bool,
warning_all: bool = True,
extra_flags: Sequence[str] = (),
min_optimize: bool = False,
Expand All @@ -621,7 +617,7 @@ def get_cpp_options(
passthrough_args: list[str] = []

cflags = (
_get_shared_cflag(compile_only)
_get_shared_cflag(do_link)
+ _get_optimization_cflags(cpp_compiler, min_optimize)
+ _get_warning_all_cflag(warning_all)
+ _get_cpp_std_cflag()
Expand Down Expand Up @@ -681,7 +677,7 @@ def __init__(
passthrough_args,
) = get_cpp_options(
cpp_compiler=self._compiler,
compile_only=compile_only,
do_link=not (compile_only or precompiling or preprocessing),
extra_flags=extra_flags,
warning_all=warning_all,
min_optimize=min_optimize,
Expand Down Expand Up @@ -1041,7 +1037,6 @@ def get_cpp_torch_options(
vec_isa: VecISA,
include_pytorch: bool,
aot_mode: bool,
compile_only: bool,
use_relative_path: bool,
use_mmap_weights: bool,
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
Expand Down Expand Up @@ -1175,7 +1170,6 @@ def __init__(
vec_isa=vec_isa,
include_pytorch=include_pytorch,
aot_mode=aot_mode,
compile_only=compile_only,
use_relative_path=use_relative_path,
use_mmap_weights=use_mmap_weights,
)
Expand Down Expand Up @@ -1276,13 +1270,6 @@ def get_cpp_torch_device_options(
"in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support."
1241 )

if aot_mode:
if config.is_fbcode():
from torch._inductor.codecache import cpp_prefix_path

cpp_prefix_include_dir = [f"{os.path.dirname(cpp_prefix_path())}"]
include_dirs += cpp_prefix_include_dir

if config.is_fbcode():
include_dirs.append(build_paths.sdk_include)

Expand Down Expand Up @@ -1633,22 +1620,16 @@ def get_target_file_path(self) -> str:
def build_fbcode_re(
self,
) -> None:
from torch._inductor.codecache import cpp_prefix_path

with dynamo_timed("compile_file"):
command = self.get_command_line().split()
try:
# Need to copy our header into the same folder as the sourcecode.
header_path = cpp_prefix_path()
header_name = os.path.basename(header_path)
output_path = self._target_file
# When we build remotely, we need to make sure to carefully copy any files
# that are required during the compilation process into our build directly.
# This is where all of the ATen/c10/Torch includes come from.
torch_includes_path = os.path.join(_TORCH_PATH, "include")
with tempfile.TemporaryDirectory() as tmp_dir:
# Copy everything to tmp compilation folder
shutil.copy(header_path, os.path.join(tmp_dir, header_name))
shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld"))
for src in self._orig_source_paths:
shutil.copy(src, os.path.join(tmp_dir, os.path.basename(src)))
Expand Down
16 changes: 1 addition & 15 deletions torch/_inductor/output_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
import dataclasses
import logging
import os
import re
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -672,26 +670,14 @@ def prepare_for_serialization(self) -> None:

def write_to_disk(self) -> str:
from torch._dynamo.utils import counters
from torch._inductor.codecache import cpp_prefix_path, get_path, write_atomic
from torch._inductor.codecache import get_path, write_atomic

# See _save_graph(); we don't store the callable in the cache entry so
# recreate it here from the PyCodeCache disk cache.
artifact_path = get_path(self.cache_key, "py")[2]
code = self.source_code
if not os.path.exists(artifact_path):
counters["inductor"]["fxgraph_lookup_write_file"] += 1
Path(os.path.dirname(artifact_path)).mkdir(parents=True, exist_ok=True)
cpp_pp = cpp_prefix_path()
if os.path.basename(cpp_pp) in code:
if cpp_pp in code:
# Great the name is correct
pass
else:
# Old dir name is included, replace it
pattern = rf'#include\s*"[^"]+{os.path.basename(cpp_pp)}"'
code = re.sub(pattern, f'#include "{cpp_pp}"', code)
self.source_code = code

write_atomic(artifact_path, code, make_dirs=True)
return artifact_path

Expand Down
Loading
Loading
0