8000 cpp_wrapper: build non-performance-sensitive code at O1 · pytorch/pytorch@f44c163 · GitHub
[go: up one dir, main page]

Skip to content

Commit f44c163

Browse files
cpp_wrapper: build non-performance-sensitive code at O1
Builds on #148212, applying the same improvements to `cpp_wrapper` mode. ghstack-source-id: df16e3d Pull Request resolved: #148773
1 parent 53c1d92 commit f44c163

File tree

3 files changed

+134
-51
lines changed

3 files changed

+134
-51
lines changed

torch/_inductor/codecache.py

Lines changed: 114 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def use_global_cache() -> bool: # type: ignore[misc]
125125
from concurrent.futures import Future
126126

127127
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
128+
from .cpp_builder import BuildOptionsBase
128129
from .graph import GraphLowering
129130
from .ir import ChoiceCaller
130131
from .output_code import CompiledFxGraphConstants, OutputCode
@@ -2014,7 +2015,7 @@ def _get_file_checksum(filename: str) -> str:
20142015
os.makedirs(_HEADER_LOCK_DIR, exist_ok=True)
20152016
_worker_compile_cpp(
20162017
os.path.join(_HEADER_LOCK_DIR, f"{header_hash}.lock"),
2017-
cpp_builder,
2018+
(cpp_builder,),
20182019
)
20192020

