8000 [not for merge] cpp_wrapper: test building portions at O1 · pytorch/pytorch@16f2c74 · GitHub
[go: up one dir, main page]

Skip to content

Commit 16f2c74

Browse files
[not for merge] cpp_wrapper: test building portions at O1
ghstack-source-id: 1aba1f4 Pull Request resolved: #148773
1 parent fb4ad06 commit 16f2c74

File tree

3 files changed

+143
-58
lines changed

3 files changed

+143
-58
lines changed

torch/_inductor/codecache.py

Lines changed: 114 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def use_global_cache() -> bool: # type: ignore[misc]
126126
from concurrent.futures import Future
127127

128128
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
129+
from .cpp_builder import BuildOptionsBase
129130
from .graph import GraphLowering
130131
from .ir import ChoiceCaller
131132
from .output_code import CompiledFxGraphConstants, OutputCode
@@ -1977,7 +1978,7 @@ def _get_file_checksum(filename: str) -> str:
19771978
os.makedirs(_HEADER_LOCK_DIR, exist_ok=True)
19781979
_worker_compile_cpp(
19791980
os.path.join(_HEADER_LOCK_DIR, f"{header_hash}.lock"),
1980-
cpp_builder,
1981+
(cpp_builder,),
19811982
)
19821983

19831984
return header_full_path
@@ -2044,10 +2045,11 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
20442045
@classmethod
20452046
def load_async(
20462047
cls,
2047-
source_code: str,
2048+
main_code: str,
20482049
device_type: str = "cpu",
20492050
submit_fn: Any = None,
20502051
extra_flags: Sequence[str] = (),
2052+
optimized_code: Optional[str] = None,
20512053
) -> Any:
20522054
compile_command = {
20532055
**cls.cpp_compile_command_flags,
@@ -2059,46 +2061,112 @@ def load_async(
20592061

20602062
_set_gpu_runtime_env() # cpp_extension consults the env
20612063

2062-
cpp_build_option = CppTorchDeviceOptions(**compile_command)
2063-
command_gen = CppBuilder(name="o", sources="i", BuildOption=cpp_build_option)
2064-
# write function will calc source_code hash, the same source code with different
2065-
# ISA level should be generate different hash.
2066-
# So we need get a command_line which contains isa related parameter as a part of hash key.
2067-
# And then pass the command_line to below write function as extra parameter to
2068-
# guarantee the source code hash contains ISA difference.
2069-
vec_isa_cmd = repr(command_gen.get_command_line())
2070-
key, input_path = write(source_code, "cpp", extra=vec_isa_cmd)
2064+
# Note the distinction between the two booleans. We do minimal optimization if
2065+
# the optimized_code argument is present at all, since that's how the user of
2066+
# this function opts in, but we do compilation and linking in one step if the
2067+
# optimized_code argument is empty (as a micro-optimization).
2068+
main_build_option = CppTorchDeviceOptions(
2069+
compile_only=bool(optimized_code),
2070+
min_optimize=optimized_code is not None,
2071+
**compile_command,
2072+
)
2073+
optimized_build_option = CppTorchDeviceOptions(
2074+
compile_only=True, **compile_command
2075+
)
2076+
2077+
def get_hashable_command_line(build_option: BuildOptionsBase) -> str:
2078+
"""Writing the code to file will calculate a hash, which we need to vary if
2079+
the command line flags change. This implements a mostly-generic way of
2080+
validating that."""
2081+
return CppBuilder(
2082+
name="o", sources="i", BuildOption=build_option
2083+
).get_command_line()
2084+
2085+
main_cmd_line = get_hashable_command_line(main_build_option)
2086+
optimized_cmd_line = get_hashable_command_line(optimized_build_option)
2087+
2088+
key, main_path = write(
2089+
main_code, "main.cpp", extra=f"{optimized_code} {main_cmd_line}"
2090+
)
2091+
2092+
# Don't bother writing if the argument is empty.
2093+
if optimized_code:
2094+
_, optimized_path = write(
2095+
optimized_code, "optimized.cpp", extra=optimized_cmd_line
2096+
)
2097+
else:
2098+
# Unused, but makes type checkers happy.
2099+
optimized_path = os.devnull
20712100

20722101
if key not in cls.cache:
20732102
from torch.utils._filelock import FileLock
20742103

20752104
lock_path = os.path.join(get_lock_dir(), key + ".lock")
2076-
output_name, output_dir = get_name_and_dir_from_output_file_path(input_path)
20772105
future: Optional[Future[Any]] = None
20782106
lib = None
20792107

20802108
# if requested, pre-compile any headers
2081-
if config.cpp_cache_precompile_headers and (
2082-
header_file := cls._get_uncompiled_header(device_type)
2083-
):
2084-
cpp_build_option.precompiled_header = _precompile_header(
2085-
header_file,
2086-
vec_isa_cmd,
2087-
**compile_command,
2088-
)
2109+
if config.cpp_cache_precompile_headers:
2110+
if header := cls._get_uncompiled_header(device_type):
2111+
main_build_option.precompiled_header = _precompile_header(
2112+
header,
2113+
main_cmd_line,
2114+
min_optimize=optimized_code is not None,
2115+
**compile_command,
2116+
)
20892117

2090-
cpp_builder = CppBuilder(
2091-
name=output_name,
2092-
sources=input_path,
2118+
# Currently, the optimized_code field is only used for cpp kernel code,
2119+
# so go ahead and precompile the relevant header here. Revisit this
2120+
# decision if that ever changes.
2121+
if optimized_code and (header := _get_cpp_prefix_header(device_type)):
2122+
optimized_build_option.precompiled_header = _precompile_header(
2123+
header,
2124+
optimized_cmd_line,
2125+
**compile_command,
2126+
)
2127+
2128+
main_name, output_dir = get_name_and_dir_from_output_file_path(main_path)
2129+
main_builder = CppBuilder(
2130+
name=main_name,
2131+
sources=main_path,
2132+
BuildOption=main_build_option,
20932133
output_dir=output_dir,
2094-
BuildOption=cpp_build_option,
2095-
)
2096-
worker_fn = functools.partial(
2097-
_worker_compile_cpp,
2098-
lock_path,
2099-
cpp_builder,
21002134
)
2101-
binary_path = normalize_path_separator(cpp_builder.get_target_file_path())
2135+
2136+
if optimized_code:
2137+
optimized_name, _ = get_name_and_dir_from_output_file_path(
2138+
optimized_path
2139+
)
2140+
optimized_builder = CppBuilder(
2141+
name=optimized_name,
2142+
sources=optimized_path,
2143+
BuildOption=optimized_build_option,
2144+
output_dir=output_dir,
2145+
)
2146+
2147+
linker = CppBuilder(
2148+
name=main_name,
2149+
sources=[
2150+
main_builder.get_target_file_path(),
2151+
optimized_builder.get_target_file_path(),
2152+
],
2153+
BuildOption=CppTorchDeviceOptions(**compile_command),
2154+
output_dir=output_dir,
2155+
)
2156+
2157+
worker_fn = functools.partial(
2158+
_worker_compile_cpp,
2159+
lock_path,
2160+
(main_builder, optimized_builder, linker),
2161+
)
2162+
binary_path = normalize_path_separator(linker.get_target_file_path())
2163+
else:
2164+
worker_fn = functools.partial(
2165+
_worker_compile_cpp, lock_path, (main_builder,)
2166+
)
2167+
binary_path = normalize_path_separator(
2168+
main_builder.get_target_file_path()
2169+
)
21022170

21032171
def load_fn() -> Any:
21042172
nonlocal lib
@@ -2121,19 +2189,20 @@ def load_fn() -> Any:
21212189
return cls.cache[key]
21222190

21232191
@classmethod
2124-
def load(cls, source_code: str, device_type: str = "cpu") -> Any:
2125-
return cls.load_async(source_code, device_type)()
2192+
def load(cls, *args: Any, **kwargs: Any) -> Any:
2193+
return cls.load_async(*args, **kwargs)()
21262194

21272195

21282196
def _worker_compile_cpp(
21292197
lock_path: str,
2130-
cpp_builder: CppBuilder,
2198+
cpp_builders: Sequence[CppBuilder],
21312199
) -> None:
21322200
from torch.utils._filelock import FileLock
21332201

21342202
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
2135-
if not os.path.exists(cpp_builder.get_target_file_path()):
2136-
cpp_builder.build()
2203+
for builder in cpp_builders:
2204+
if not os.path.exists(builder.get_target_file_path()):
2205+
builder.build()
21372206

21382207

21392208
# Customized Python binding for cpp kernels
@@ -2262,19 +2331,24 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
22622331
@classmethod
22632332
def load_pybinding_async(
22642333
cls,
2265-
argtypes: list[str],
2266-
source_code: str,
2334+
argtypes: Sequence[str],
2335+
main_code: str,
22672336
device_type: str = "cpu",
22682337
num_outputs: int = -1,
22692338
submit_fn: Any = None,
22702339
extra_flags: Sequence[str] = (),
2340+
kernel_code: Optional[str] = None,
22712341
) -> Any:
22722342
"""
22732343
Wrap a C++ function in fast Python bindings.
22742344
22752345
Args:
22762346
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
2277-
source_code: C++ source code containing a ENTRY_FUNCTION() function
2347+
main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at
2348+
-O3 if kernel_code is None (to maximize performance in any kernels that
2349+
are present), or -O1 otherwise (to minimize compile time).
2350+
kernel_code: If present, C++ source code that will be built at -O3 and
2351+
linked to main_code.
22782352
22792353
Returns:
22802354
A python version of ENTRY_FUNCTION()
@@ -2296,10 +2370,11 @@ def load_pybinding_async(
22962370
cls.entry_function,
22972371
)
22982372
get_result = cls.load_async(
2299-
source_code + suffix,
2373+
main_code + suffix,
23002374
device_type,
23012375
submit_fn=submit_fn,
23022376
extra_flags=extra_flags,
2377+
optimized_code=kernel_code,
23032378
)
23042379
result = None
23052380

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,19 +1011,29 @@ def generate_end(self, result):
10111011
result.writeline("} // namespace torch::aot_inductor\n\n\n")
10121012
return
10131013

1014-
# Add any kernel definitions into the wrapped code. We currently only build
1015-
# them in separate files in AOT mode.
1016-
result.splice(self.kernel_declarations.getvalue())
1017-
self.kernel_declarations.clear()
1014+
# Close the wrapper code block, then write any kernel definitions.
1015+
result.splice("'''\n)")
1016+
if self.kernel_declarations:
1017+
result.splice("\nkernel_src = (\nr'''")
1018+
result.splice(self.kernel_declarations.getvalue())
1019+
result.splice("'''\n)")
1020+
else:
1021+
result.splice(
1022+
"""
1023+
kernel_src = ''
1024+
"""
1025+
)
10181026

10191027
# cpp entry function for JIT with cpp wrapper
10201028
result.splice(
10211029
f"""
1022-
'''
1023-
)
1024-
10251030
inductor_entry = CppWrapperCodeCache.load_pybinding(
1026-
["std::vector<AtenTensorHandle>"], cpp_wrapper_src, "{self.device}", {len(V.graph.graph_outputs)})
1031+
argtypes=["std::vector<AtenTensorHandle>"],
1032+
main_code=cpp_wrapper_src,
1033+
device_type="{self.device}",
1034+
num_outputs={len(V.graph.graph_outputs)},
1035+
kernel_code=kernel_src,
1036+
)
10271037
"""
10281038
)
10291039

torch/_inductor/graph.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,9 +2085,9 @@ def compile_to_module(self) -> ModuleType:
20852085
def _compile_to_module(self) -> ModuleType:
20862086
from .codecache import PyCodeCache
20872087

2088-
# Currently, if we're here, we don't have to worry about the kernel code, which
2089-
# is only available in AOTInductor mode.
2090-
wrapper_code, _ = (
2088+
# If we're here, we don't have to worry about the kernel code, which is only
2089+
# returned separately in AOTInductor mode.
2090+
src_code, _ = (
20912091
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
20922092
)
20932093
if config.triton.autotune_at_compile_time:
@@ -2098,33 +2098,33 @@ def _compile_to_module(self) -> ModuleType:
20982098
+ self.wrapper_code.kernel_autotune_calls.getvalue()
20992099
+ '"""\n'
21002100
)
2101-
wrapper_code.value = tuning_code + wrapper_code.value
2101+
src_code.value = tuning_code + src_code.value
21022102
if GraphLowering.save_output_code is not None:
2103-
GraphLowering.save_output_code(wrapper_code.value)
2104-
output_code_log.debug("Output code: \n%s", wrapper_code.value)
2103+
GraphLowering.save_output_code(src_code.value)
2104+
output_code_log.debug("Output code: \n%s", src_code.value)
21052105

21062106
inductor_meta = autotune_cache.inductor_meta_from_config()
2107-
AutotuneCacheBundler.begin_compile(inductor_meta, code=wrapper_code.value)
2107+
AutotuneCacheBundler.begin_compile(inductor_meta, code=src_code.value)
21082108

21092109
try:
21102110
linemap = [
21112111
(line_no, node.stack_trace) # type: ignore[attr-defined]
2112-
for line_no, node in wrapper_code.line_map
2112+
for line_no, node in src_code.line_map
21132113
]
2114-
key, path = PyCodeCache.write(wrapper_code.value)
2114+
key, path = PyCodeCache.write(src_code.value)
21152115
output_code_log.debug("Output code written to: %s", path)
21162116
except Exception:
21172117
trace_structured(
21182118
"inductor_output_code",
21192119
# Just omit the filename, I still want the code though!
2120-
payload_fn=lambda: wrapper_code.value,
2120+
payload_fn=lambda: src_code.value,
21212121
)
21222122
raise
21232123
else:
21242124
trace_structured(
21252125
"inductor_output_code",
21262126
lambda: {"filename": path},
2127-
payload_fn=lambda: wrapper_code.value,
2127+
payload_fn=lambda: src_code.value,
21282128
)
21292129
with dynamo_timed("PyCodeCache.load_by_key_path", log_pt2_compile_event=True):
21302130
mod = PyCodeCache.load_by_key_path(

0 commit comments

Comments
 (0)
0