8000 [ROCm] Added unit test to test the cuda_pluggable allocator (#154135) · pytorch/pytorch@941732c · GitHub
[go: up one dir, main page]

Skip to content

Commit 941732c

Browse files
pytorchbotamd-sriramjithunnair-amd
authored
[ROCm] Added unit test to test the cuda_pluggable allocator (#154135)
[ROCm] Added unit test to test the cuda_pluggable allocator (#154041) Added unit test to include the cuda_pluggable allocator and replicate the apex setup.py to build nccl_allocator extension This test to check if this commit #152179 helps to build the cuda pluggable allocator in Rocm/Apex Pull Request resolved: #154041 Approved by: https://github.com/atalman, https://github.com/jeffdaily (cherry picked from commit c2660d2) Co-authored-by: skishore <sriramkumar.kishorekumar@amd.com> Co-authored-by: Jithun Nair <jithun.nair@amd.com>
1 parent 769d5da commit 941732c

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

test/test_cpp_extensions_jit.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,48 @@ def test_aoti_torch_call_dispatcher(self):
11631163
self.assertEqual(abs_t, torch.abs(t))
11641164
self.assertEqual(floor_t, torch.floor(t))
11651165

1166+
@unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found")
1167+
def test_cuda_pluggable_allocator_include(self):
1168+
"""
1169+
This method creates a minimal example to replicate the apex setup.py to build nccl_allocator extension
1170+
"""
1171+
1172+
# the cpp source includes CUDAPluggableAllocator and has an empty exported function
1173+
cpp_source = """
1174+
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
1175+
#include <torch/extension.h>
1176+
int get_nccl_allocator() {
1177+
return 0;
1178+
}
1179+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1180+
m.def("get_nccl_allocator", []() { return get_nccl_allocator(); });
1181+
}
1182+
"""
1183+
1184+
build_dir = tempfile.mkdtemp()
1185+
src_path = os.path.join(build_dir, "NCCLAllocator.cpp")
1186+
1187+
with open(src_path, mode="w") as f:
1188+
f.write(cpp_source)
1189+
1190+
# initially success is false
1191+
success = False
1192+
try:
1193+
# try to build the module
1194+
torch.utils.cpp_extension.load(
1195+
name="nccl_allocator",
1196+
sources=src_path,
1197+
verbose=True,
1198+
with_cuda=True,
1199+
)
1200+
# set success as true if built successfully
1201+
success = True
1202+
except Exception as e:
1203+
print(f"Failed to load the module: {e}")
1204+
1205+
# test if build was successful
1206+
self.assertEqual(success, True)
1207+
11661208

11671209
if __name__ == "__main__":
11681210
common.run_tests()

0 commit comments

Comments
 (0)
0