@@ -126,6 +126,7 @@ def use_global_cache() -> bool: # type: ignore[misc]
126
126
from concurrent .futures import Future
127
127
128
128
from .compile_fx import _CompileFxKwargs , CompiledFxGraph
129
+ from .cpp_builder import BuildOptionsBase
129
130
from .graph import GraphLowering
130 131
from .ir import ChoiceCaller
131
132
from .output_code import CompiledFxGraphConstants , OutputCode
@@ -1977,7 +1978,7 @@ def _get_file_checksum(filename: str) -> str:
1977
1978
os .makedirs (_HEADER_LOCK_DIR , exist_ok = True )
1978
1979
_worker_compile_cpp (
1979
1980
os .path .join (_HEADER_LOCK_DIR , f"{ header_hash } .lock" ),
1980
- cpp_builder ,
1981
+ ( cpp_builder ,) ,
1981
1982
)
1982
1983
1983
1984
return header_full_path
@@ -2044,10 +2045,11 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
2044
2045
@classmethod
2045
2046
def load_async (
2046
2047
cls ,
2047
- source_code : str ,
2048
+ main_code : str ,
2048
2049
device_type : str = "cpu" ,
2049
2050
submit_fn : Any = None ,
2050
2051
extra_flags : Sequence [str ] = (),
2052
+ optimized_code : Optional [str ] = None ,
2051
2053
) -> Any :
2052
2054
compile_command = {
2053
2055
** cls .cpp_compile_command_flags ,
@@ -2059,46 +2061,112 @@ def load_async(
2059
2061
2060
2062
_set_gpu_runtime_env () # cpp_extension consults the env
2061
2063
2062
- cpp_build_option = CppTorchDeviceOptions (** compile_command )
2063
- command_gen = CppBuilder (name = "o" , sources = "i" , BuildOption = cpp_build_option )
2064
- # write function will calc source_code hash, the same source code with different
2065
- # ISA level should be generate different hash.
2066
- # So we need get a command_line which contains isa related parameter as a part of hash key.
2067
- # And then pass the command_line to below write function as extra parameter to
2068
- # guarantee the source code hash contains ISA difference.
2069
- vec_isa_cmd = repr (command_gen .get_command_line ())
2070
- key , input_path = write (source_code , "cpp" , extra = vec_isa_cmd )
2064
+ # Note the distinction between the two booleans. We do minimal optimization if
2065
+ # the optimized_code argument is present at all, since that's how the user of
2066
+ # this function opts in, but we do compilation and linking in one step if the
2067
+ # optimized_code argument is empty (as a micro-optimization).
2068
+ main_build_option = CppTorchDeviceOptions (
2069
+ compile_only = bool (optimized_code ),
2070
+ min_optimize = optimized_code is not None ,
2071
+ ** compile_command ,
2072
+ )
2073
+ optimized_build_option = CppTorchDeviceOptions (
2074
+ compile_only = True , ** compile_command
2075
+ )
2076
+
2077
+ def get_hashable_command_line (build_option : BuildOptionsBase ) -> str :
2078
+ """Writing the code to file will calculate a hash, which we need to vary if
2079
+ the command line flags change. This implements a mostly-generic way of
2080
+ validating that."""
2081
+ return CppBuilder (
2082
+ name = "o" , sources = "i" , BuildOption = build_option
2083
+ ).get_command_line ()
2084
+
2085
+ main_cmd_line = get_hashable_command_line (main_build_option )
2086
+ optimized_cmd_line = get_hashable_command_line (optimized_build_option )
2087
+
2088
+ key , main_path = write (
2089
+ main_code , "main.cpp" , extra = f"{ optimized_code } { main_cmd_line } "
2090
+ )
2091
+
2092
+ # Don't bother writing if the argument is empty.
2093
+ if optimized_code :
2094
+ _ , optimized_path = write (
2095
+ optimized_code , "optimized.cpp" , extra = optimized_cmd_line
2096
+ )
2097
+ else :
2098
+ # Unused, but makes type checkers happy.
2099
+ optimized_path = os .devnull
2071
2100
2072
2101
if key not in cls .cache :
2073
2102
from torch .utils ._filelock import FileLock
2074
2103
2075
2104
lock_path = os .path .join (get_lock_dir (), key + ".lock" )
2076
- output_name , output_dir = get_name_and_dir_from_output_file_path (input_path )
2077
2105
future : Optional [Future [Any ]] = None
2078
2106
lib = None
2079
2107
2080
2108
# if requested, pre-compile any headers
2081
- if config .cpp_cache_precompile_headers and (
2082
- header_file := cls ._get_uncompiled_header (device_type )
2083
- ):
2084
- cpp_build_option . precompiled_header = _precompile_header (
2085
- header_file ,
2086
- vec_isa_cmd ,
2087
- ** compile_command ,
2088
- )
2109
+ if config .cpp_cache_precompile_headers :
2110
+ if header := cls ._get_uncompiled_header (device_type ):
2111
+ main_build_option . precompiled_header = _precompile_header (
2112
+ header ,
2113
+ main_cmd_line ,
2114
+ min_optimize = optimized_code is not None ,
2115
+ ** compile_command ,
2116
+ )
2089
2117
2090
- cpp_builder = CppBuilder (
2091
- name = output_name ,
2092
- sources = input_path ,
2118
+ # Currently, the optimized_code field is only used for cpp kernel code,
2119
+ # so go ahead and precompile the relevant header here. Revisit this
2120
+ # decision if that ever changes.
2121
+ if optimized_code and (header := _get_cpp_prefix_header (device_type )):
2122
+ optimized_build_option .precompiled_header = _precompile_header (
2123
+ header ,
2124
+ optimized_cmd_line ,
2125
+ ** compile_command ,
2126
+ )
2127
+
2128
+ main_name , output_dir = get_name_and_dir_from_output_file_path (main_path )
2129
+ main_builder = CppBuilder (
2130
+ name = main_name ,
2131
+ sources = main_path ,
2132
+ BuildOption = main_build_option ,
2093
2133
output_dir = output_dir ,
2094
- BuildOption = cpp_build_option ,
2095
- )
2096
- worker_fn = functools .partial (
2097
- _worker_compile_cpp ,
2098
- lock_path ,
2099
- cpp_builder ,
2100
2134
)
2101
- binary_path = normalize_path_separator (cpp_builder .get_target_file_path ())
2135
+
2136
+ if optimized_code :
2137
+ optimized_name , _ = get_name_and_dir_from_output_file_path (
2138
+ optimized_path
2139
+ )
2140
+ optimized_builder = CppBuilder (
2141
+ name = optimized_name ,
2142
+ sources = optimized_path ,
2143
+ BuildOption = optimized_build_option ,
2144
+ output_dir = output_dir ,
2145
+ )
2146
+
2147
+ linker = CppBuilder (
2148
+ name = main_name ,
2149
+ sources = [
2150
+ main_builder .get_target_file_path (),
2151
+ optimized_builder .get_target_file_path (),
2152
+ ],
2153
+ BuildOption = CppTorchDeviceOptions (** compile_command ),
2154
+ output_dir = output_dir ,
2155
+ )
2156
+
2157
+ worker_fn = functools .partial (
2158
+ _worker_compile_cpp ,
2159
+ lock_path ,
2160
+ (main_builder , optimized_builder , linker ),
2161
+ )
2162
+ binary_path = normalize_path_separator (linker .get_target_file_path ())
2163
+ else :
2164
+ worker_fn = functools .partial (
2165
+ _worker_compile_cpp , lock_path , (main_builder ,)
2166
+ )
2167
+ binary_path = normalize_path_separator (
2168
+ main_builder .get_target_file_path ()
2169
+ )
2102
2170
2103
2171
def load_fn () -> Any :
2104
2172
nonlocal lib
@@ -2121,19 +2189,20 @@ def load_fn() -> Any:
2121
2189
return cls .cache [key ]
2122
2190
2123
2191
@classmethod
2124
- def load (cls , source_code : str , device_type : str = "cpu" ) -> Any :
2125
- return cls .load_async (source_code , device_type )()
2192
+ def load (cls , * args : Any , ** kwargs : Any ) -> Any :
2193
+ return cls .load_async (* args , ** kwargs )()
2126
2194
2127
2195
2128
2196
def _worker_compile_cpp (
2129
2197
lock_path : str ,
2130
- cpp_builder : CppBuilder ,
2198
+ cpp_builders : Sequence [ CppBuilder ] ,
2131
2199
) -> None :
2132
2200
from torch .utils ._filelock import FileLock
2133
2201
2134
2202
with FileLock (lock_path , timeout = LOCK_TIMEOUT ):
2135
- if not os .path .exists (cpp_builder .get_target_file_path ()):
2136
- cpp_builder .build ()
2203
+ for builder in cpp_builders :
2204
+ if not os .path .exists (builder .get_target_file_path ()):
2205
+ builder .build ()
2137
2206
2138
2207
2139
2208
# Customized Python binding for cpp kernels
@@ -2262,19 +2331,24 @@ def _get_uncompiled_header(cls, device: str) -> str | None:
2262
2331
@classmethod
2263
2332
def load_pybinding_async (
2264
2333
cls ,
2265
- argtypes : list [str ],
2266
- source_code : str ,
2334
+ argtypes : Sequence [str ],
2335
+ main_code : str ,
2267
2336
device_type : str = "cpu" ,
2268
2337
num_outputs : int = - 1 ,
2269
2338
submit_fn : Any = None ,
2270
2339
extra_flags : Sequence [str ] = (),
2340
+ kernel_code : Optional [str ] = None ,
2271
2341
) -> Any :
2272
2342
"""
2273
2343
Wrap a C++ function in fast Python bindings.
2274
2344
2275
2345
Args:
2276
2346
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
2277
- source_code: C++ source code containing a ENTRY_FUNCTION() function
2347
+ main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at
2348
+ -O3 if kernel_code is None (to maximize performance in any kernels that
2349
+ are present), or -O1 otherwise (to minimize compile time).
2350
+ kernel_code: If present, C++ source code that will be built at -O3 and
2351
+ linked to main_code.
2278
2352
2279
2353
Returns:
2280
2354
A python version of ENTRY_FUNCTION()
@@ -2296,10 +2370,11 @@ def load_pybinding_async(
2296
2370
cls .entry_function ,
2297
2371
)
2298
2372
get_result = cls .load_async (
2299
- source_code + suffix ,
2373
+ main_code + suffix ,
2300
2374
device_type ,
2301
2375
submit_fn = submit_fn ,
2302
2376
extra_flags = extra_flags ,
2377
+ optimized_code = kernel_code ,
2303
2378
)
2304
2379
result = None
2305
2380
0 commit comments