@@ -1532,6 +1532,12 @@ def get_keys(cls) -> KeysView[str]:
1532
1532
1533
1533
1534
1534
class AotCodeCompiler :
1535
+ """
1536
+ AOTCodeCompiler is a class that handles the compilation of AOTInductor
1537
+ kernels. It is responsible for generating the kernel and wrapper code,
1538
+ compiling and packaging them.
1539
+ """
1540
+
1535
1541
@classmethod
1536
1542
def compile (
1537
1543
cls ,
@@ -1744,6 +1750,9 @@ def _compile_consts(consts: bytes, platform: str) -> str:
1744
1750
1745
1751
metadata = config .aot_inductor .metadata
1746
1752
metadata ["AOTI_DEVICE_KEY" ] = device_type
1753
+ metadata ["STANDALONE" ] = (
1754
+ "1" if config .aot_inductor .codegen_standalone else "0"
1755
+ )
1747
1756
1748
1757
# Save user provided metadata
1749
1758
meta_json = str (
@@ -1878,6 +1887,27 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
1878
1887
1879
1888
log .debug ("aot wrapper compilation command: %s" , wrapper_compile_cmd )
1880
1889
log .debug ("aot kernel compilation command: %s" , kernel_compile_cmd )
1890
+
1891
+ cuda_utils_o : list [str ] = []
1892
+ if config .aot_inductor .codegen_standalone and device_type == "cuda" :
1893
+ # TODO: seletively add additional cuda files
1894
+ cuda_util_files : list [str ] = []
1895
+ cuda_build_options = CppTorchDeviceOptions (
1896
+ compiler = "nvcc" ,
1897
+ compile_only = True ,
1898
+ ** compile_command ,
1899
+ )
1900
+ for file in cuda_util_files :
1901
+ cuda_builder = CppBuilder (
1902
+ name = file ,
1903
+ sources = file ,
10000
code>
1904
+ output_dir = str (wrapper_path_operator .parent ),
1905
+ BuildOption = cuda_build_options ,
1906
+ )
1907
+ if not config .aot_inductor .package_cpp_only :
1908
+ cuda_builder .build ()
1909
+ cuda_utils_o .append (cuda_builder .get_target_file_path ())
1910
+
1881
1911
if config .aot_inductor .package_cpp_only :
1882
1912
# Not doing the actual compilation here
1883
1913
compile_flags = str (
@@ -2001,7 +2031,14 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
2001
2031
use_relative_path = use_relative_path ,
2002
2032
)
2003
2033
2004
- obj_srcs = [wrapper_o , kernel_o , consts_o , * gpu_kernels_o , * cubins_o ]
2034
+ obj_srcs = [
2035
+ wrapper_o ,
2036
+ kernel_o ,
2037
+ consts_o ,
2038
+ * gpu_kernels_o ,
2039
+ * cubins_o ,
2040
+ * cuda_utils_o ,
2041
+ ]
2005
2042
so_builder = CppBuilder (
2006
2043
name = output_name ,
2007
2044
sources = obj_srcs ,
@@ -2096,7 +2133,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
2096
2133
@clear_on_fresh_inductor_cache
2097
2134
@functools .lru_cache
2098
2135
def cpp_prefix_path () -> str :
2099
- path = Path (__file__ ).parent / "codegen/ cpp_prefix.h"
2136
+ path = Path (__file__ ).parent / "codegen" / " cpp_prefix.h"
2100
2137
with path .open () as f :
2101
2138
content = f .read ()
2102
2139
_ , filename = write (
@@ -2571,7 +2608,11 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
2571
2608
call_entry_function = "return inductor_entry_cpp({});"
2572
2609
extra_parse_arg = textwrap .dedent (
2573
2610
"""
2611
+ #ifdef AOTI_STANDALONE
2612
+ #include <torch/csrc/inductor/aoti_standalone/c/shim.h>
2613
+ #else
2574
2614
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
2615
+ #endif // AOTI_STANDALONE
2575
2616
2576
2617
static inline std::vector<AtenTensorHandle> unpack_tensor_handle_list(PyObject* pyvec) {{
2577
2618
std::vector<AtenTensorHandle> result;
@@ -3215,7 +3256,7 @@ def _nvcc_host_compiler_options() -> list[str]:
3215
3256
]
3216
3257
3217
3258
3218
- def _nvcc_compiler_options () -> list [ str ] :
3259
+ def _nvcc_get_arch_option () -> str :
3219
3260
arch = cuda_env .get_cuda_arch ()
3220
3261
if arch == "90" :
3221
3262
# Required by cutlass compilation.
@@ -3225,13 +3266,17 @@ def _nvcc_compiler_options() -> list[str]:
3225
3266
code = [f"sm_{ arch } " , f"compute_{ arch } " ]
3226
3267
if config .cuda .enable_cuda_lto :
3227
3268
code += [f"lto_{ arch } " ]
3269
+ return f"gencode=arch=compute_{ arch } ,code=[{ ',' .join (code )} ]"
3270
+
3271
+
3272
+ def _nvcc_compiler_options () -> list [str ]:
3228
3273
options = [
3229
3274
"-t=0" ,
3230
3275
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" ,
3231
3276
"-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1" ,
3232
3277
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED" ,
3233
3278
"-w" ,
3234
- f"-gencode=arch=compute_ { arch } ,code=[ { ',' . join ( code ) } ] " ,
3279
+ f"-{ _nvcc_get_arch_option () } " ,
3235
3280
config .cuda .compile_opt_level ,
3236
3281
"-std=c++17" ,
3237
3282
"--expt-relaxed-constexpr" ,
0 commit comments