@@ -125,6 +125,7 @@ def use_global_cache() -> bool: # type: ignore[misc]
125
125
from concurrent .futures import Future
126
126
127
127
from .compile_fx import _CompileFxKwargs , CompiledFxGraph
128
+ from .cpp_builder import BuildOptionsBase
128
129
from .graph import GraphLowering
129
130
from .ir import ChoiceCaller
130
131
from .output_code import CompiledFxGraphConstants , OutputCode
@@ -2018,7 +2019,7 @@ def _get_file_checksum(filename: str) -> str:
2018
2019
os .makedirs (_HEADER_LOCK_DIR , exist_ok = True )
2019
2020
_worker_compile_cpp (
2020
2021
os .path .join (_HEADER_LOCK_DIR , f"{ header_hash } .lock" ),
2021
- cpp_builder ,
2022
+ ( cpp_builder ,) ,
2022
2023
)
2023
2024
2024
2025
return header_full_path
@@ -2085,10 +2086,11 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
2085
2086
@classmethod
2086
2087
def load_async (
2087
2088
cls ,
2088
- source_code : str ,
2089
+ main_code : str ,
2089
2090
device_type : str = "cpu" ,
2090
2091
submit_fn : Any = None ,
2091
2092
extra_flags : Sequence [str ] = (),
2093
+ optimized_code : Optional [str ] = None ,
2092
2094
) -> Any :
2093
2095
compile_command = {
2094
2096
** cls .cpp_compile_command_flags ,
@@ -2100,48 +2102,112 @@ def load_async(
2100
2102
2101
2103
_set_gpu_runtime_env () # cpp_extension consults the env
2102
2104
2103
- cpp_build_option = CppTorchDeviceOptions (** compile_command )
2104
- command_gen = CppBuilder (name = "o" , sources = "i" , BuildOption = cpp_build_option )
2105
- # write function will calc source_code hash, the same source code with different
2106
- # ISA level should be generate different hash.
2107
- # So we need get a command_line which contains isa related parameter as a part of hash key.
2108
- # And then pass the command_line to below write function as extra parameter to
2109
- # guarantee the source code hash contains ISA difference.
2110
- vec_isa_cmd = repr (command_gen .get_command_line ())
2111
- key , input_path = write (source_code , "cpp" , extra = vec_isa_cmd )
2105
+ # Note the distinction between the two booleans. We do minimal optimization if
2106
+ # the optimized_code argument is present at all, since that's how the user of
2107
+ # this function opts in, but we do compilation and linking in one step if the
2108
+ # optimized_code argument is empty (as a micro-optimization).
2109
+ main_build_option = CppTorchDeviceOptions (
2110
+ compile_only = bool (optimized_code ),
2111
+ min_optimize = optimized_code is not None ,
2112
+ ** compile_command ,
2113
+ )
2114
+ optimized_build_option = CppTorchDeviceOptions (
2115
+ compile_only = True , ** compile_command
2116
+ )
2117
+
2118
+ def get_hashable_command_line (build_option : BuildOptionsBase ) -> str :
2119
+ """Writing the code to file will calculate a hash, which we need to vary if
2120
+ the command line flags change. This implements a mostly-generic way of
2121
+ validating that."""
2122
+ return CppBuilder (
2123
+ name = "o" , sources = "i" , BuildOption = build_option
2124
+ ).get_command_line ()
2125
+
2126
+ main_cmd_line = get_hashable_command_line (main_build_option )
2127
+ optimized_cmd_line = get_hashable_command_line (optimized_build_option )
2128
+
2129
+ key , main_path = write (
2130
+ main_code , "main.cpp" , extra = f"{ optimized_code } { main_cmd_line } "
2131
+ )
2132
+
2133
+ # Don't bother writing if the argument is empty.
2134
+ if optimized_code :
2135
+ _ , optimized_path = write (
2136
+ optimized_code , "optimized.cpp" , extra = optimized_cmd_line
2137
+ )
2138
+ else :
2139
+ # Unused, but makes type checkers happy.
2140
+ optimized_path = os .devnull
2112
2141
2113
2142
if key not in cls .cache :
2114
2143
from torch .utils ._filelock import FileLock
2115
2144
2116
2145
lock_path = os .path .join (get_lock_dir (), key + ".lock" )
2117
2D0F
- output_name , output_dir = get_name_and_dir_from_output_file_path (input_path )
2118
2146
future : Optional [Future [Any ]] = None
2119
2147
lib = None
2120
2148
2121
2149
# if requested, pre-compile any headers
2122
- if (
2123
- config .cpp_cache_precompile_headers
2124
- and not _IS_WINDOWS
2125
- and (header_file := cls ._get_uncompiled_header (device_type ))
2126
- ):
2127
- cpp_build_option .precompiled_header = _precompile_header (
2128
- header_file ,
2129
- vec_isa_cmd ,
2130
- ** compile_command ,
2131
- )
2150
+ if config .cpp_cache_precompile_headers and not _IS_WINDOWS :
2151
+ if header := cls ._get_uncompiled_header (device_type ):
2152
+ main_build_option .precompiled_header = _precompile_header (
2153
+ header ,
2154
+ main_cmd_line ,
2155
+ min_optimize = optimized_code is not None ,
2156
+ ** compile_command ,
2157
+ )
2132
2158
2133
- cpp_builder = CppBuilder (
2134
- name = output_name ,
2135
- sources = input_path ,
2159
+ # Currently, the optimized_code field is only used for cpp kernel code,
2160
+ # so go ahead and precompile the relevant header here. Revisit this
2161
+ # decision if that ever changes.
2162
+ if optimized_code and (header := _get_cpp_prefix_header (device_type )):
2163
+ optimized_build_option .precompiled_header = _precompile_header (
2164
+ header ,
2165
+ optimized_cmd_line ,
2166
+ ** compile_command ,
2167
+ )
2168
+
2169
+ main_name , output_dir = get_name_and_dir_from_output_file_path (main_path )
2170
+ main_builder = CppBuilder (
2171
+ name = main_name ,
2172
+ sources = main_path ,
2173
+ BuildOption = main_build_option ,
2136
2174
output_dir = output_dir ,
2137
- BuildOption = cpp_build_option ,
2138
2175
)
2139
- worker_fn = functools .partial (
2140
- _worker_compile_cpp ,
2141
- lock_path ,
2142
- cpp_builder ,
2143
- )
2144
- binary_path = normalize_path_separator (cpp_builder .get_target_file_path ())
2176
+
2177
+ if optimized_code :
2178
+ optimized_name , _ = get_name_and_dir_from_output_file_path (
2179
+ optimized_path
2180
+ )
2181
+ optimized_builder = CppBuilder (
2182
+ name = optimized_name ,
2183
+ sources = optimized_path ,
2184
+ BuildOption = optimized_build_option ,
2185
+ output_dir = output_dir ,
2186
+ )
2187
+
2188
+ linker = CppBuilder (
2189
+ name = main_name ,
2190
+ sources = [
2191
+ main_builder .get_target_file_path (),
2192
+ optimized_builder .get_target_file_path (),
2193
+ ],
2194
+ BuildOption = CppTorchDeviceOptions (** compile_command ),
2195
+ output_dir = output_dir ,
2196
+ )
2197
+
2198
+ worker_fn = functools .partial (
2199
+ _worker_compile_cpp ,
2200
+ lock_path ,
2201
+ (main_builder , optimized_builder , linker ),
2202
+ )
2203
+ binary_path = normalize_path_separator (linker .get_target_file_path ())
2204
+ else :
2205
+ worker_fn = functools .partial (
2206
+ _worker_compile_cpp , lock_path , (main_builder ,)
2207
+ )
2208
+ binary_path = normalize_path_separator (
2209
+ main_builder .get_target_file_path ()
2210
+ )
2145
2211
2146
2212
def load_fn () -> Any :
2147
2213
nonlocal lib
@@ -2164,19 +2230,20 @@ def load_fn() -> Any:
2164
2230
return cls .cache [key ]
2165
2231
2166
2232
@classmethod
2167
- def load (cls , source_code : str , device_type : str = "cpu" ) -> Any :
2168
- return cls .load_async (source_code , device_type )()
2233
+ def load (cls , * args : Any , ** kwargs : Any ) -> Any :
2234
+ return cls .load_async (* args , ** kwargs )()
2169
2235
2170
2236
2171
2237
def _worker_compile_cpp (
2172
2238
lock_path : str ,
2173
- cpp_builder : CppBuilder ,
2239
+ cpp_builders : Sequence [ CppBuilder ] ,
2174
2240
) -> None :
2175
2241
from torch .utils ._filelock import FileLock
2176
2242
2177
2243
with FileLock (lock_path , timeout = LOCK_TIMEOUT ):
2178
- if not os .path .exists (cpp_builder .get_target_file_path ()):
2179
- cpp_builder .build ()
2244
+ for builder in cpp_builders :
2245
+ if not os .path .exists (builder .get_target_file_path ()):
2246
+ builder .build ()
2180
2247
2181
2248
2182
2249
# Customized Python binding for cpp kernels
@@ -2305,19 +2372,24 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
2305
2372
@classmethod
2306
2373
def load_pybinding_async (
2307
2374
cls ,
2308
- argtypes : list [str ],
2309
- source_code : str ,
2375
+ argtypes : Sequence [str ],
2376
+ main_code : str ,
2310
2377
device_type : str = "cpu" ,
2311
2378
num_outputs : int = - 1 ,
2312
2379
submit_fn : Any = None ,
2313
2380
extra_flags : Sequence [str ] = (),
2381
+ kernel_code : Optional [str ] = None ,
2314
2382
) -> Any :
2315
2383
"""
2316
2384
Wrap a C++ function in fast Python bindings.
2317
2385
2318
2386
Args:
2319
2387
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
2320
- source_code: C++ source code containing a ENTRY_FUNCTION() function
2388
+ main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at
2389
+ -O3 if kernel_code is None (to maximize performance in any kernels that
2390
+ are present), or -O1 otherwise (to minimize compile time).
2391
+ kernel_code: If present, C++ source code that will be built at -O3 and
2392
+ linked to main_code.
2321
2393
2322
2394
Returns:
2323
2395
A python version of ENTRY_FUNCTION()
@@ -2333,10 +2405,11 @@ def load_pybinding_async(
2333
2405
extra_parse_arg = cls .extra_parse_arg .format (array_len = num_outputs ),
2334
2406
)
2335
2407
get_result = cls .load_async (
2336
- source_code + suffix ,
2408
+ main_code + suffix ,
2337
2409
device_type ,
2338
2410
submit_fn = submit_fn ,
2339
2411
extra_flags = extra_flags ,
2412
+ optimized_code = kernel_code ,
2340
2413
)
2341
2414
result = None
2342
2415
0 commit comments