8000 cpp_wrapper: build non-performance-sensitive code at O1 · pytorch/pytorch@0bbbd2c · GitHub
[go: up one dir, main page]

Skip to content

Commit 0bbbd2c

Browse files
cpp_wrapper: build non-performance-sensitive code at O1
Builds on #148212, applying the same improvements to `cpp_wrapper` mode. ghstack-source-id: bdba464 Pull Request resolved: #148773
1 parent 0df77fa commit 0bbbd2c

File tree

3 files changed

+134
-51
lines changed

3 files changed

+134
-51
lines changed

torch/_inductor/codecache.py

Lines changed: 114 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def use_global_cache() -> bool: # type: ignore[misc]
125125
from concurrent.futures import Future
126126

127127
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
128+
from .cpp_builder import BuildOptionsBase
128129
from .graph import GraphLowering
129130
from .ir import ChoiceCaller
130131
from .output_code import CompiledFxGraphConstants, OutputCode
@@ -2018,7 +2019,7 @@ def _get_file_checksum(filename: str) -> str:
20182019
os.makedirs(_HEADER_LOCK_DIR, exist_ok=True)
20192020
_worker_compile_cpp(
20202021
os.path.join(_HEADER_LOCK_DIR, f"{header_hash}.lock"),
2021-
cpp_builder,
2022+
(cpp_builder,),
20222023
)
20232024