20202021
return header_full_path
@@ -2081,10 +2082,11 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
20812082
@classmethod
20822083
def load_async(
20832084
cls,
2084-
source_code: str,
2085+
main_code: str,
20852086
device_type: str = "cpu",
20862087
submit_fn: Any = None,
20872088
extra_flags: Sequence[str] = (),
2089+
optimized_code: Optional[str] = None,
20882090
) -> Any:
20892091
compile_command = {
20902092
**cls.cpp_compile_command_flags,
@@ -2096,48 +2098,112 @@ def load_async(
20962098

20972099
_set_gpu_runtime_env() # cpp_extension consults the env
20982100

2099-
cpp_build_option = CppTorchDeviceOptions(**compile_command)
2100-
command_gen = CppBuilder(name="o", sources="i", BuildOption=cpp_build_option)
2101-
# write function will calc source_code hash, the same source code with different
2102-
# ISA level should be generate different hash.
2103-
# So we need get a command_line which contains isa related parameter as a part of hash key.
2104-
# And then pass the command_line to below write function as extra parameter to
2105-
# guarantee the source code hash contains ISA difference.
2106-
vec_isa_cmd = repr(command_gen.get_command_line())
2107-
key, input_path = write(source_code, "cpp", extra=vec_isa_cmd)
2101+
# Note the distinction between the two booleans. We do minimal optimization if
2102+
# the optimized_code argument is present at all, since that's how the user of
2103+
# this function opts in, but we do compilation and linking in one step if the
2104+
# optimized_code argument is empty (as a micro-optimization).
2105+
main_build_option = CppTorchDeviceOptions(
2106+
compile_only=bool(optimized_code),
2107+
min_optimize=optimized_code is not None,
2108+
**compile_command,
2109+
)
2110+
optimized_build_option = CppTorchDeviceOptions(
2111+
compile_only=True, **compile_command
2112+
)
2113+
2114+
def get_hashable_command_line(build_option: BuildOptionsBase) -> str:
2115+
"""Writing the code to file will calculate a hash, which we need to vary if
2116+
the command line flags change. This implements a mostly-generic way of
2117+
validating that."""
2118+
return CppBuilder(
2119+
name="o", sources="i", BuildOption=build_option
2120+
).get_command_line()
2121+
2122+
main_cmd_line = get_hashable_command_line(main_build_option)
2123+
optimized_cmd_line = get_hashable_command_line(optimized_build_option)
2124+
2125+
key, main_path = write(
2126+
main_code, "main.cpp", extra=f"{optimized_code} {main_cmd_line}"
2127+
)
2128+
2129+
# Don't bother writing if the argument is empty.
2130+
if optimized_code:
2131+
_, optimized_path = write(
2132+
optimized_code, "optimized.cpp", extra=optimized_cmd_line
2133+
)
2134+
else:
2135+
# Unused, but makes type checkers happy.
2136+
optimized_path = os.devnull
21082137

21092138
if key not in cls.cache:
21102139
from torch.utils._filelock import FileLock
21112140

21122141
lock_path = os.path.join(get_lock_dir(), key + ".lock")
2113-
output_name, output_dir = get_name_and_dir_from_output_file_path(input_path)
21142142
future: Optional[Future[Any]] = None
21152143
lib = None
21162144

21172145
# if requested, pre-compile any headers
2118-
if (
2119-
config.cpp_cache_precompile_headers
2120-
and not _IS_WINDOWS
2121-
and (header_file := cls._get_uncompiled_header(device_type))
2122-
):
2123-
cpp_build_option.precompiled_header = _precompile_header(
2124-
header_file,
2125-
vec_isa_cmd,
2126-
**compile_command,
2127-
)
2146+
if config.cpp_cache_precompile_headers and not _IS_WINDOWS:
2147+
if header := cls._get_uncompiled_header(device_type):
2148+
main_build_option.precompiled_header = _precompile_header(
2149+
header,
2150+
main_cmd_line,
2151+
min_optimize=optimized_code is not None,
2152+
**compile_command,
2153+
)
21282154

2129-
cpp_builder = CppBuilder(
2130-
name=output_name,
2131-
sources=input_path,
2155+
# Currently, the optimized_code field is only used for cpp kernel code,
2156+
# so go ahead and precompile the relevant header here. Revisit this
2157+
# decision if that ever changes.
2158+
if optimized_code and (header := _get_cpp_prefix_header(device_type)):
2159+
optimized_build_option.precompiled_header = _precompile_header(
2160+
header,
2161+
optimized_cmd_line,
2162+
**compile_command,
2163+
)
2164+
2165+
main_name, output_dir = get_name_and_dir_from_output_file_path(main_path)
2166+
main_builder = CppBuilder(
2167+
name=main_name,
2168+
sources=main_path,
2169+
BuildOption=main_build_option,
21322170
output_dir=output_dir,
2133-
BuildOption=cpp_build_option,
21342171
)
2135-
worker_fn = functools.partial(
2136-
_worker_compile_cpp,
2137-
lock_path,
2138-
cpp_builder,
2139-
)
2140-
binary_path = normalize_path_separator(cpp_builder.get_target_file_path())
2172+
2173+
if optimized_code:
2174+
optimized_name, _ = get_name_and_dir_from_output_file_path(
2175+
optimized_path
2176+
)
2177+
optimized_builder = CppBuilder(
2178+
name=optimized_name,
2179+
sources=optimized_path,
2180+
BuildOption=optimized_build_option,
2181+
output_dir=output_dir,
2182+
)
2183+
2184+
linker = CppBuilder(
2185+
name=main_name,
2186+
sources=[
2187+
main_builder.get_target_file_path(),
2188+
optimized_builder.get_target_file_path(),
2189+
],
2190+
BuildOption=CppTorchDeviceOptions(**compile_command),
2191+
output_dir=output_dir,
2192+
)
2193+
2194+
worker_fn = functools.partial(
2195+
_worker_compile_cpp,
2196+
lock_path,
2197+
(main_builder, optimized_builder, linker),
2198+
)
2199+
binary_path = normalize_path_separator(linker.get_target_file_path())
2200+
else:
2201+
worker_fn = functools.partial(
2202+
_worker_compile_cpp, lock_path, (main_builder,)
2203+
)
2204+
binary_path = normalize_path_separator(
2205+
main_builder.get_target_file_path()
2206+
)
21412207

21422208
def load_fn() -> Any:
21432209
nonlocal lib
@@ -2160,19 +2226,20 @@ def load_fn() -> Any:
21602226
return cls.cache[key]
21612227

21622228
@classmethod
2163-
def load(cls, source_code: str, device_type: str = "cpu") -> Any:
2164-
return cls.load_async(source_code, device_type)()
2229+
def load(cls, *args: Any, **kwargs: Any) -> Any:
2230+
return cls.load_async(*args, **kwargs)()
21652231

21662232

21672233
def _worker_compile_cpp(
21682234
lock_path: str,
2169-
cpp_builder: CppBuilder,
2235+
cpp_builders: Sequence[CppBuilder],
21702236
) -> None:
21712237
from torch.utils._filelock import FileLock
21722238

21732239
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
2174-
if not os.path.exists(cpp_builder.get_target_file_path()):
2175-
cpp_builder.build()
2240+
for builder in cpp_builders:
2241+
if not os.path.exists(builder.get_target_file_path()):
2242+
builder.build()
21762243

21772244

21782245
# Customized Python binding for cpp kernels
@@ -2301,19 +2368,24 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
23012368
@classmethod
23022369
def load_pybinding_async(
23032370
cls,
2304-
argtypes: list[str],
2305-
source_code: str,
2371+
argtypes: Sequence[str],
2372+
main_code: str,
23062373
device_type: str = "cpu",
23072374
num_outputs: int = -1,
23082375
submit_fn: Any = None,
23092376
extra_flags: Sequence[str] = (),
2377+
kernel_code: Optional[str] = None,
23102378
) -> Any:
23112379
"""
23122380
Wrap a C++ function in fast Python bindings.
23132381
23142382
Args:
23152383
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
2316-
source_code: C++ source code containing a ENTRY_FUNCTION() function
2384+
main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at
2385+
-O3 if kernel_code is None (to maximize performance in any kernels that
2386+
are present), or -O1 otherwise (to minimize compile time).
2387+
kernel_code: If present, C++ source code that will be built at -O3 and
2388+
linked to main_code.
23172389
23182390
Returns:
23192391
A python version of ENTRY_FUNCTION()
@@ -2329,10 +2401,11 @@ def load_pybinding_async(
23292401
extra_parse_arg=cls.extra_parse_arg.format(array_len=num_outputs),
23302402
)
23312403
get_result = cls.load_async(
2332-
source_code + suffix,
2404+
main_code + suffix,
23332405
device_type,
23342406
submit_fn=submit_fn,
23352407
extra_flags=extra_flags,
2408+
optimized_code=kernel_code,
23362409
)
23372410
result = None
23382411

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,19 +1003,29 @@ def generate_end(self, result):
10031003
result.writeline("} // namespace torch::aot_inductor\n\n\n")
10041004
return
10051005

1006-
# Add any kernel definitions into the wrapped code. We currently only build
1007-
# them in separate files in AOT mode.
1008-
result.splice(self.kernel_declarations.getvalue())
1009-
self.kernel_declarations.clear()
1006+
# Close the wrapper code block, then write any kernel definitions.
1007+
result.splice("'''\n)")
1008+
if self.kernel_declarations:
1009+
result.splice("\nkernel_src = (\nr'''")
1010+
result.splice(self.kernel_declarations.getvalue())
1011+
result.splice("'''\n)")
1012+
else:
1013+
result.splice(
1014+
"""
1015+
kernel_src = ''
1016+
"""
1017+
)
10101018

10111019
# cpp entry function for JIT with cpp wrapper
10121020
result.splice(
10131021
f"""
1014-
'''
1015-
)
1016-
10171022
inductor_entry = CppWrapperCodeCache.load_pybinding(
1018-
["std::vector<AtenTensorHandle>"], cpp_wrapper_src, "{self.device}", {len(V.graph.graph_outputs)})
1023+
argtypes=["std::vector<AtenTensorHandle>"],
1024+
main_code=cpp_wrapper_src,
1025+
device_type="{self.device}",
1026+
num_outputs={len(V.graph.graph_outputs)},
1027+
kernel_code=kernel_src,
1028+
)
10191029
"""
10201030
)
10211031

torch/_inductor/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,8 +2201,8 @@ def compile_to_module(self) -> ModuleType:
22012201
def _compile_to_module(self) -> ModuleType:
22022202
from .codecache import PyCodeCache
22032203

2204-
# Currently, if we're here, we don't have to worry about the kernel code, which
2205-
# is only available in AOTInductor mode.
2204+
# If we're here, we don't have to worry about the kernel code, which is only
2205+
# returned separately in AOTInductor mode.
22062206
wrapper_code, _ = (
22072207
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
22082208
)

0 commit comments

Comments
 (0)
0