@@ -125,6 +125,7 @@ def use_global_cache() -> bool: # type: ignore[misc]
125
125
from concurrent .futures import Future
126
126
127
127
from .compile_fx import _CompileFxKwargs , CompiledFxGraph
128
+ from .cpp_builder import BuildOptionsBase
128
129
from .graph import GraphLowering
129
130
from .ir import ChoiceCaller
130
131
from .output_code import CompiledFxGraphConstants , OutputCode
@@ -2014,7 +2015,7 @@ def _get_file_checksum(filename: str) -> str:
2014
2015
os .makedirs (_HEADER_LOCK_DIR , exist_ok = True )
2015
2016
_worker_compile_cpp (
2016
2017
os .path .join (_HEADER_LOCK_DIR , f"{ header_hash } .lock" ),
2017
- cpp_builder ,
2018
+ ( cpp_builder ,) ,
2018
2019
)
2019
2020
2020
2021
return header_full_path
@@ -2081,10 +2082,11 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
2081
2082
@classmethod
2082
2083
def load_async (
2083
2084
cls ,
2084
- source_code : str ,
2085
+ main_code : str ,
2085
2086
device_type : str = "cpu" ,
2086
2087
submit_fn : Any = None ,
2087
2088
extra_flags : Sequence [str ] = (),
2089
+ optimized_code : Optional [str ] = None ,
2088
2090
) -> Any :
2089
2091
compile_command = {
2090
2092
** cls .cpp_compile_command_flags ,
@@ -2096,48 +2098,112 @@ def load_async(
2096
2098
2097
2099
_set_gpu_runtime_env () # cpp_extension consults the env
2098
2100
2099
- cpp_build_option = CppTorchDeviceOptions (** compile_command )
2100
- command_gen = CppBuilder (name = "o" , sources = "i" , BuildOption = cpp_build_option )
2101
- # write function will calc source_code hash, the same source code with different
2102
- # ISA level should be generate different hash.
2103
- # So we need get a command_line which contains isa related parameter as a part of hash key.
2104
- # And then pass the command_line to below write function as extra parameter to
2105
- # guarantee the source code hash contains ISA difference.
2106
- vec_isa_cmd = repr (command_gen .get_command_line ())
2107
- key , input_path = write (source_code , "cpp" , extra = vec_isa_cmd )
2101
+ # Note the distinction between the two booleans. We do minimal optimization if
2102
+ # the optimized_code argument is present at all, since that's how the user of
2103
+ # this function opts in, but we do compilation and linking in one step if the
2104
+ # optimized_code argument is empty (as a micro-optimization).
2105
+ main_build_option = CppTorchDeviceOptions (
2106
+ compile_only = bool (optimized_code ),
2107
+ min_optimize = optimized_code is not None ,
2108
+ ** compile_command ,
2109
+ )
2110
+ optimized_build_option = CppTorchDeviceOptions (
2111
+ compile_only = True , ** compile_command
2112
+ )
2113
+
2114
+ def get_hashable_command_line (build_option : BuildOptionsBase ) -> str :
2115
+ """Writing the code to file will calculate a hash, which we need to vary if
2116
+ the command line flags change. This implements a mostly-generic way of
2117
+ validating that."""
2118
+ return CppBuilder (
2119
+ name = "o" , sources = "i" , BuildOption = build_option
2120
+ ).get_command_line ()
2121
+
2122
+ main_cmd_line = get_hashable_command_line (main_build_option )
2123
+ optimized_cmd_line = get_hashable_command_line (optimized_build_option )
2124
+
2125
+ key , main_path = write (
2126
+ main_code , "main.cpp" , extra = f"{ optimized_code } { main_cmd_line } "
2127
+ )
2128
+
2129
+ # Don't bother writing if the argument is empty.
2130
+ if optimized_code :
2131
+ _ , optimized_path = write (
2132
+ optimized_code , "optimized.cpp" , extra = optimized_cmd_line
2133
+ )
2134
+ else :
2135
+ # Unused, but makes type checkers happy.
2136
+ optimized_path = os .devnull
2108
2137
2109
2138
if key not in cls .cache :
2110
2139
from torch .utils ._filelock import FileLock
2111
2140
2112
2141
lock_path = os .path .join (get_lock_dir (), key + ".lock" )
2113
- output_name , output_dir = get_name_and_dir_from_output_file_path (input_path )
2114
2142
future : Optional [Future [Any ]] = None
2115
2143
lib = None
2116
2144
2117
2145
# if requested, pre-compile any headers
2118
- if (
2119
- config .cpp_cache_precompile_headers
2120
- and not _IS_WINDOWS
2121
- and (header_file := cls ._get_uncompiled_header (device_type ))
2122
- ):
2123
- cpp_build_option .precompiled_header = _precompile_header (
2124
- header_file ,
2125
- vec_isa_cmd ,
2126
- ** compile_command ,
2127
- )
2146
+ if config .cpp_cache_precompile_headers and not _IS_WINDOWS :
2147
+ if header := cls ._get_uncompiled_header (device_type ):
2148
+ main_build_option .precompiled_header = _precompile_header (
2149
+ header ,
2150
+ main_cmd_line ,
2151
+ min_optimize = optimized_code is not None ,
2152
+ ** compile_command ,
2153
+ )
2128
2154
2129
- cpp_builder = CppBuilder (
2130
- name = output_name ,
2131
- sources = input_path ,
2155
+ # Currently, the optimized_code field is only used for cpp kernel code,
2156
+ # so go ahead and precompile the relevant header here. Revisit this
2157
+ # decision if that ever changes.
2158
+ if optimized_code and (header := _get_cpp_prefix_header (device_type )):
2159
+ optimized_build_option .precompiled_header = _precompile_header (
2160
+ header ,
2161
+ optimized_cmd_line ,
2162
+ ** compile_command ,
2163
+ )
2164
+
2165
+ main_name , output_dir = get_name_and_dir_from_output_file_path (main_path )
2166
+ main_builder = CppBuilder (
2167
+ name = main_name ,
2168
+ sources = main_path ,
2169
+ BuildOption = main_build_option ,
2132
2170
output_dir = output_dir ,
2133
- BuildOption = cpp_build_option ,
2134
2171
)
2135
- worker_fn = functools .partial (
2136
- _worker_compile_cpp ,
2137
- lock_path ,
2138
- cpp_builder ,
2139
- )
2140
- binary_path = normalize_path_separator (cpp_builder .get_target_file_path ())
2172
+
2173
+ if optimized_code :
2174
+ optimized_name , _ = get_name_and_dir_from_output_file_path (
2175
+ optimized_path
2176
+ )
2177
+ optimized_builder = CppBuilder (
2178
+ name = optimized_name ,
2179
+ sources = optimized_path ,
2180
+ BuildOption = optimized_build_option ,
2181
+ output_dir = output_dir ,
2182
+ )
2183
+
2184
+ linker = CppBuilder (
2185
+ name = main_name ,
2186
+ sources = [
2187
+ main_builder .get_target_file_path (),
2188
+ optimized_builder .get_target_file_path (),
2189
+ ],
2190
+ BuildOption = CppTorchDeviceOptions (** compile_command ),
2191
+ output_dir = output_dir ,
2192
+ )
2193
+
2194
+ worker_fn = functools .partial (
2195
+ _worker_compile_cpp ,
2196
+ lock_path ,
2197
+ (main_builder , optimized_builder , linker ),
2198
+ )
2199
+ binary_path = normalize_path_separator (linker .get_target_file_path ())
2200
+ else :
2201
+ worker_fn = functools .partial (
2202
+ _worker_compile_cpp , lock_path , (main_builder ,)
2203
+ )
2204
+ binary_path = normalize_path_separator (
2205
+ main_builder .get_target_file_path ()
2206
+ )
2141
2207
2142
2208
def load_fn () -> Any :
2143
2209
nonlocal lib
@@ -2160,19 +2226,20 @@ def load_fn() -> Any:
2160
2226
return cls .cache [key ]
2161
2227
2162
2228
@classmethod
2163
- def load (cls , source_code : str , device_type : str = "cpu" ) -> Any :
2164
- return cls .load_async (source_code , device_type )()
2229
+ def load (cls , * args : Any , ** kwargs : Any ) -> Any :
2230
+ return cls .load_async (* args , ** kwargs )()
2165
2231
2166
2232
2167
2233
def _worker_compile_cpp (
2168
2234
lock_path : str ,
2169
- cpp_builder : CppBuilder ,
2235
+ cpp_builders : Sequence [ CppBuilder ] ,
2170
2236
) -> None :
2171
2237
from torch .utils ._filelock import FileLock
2172
2238
2173
2239
with FileLock (lock_path , timeout = LOCK_TIMEOUT ):
2174
- if not os .path .exists (cpp_builder .get_target_file_path ()):
2175
- cpp_builder .build ()
2240
+ for builder in cpp_builders :
2241
+ if not os .path .exists (builder .get_target_file_path ()):
2242
+ builder .build ()
2176
2243
2177
2244
2178
2245
# Customized Python binding for cpp kernels
@@ -2301,19 +2368,24 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
2301
2368
@classmethod
2302
2369
def load_pybinding_async (
2303
2370
cls ,
2304
- argtypes : list [str ],
2305
- source_code : str ,
2371
+ argtypes : Sequence [str ],
2372
+ main_code : str ,
2306
2373
device_type : str = "cpu" ,
2307
2374
num_outputs : int = - 1 ,
2308
2375
submit_fn : Any = None ,
2309
2376
extra_flags : Sequence [str ] = (),
2377
+ kernel_code : Optional [str ] = None ,
2310
2378
) -> Any :
2311
2379
"""
2312
2380
Wrap a C++ function in fast Python bindings.
2313
2381
2314
2382
Args:
2315
2383
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
2316
- source_code: C++ source code containing a ENTRY_FUNCTION() function
2384
+ main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at
2385
+ -O3 if kernel_code is None (to maximize performance in any kernels that
2386
+ are present), or -O1 otherwise (to minimize compile time).
2387
+ kernel_code: If present, C++ source code that will be built at -O3 and
2388
+ linked to main_code.
2317
2389
2318
2390
Returns:
2319
2391
A python version of ENTRY_FUNCTION()
@@ -2329,10 +2401,11 @@ def load_pybinding_async(
2329
2401
extra_parse_arg = cls .extra_parse_arg .format (array_len = num_outputs ),
2330
2402
)
2331
2403
get_result = cls .load_async (
2332
- source_code + suffix ,
2404
+ main_code + suffix ,
2333
2405
device_type ,
2334
2406
submit_fn = submit_fn ,
2335
2407
extra_flags = extra_flags ,
2408
+ optimized_code = kernel_code ,
2336
2409
)
2337
2410
result = None
2338
2411
0 commit comments