20242025
return header_full_path
@@ -2085,10 +2086,11 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
20852086
@classmethod
20862087
def load_async(
20872088
cls,
2088-
source_code: str,
2089+
main_code: str,
20892090
device_type: str = "cpu",
20902091
submit_fn: Any = None,
20912092
extra_flags: Sequence[str] = (),
2093+
optimized_code: Optional[str] = None,
20922094
) -> Any:
20932095
compile_command = {
20942096
**cls.cpp_compile_command_flags,
@@ -2100,48 +2102,112 @@ def load_async(
21002102

21012103
_set_gpu_runtime_env() # cpp_extension consults the env
21022104

2103-
cpp_build_option = CppTorchDeviceOptions(**compile_command)
2104-
command_gen = CppBuilder(name="o", sources="i", BuildOption=cpp_build_option)
2105-
# write function will calc source_code hash, the same source code with different
2106-
# ISA level should be generate different hash.
2107-
# So we need get a command_line which contains isa related parameter as a part of hash key.
2108-
# And then pass the command_line to below write function as extra parameter to
2109-
# guarantee the source code hash contains ISA difference.
2110-
vec_isa_cmd = repr(command_gen.get_command_line())
2111-
key, input_path = write(source_code, "cpp", extra=vec_isa_cmd)
2105+
# Note the distinction between the two booleans. We do minimal optimization if
2106+
# the optimized_code argument is present at all, since that's how the user of
2107+
# this function opts in, but we do compilation and linking in one step if the
2108+
# optimized_code argument is empty (as a micro-optimization).
2109+
main_build_option = CppTorchDeviceOptions(
2110+
compile_only=bool(optimized_code),
2111+
min_optimize=optimized_code is not None,
2112+
**compile_command,
2113+
)
2114+
optimized_build_option = CppTorchDeviceOptions(
2115+
compile_only=True, **compile_command
2116+
)
2117+
2118+
def get_hashable_command_line(build_option: BuildOptionsBase) -> str:
2119+
"""Writing the code to file will calculate a hash, which we need to vary if
2120+
the command line flags change. This implements a mostly-generic way of
2121+
validating that."""
2122+
return CppBuilder(
2123+
name="o", sources="i", BuildOption=build_option
2124+
).get_command_line()
2125+
2126+
main_cmd_line = get_hashable_command_line(main_build_option)
2127+
optimized_cmd_line = get_hashable_command_line(optimized_build_option)
2128+
2129+
key, main_path = write(
2130+
main_code, "main.cpp", extra=f"{optimized_code} {main_cmd_line}"
2131+
)
2132+
2133+
# Don't bother writing if the argument is empty.
2134+
if optimized_code:
2135+
_, optimized_path = write(
2136+
optimized_code, "optimized.cpp", extra=optimized_cmd_line
2137+
)
2138+
else:
2139+
# Unused, but makes type checkers happy.
2140+
optimized_path = os.devnull
21122141

21132142
if key not in cls.cache:
21142143
from torch.utils._filelock import FileLock
21152144

21162145
lock_path = os.path.join(get_lock_dir(), key + ".lock")
2117 2D0F -
output_name, output_dir = get_name_and_dir_from_output_file_path(input_path)
21182146
future: Optional[Future[Any]] = None
21192147
lib = None
21202148

21212149
# if requested, pre-compile any headers
2122-
if (
2123-
config.cpp_cache_precompile_headers
2124-
and not _IS_WINDOWS
2125-
and (header_file := cls._get_uncompiled_header(device_type))
2126-
):
2127-
cpp_build_option.precompiled_header = _precompile_header(
2128-
header_file,
2129-
vec_isa_cmd,
2130-
**compile_command,
2131-
)
2150+
if config.cpp_cache_precompile_headers and not _IS_WINDOWS:
2151+
if header := cls._get_uncompiled_header(device_type):
2152+
main_build_option.precompiled_header = _precompile_header(
2153+
header,
2154+
main_cmd_line,
2155+
min_optimize=optimized_code is not None,
2156+
**compile_command,
2157+
)
21322158

2133-
cpp_builder = CppBuilder(
2134-
name=output_name,
2135-
sources=input_path,
2159+
# Currently, the optimized_code field is only used for cpp kernel code,
2160+
# so go ahead and precompile the relevant header here. Revisit this
2161+
# decision if that ever changes.
2162+
if optimized_code and (header := _get_cpp_prefix_header(device_type)):
2163+
optimized_build_option.precompiled_header = _precompile_header(
2164+
header,
2165+
optimized_cmd_line,
2166+
**compile_command,
2167+
)
2168+
2169+
main_name, output_dir = get_name_and_dir_from_output_file_path(main_path)
2170+
main_builder = CppBuilder(
2171+
name=main_name,
2172+
sources=main_path,
2173+
BuildOption=main_build_option,
21362174
output_dir=output_dir,
2137-
BuildOption=cpp_build_option,
21382175
)
2139-
worker_fn = functools.partial(
2140-
_worker_compile_cpp,
2141-
lock_path,
2142-
cpp_builder,
2143-
)
2144-
binary_path = normalize_path_separator(cpp_builder.get_target_file_path())
2176+
2177+
if optimized_code:
2178+
optimized_name, _ = get_name_and_dir_from_output_file_path(
2179+
optimized_path
2180+
)
2181+
optimized_builder = CppBuilder(
2182+
name=optimized_name,
2183+
sources=optimized_path,
2184+
BuildOption=optimized_build_option,
2185+
output_dir=output_dir,
2186+
)
2187+
2188+
linker = CppBuilder(
2189+
name=main_name,
2190+
sources=[
2191+
main_builder.get_target_file_path(),
2192+
optimized_builder.get_target_file_path(),
2193+
],
2194+
BuildOption=CppTorchDeviceOptions(**compile_command),
2195+
output_dir=output_dir,
2196+
)
2197+
2198+
worker_fn = functools.partial(
2199+
_worker_compile_cpp,
2200+
lock_path,
2201+
(main_builder, optimized_builder, linker),
2202+
)
2203+
binary_path = normalize_path_separator(linker.get_target_file_path())
2204+
else:
2205+
worker_fn = functools.partial(
2206+
_worker_compile_cpp, lock_path, (main_builder,)
2207+
)
2208+
binary_path = normalize_path_separator(
2209+
main_builder.get_target_file_path()
2210+
)
21452211

21462212
def load_fn() -> Any:
21472213
nonlocal lib
@@ -2164,19 +2230,20 @@ def load_fn() -> Any:
21642230
return cls.cache[key]
21652231

21662232
@classmethod
2167-
def load(cls, source_code: str, device_type: str = "cpu") -> Any:
2168-
return cls.load_async(source_code, device_type)()
2233+
def load(cls, *args: Any, **kwargs: Any) -> Any:
2234+
return cls.load_async(*args, **kwargs)()
21692235

21702236

21712237
def _worker_compile_cpp(
21722238
lock_path: str,
2173-
cpp_builder: CppBuilder,
2239+
cpp_builders: Sequence[CppBuilder],
21742240
) -> None:
21752241
from torch.utils._filelock import FileLock
21762242

21772243
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
2178-
if not os.path.exists(cpp_builder.get_target_file_path()):
2179-
cpp_builder.build()
2244+
for builder in cpp_builders:
2245+
if not os.path.exists(builder.get_target_file_path()):
2246+
builder.build()
21802247

21812248

21822249
# Customized Python binding for cpp kernels
@@ -2305,19 +2372,24 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
23052372
@classmethod
23062373
def load_pybinding_async(
23072374
cls,
2308-
argtypes: list[str],
2309-
source_code: str,
2375+
argtypes: Sequence[str],
2376+
main_code: str,
23102377
device_type: str = "cpu",
23112378
num_outputs: int = -1,
23122379
submit_fn: Any = None,
23132380
extra_flags: Sequence[str] = (),
2381+
kernel_code: Optional[str] = None,
23142382
) -> Any:
23152383
"""
23162384
Wrap a C++ function in fast Python bindings.
23172385
23182386
Args:
23192387
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
2320-
source_code: C++ source code containing a ENTRY_FUNCTION() function
2388+
main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at
2389+
-O3 if kernel_code is None (to maximize performance in any kernels that
2390+
are present), or -O1 otherwise (to minimize compile time).
2391+
kernel_code: If present, C++ source code that will be built at -O3 and
2392+
linked to main_code.
23212393
23222394
Returns:
23232395
A python version of ENTRY_FUNCTION()
@@ -2333,10 +2405,11 @@ def load_pybinding_async(
23332405
extra_parse_arg=cls.extra_parse_arg.format(array_len=num_outputs),
23342406
)
23352407
get_result = cls.load_async(
2336-
source_code + suffix,
2408+
main_code + suffix,
23372409
device_type,
23382410
submit_fn=submit_fn,
23392411
extra_flags=extra_flags,
2412+
optimized_code=kernel_code,
23402413
)
23412414
result = None
23422415

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,19 +1013,29 @@ def generate_end(self, result):
10131013
result.writeline("} // namespace torch::aot_inductor\n\n\n")
10141014
return
10151015

1016-
# Add any kernel definitions into the wrapped code. We currently only build
1017-
# them in separate files in AOT mode.
1018-
result.splice(self.kernel_declarations.getvalue())
1019-
self.kernel_declarations.clear()
1016+
# Close the wrapper code block, then write any kernel definitions.
1017+
result.splice("'''\n)")
1018+
if self.kernel_declarations:
1019+
result.splice("\nkernel_src = (\nr'''")
1020+
result.splice(self.kernel_declarations.getvalue())
1021+
result.splice("'''\n)")
1022+
else:
1023+
result.splice(
1024+
"""
1025+
kernel_src = ''
1026+
"""
1027+
)
10201028

10211029
# cpp entry function for JIT with cpp wrapper
10221030
result.splice(
10231031
f"""
1024-
'''
1025-
)
1026-
10271032
inductor_entry = CppWrapperCodeCache.load_pybinding(
1028-
["std::vector<AtenTensorHandle>"], cpp_wrapper_src, "{self.device}", {len(V.graph.graph_outputs)})
1033+
argtypes=["std::vector<AtenTensorHandle>"],
1034+
main_code=cpp_wrapper_src,
1035+
device_type="{self.device}",
1036+
num_outputs={len(V.graph.graph_outputs)},
1037+
kernel_code=kernel_src,
1038+
)
10291039
"""
10301040
)
10311041

torch/_inductor/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,8 +2203,8 @@ def compile_to_module(self) -> ModuleType:
22032203
def _compile_to_module(self) -> ModuleType:
22042204
from .codecache import PyCodeCache
22052205

2206-
# Currently, if we're here, we don't have to worry about the kernel code, which
2207-
# is only available in AOTInductor mode.
2206+
# If we're here, we don't have to worry about the kernel code, which is only
2207+
# returned separately in AOTInductor mode.
22082208
wrapper_code, _ = (
22092209
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
22102210
)

0 commit comments

Comments
 (0)
0