8000 [RelEng] Define `BUILD_BUNDLE_PTXAS` (#119750) (#119988) · pytorch/pytorch@6c8c5ad · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c8c5ad

Browse files
atalmanmalfet
andauthored
[RelEng] Define BUILD_BUNDLE_PTXAS (#119750) (#119988)
Co-authored-by: Nikita Shulga <nshulga@meta.com> Fixes #119054 resolved: #119750
1 parent f00f0ab commit 6c8c5ad

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ cmake_dependent_option(
349349
"NOT INTERN_BUILD_MOBILE" OFF)
350350
cmake_dependent_option(
351351
BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
352+
cmake_dependent_option(
353+
BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" OFF "USE_CUDA" OFF)
352354

353355
option(USE_MIMALLOC "Use mimalloc" OFF)
354356
# Enable third party mimalloc library to improve memory allocation performance on Windows.
@@ -1230,3 +1232,12 @@ if(DEFINED USE_CUSTOM_DEBINFO)
12301232
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -g")
12311233
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -g")
12321234
endif()
1235+
1236+
# Bundle PTXAS if needed
1237+
if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
1238+
if(NOT EXISTS "${PROJECT_SOURCE_DIR}/build/bin/ptxas")
1239+
message(STATUS "Copying PTXAS into the bin folder")
1240+
file(COPY "${CUDAToolkit_BIN_DIR}/ptxas" DESTINATION "${PROJECT_BINARY_DIR}")
1241+
endif()
1242+
install(PROGRAMS "${PROJECT_BINARY_DIR}/ptxas" DESTINATION "${CMAKE_INSTALL_BINDIR}")
1243+
endif()

torch/_inductor/codecache.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,6 +2277,20 @@ def caching_device_properties():
22772277
device_interface.Worker.get_device_properties()
22782278

22792279

2280+
def _set_triton_ptxas_path() -> None:
2281+
if os.environ.get("TRITON_PTXAS_PATH") is not None:
2282+
return
2283+
ptxas_path = os.path.abspath(
2284+
os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
2285+
)
2286+
if not os.path.exists(ptxas_path):
2287+
return
2288+
if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
2289+
os.environ["TRITON_PTXAS_PATH"] = ptxas_path
2290+
else:
2291+
warnings.warn(f"{ptxas_path} exists but is not an executable")
2292+
2293+
22802294
def _worker_compile(
22812295
kernel_name: str, source_code: str, cc: int, device: torch.device
22822296
) -> None:
@@ -2287,6 +2301,7 @@ def _worker_compile(
22872301

22882302

22892303
def _load_kernel(kernel_name: str, source_code: str) -> ModuleType:
2304+
_set_triton_ptxas_path()
22902305
kernel = TritonCodeCache.load(kernel_name, source_code)
22912306
kernel.precompile()
22922307
return kernel

0 commit comments

Comments
 (0)
0