8000 Add optional device index to AOTIModelPackageLoader (#152093) · pytorch/pytorch@8f54e56 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f54e56

Browse files
juliusghpytorchmergebot
authored andcommitted
Add optional device index to AOTIModelPackageLoader (#152093)
This is my suggestion for resolving #152087 This PR extends the constructor of `AOTIModelPackageLoader` with an (optional) device index. The device type is still determined by `metadata_["AOTI_DEVICE_KEY"]`, but the `device_index` argument can be used to move an AOTI model package to different devices like `cuda:0`, `cuda:1`, ... in a convenient way. AFAIK, this is not possible so far using `AOTIModelPackageLoader` alone. The default case (no device index specified) with `metadata_["AOTI_DEVICE_KEY"] == "cuda"` would lead to the current behavior, i.e., the model is loaded to device `cuda`. Pull Request resolved: #152093 Approved by: https://github.com/desertfire
1 parent fd8fd01 commit 8f54e56

File tree

7 files changed

+113
-15
lines changed

7 files changed

+113
-15
lines changed

test/cpp/aoti_inference/test.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
1515
#if defined(USE_CUDA)
1616
#include <c10/cuda/CUDACachingAllocator.h>
17+
#include <c10/cuda/CUDAGuard.h>
1718
#include <cuda_runtime.h>
1819
#endif
1920
#if defined(USE_CUDA) || defined(USE_ROCM)
@@ -139,6 +140,45 @@ void test_aoti_package_loader(
139140
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
140141
}
141142

143+
void test_aoti_package_loader_multi_gpu(
144+
const std::string& device,
145+
bool use_runtime_constant_folding) {
146+
torch::NoGradGuard no_grad;
147+
148+
std::string data_path =
149+
(std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt")
150+
.string();
151+
torch::jit::script::Module data_loader = torch::jit::load(data_path);
152+
std::string suffix = use_runtime_constant_folding
153+
? device + "_use_runtime_constant_folding"
154+
: device;
155+
std::string path_attr = "pt2_package_path_" + suffix;
156+
std::string inputs_attr = "inputs_" + suffix;
157+
std::string outputs_attr = "outputs_" + suffix;
158+
const auto& pt2_package_path =
159+
data_loader.attr(path_attr.c_str()).toStringRef();
160+
const auto& ref_output_tensors =
161+
data_loader.attr(outputs_attr.c_str()).toTensorList().vec();
162+
163+
// For all available CUDA devices: Load PT2 package on this device, run
164+
// inference, and validate results
165+
auto input_tensors =
166+
data_loader.attr(inputs_attr.c_str()).toTensorList().vec();
167+
for (int i = 0; i < torch::cuda::device_count(); i++) {
168+
auto options = torch::TensorOptions().device(torch::kCUDA, i);
169+
torch::inductor::AOTIModelPackageLoader runner(
170+
pt2_package_path, "model", false, 1, i);
171+
std::vector<torch::Tensor> input_tensors_on_device;
172+
for (auto input_tensor : input_tensors) {
173+
input_tensors_on_device.push_back(input_tensor.clone().to(options));
174+
}
175+
// Run loaded PT2 package on device
176+
auto actual_output_tensors = runner.run(input_tensors_on_device);
177+
ASSERT_TRUE(torch::allclose(
178+
ref_output_tensors[0].cpu(), actual_output_tensors[0].cpu()));
179+
}
180+
}
181+
142182
void test_aoti_constants_update(
143183
const std::string& device,
144184
bool use_runtime_constant_folding) {
@@ -988,6 +1028,10 @@ TEST(AotInductorTest, BasicPackageLoaderTestCuda) {
9881028
test_aoti_package_loader("cuda", false);
9891029
}
9901030

1031+
TEST(AotInductorTest, BasicPackageLoaderTestMultiGpuCuda) {
1032+
test_aoti_package_loader_multi_gpu("cuda", false);
1033+
}
1034+
9911035
TEST(AotInductorTest, UpdateUserManagedConstantsCuda) {
9921036
test_aoti_user_managed_buffer();
9931037
}

test/inductor/test_aot_inductor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2155,6 +2155,41 @@ def forward(self, x, y):
21552155
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
21562156
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
21572157

2158+
@requires_multigpu()
2159+
def test_load_package_multiple_gpus(s 9E88 elf):
2160+
if self.device != GPU_TYPE:
2161+
raise unittest.SkipTest("requires GPU")
2162+
2163+
class Model(torch.nn.Module):
2164+
def __init__(self, weight):
2165+
super().__init__()
2166+
self.weight = weight
2167+
2168+
def forward(self, x, y):
2169+
return x + torch.nn.functional.linear(y, self.weight)
2170+
2171+
weight = torch.randn(10, 10, device=self.device)
2172+
inputs = (
2173+
torch.randn(10, 10, device=self.device),
2174+
torch.randn(10, 10, device=self.device),
2175+
)
2176+
model = Model(weight).to(device=self.device)
2177+
result_ref = model(*inputs)
2178+
2179+
package_path = AOTIRunnerUtil.compile(model, inputs)
2180+
2181+
# Load AOT package on gpu:N
2182+
device_interface = get_interface_for_device(GPU_TYPE)
2183+
for i in range(device_interface.device_count()):
2184+
device = torch.device(GPU_TYPE, i)
2185+
with device_interface.device(i), torch.no_grad():
2186+
model_package = torch._inductor.aoti_load_package(
2187+
package_path, device_index=i
2188+
)
2189+
inputs_on_device = [input.to(device=device) for input in inputs]
2190+
result_package = model_package(*inputs_on_device)
2191+
self.assertTrue(same(result_ref.cpu(), result_package.cpu()))
2192+
21582193
def test_reuse_kernel(self):
21592194
class Model(torch.nn.Module):
21602195
def __init__(self) -> None:

torch/_inductor/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def _aoti_compile_and_package_inner(
235235
return package_path
236236

237237

238-
def aoti_load_package(path: FileLike, run_single_threaded: bool = False) -> Any: # type: ignore[type-arg]
238+
def aoti_load_package(
239+
path: FileLike, run_single_threaded: bool = False, device_index: int = -1
240+
) -> Any: # type: ignore[type-arg]
239241
"""
240242
Loads the model from the PT2 package.
241243
@@ -254,10 +256,16 @@ def aoti_load_package(path: FileLike, run_single_threaded: bool = False) -> Any:
254256
run_single_threaded (bool): Whether the model should be run without
255257
thread synchronization logic. This is useful to avoid conflicts with
256258
CUDAGraphs.
259+
device_index (int): The index of the device to which the PT2 package is
260+
to be loaded. By default, `device_index=-1` is used, which corresponds
261+
to the device `cuda` when using CUDA. Passing `device_index=1` would
262+
load the package to `cuda:1`, for example.
257263
"""
258264
from torch._inductor.package import load_package
259265

260-
return load_package(path, run_single_threaded=run_single_threaded)
266+
return load_package(
267+
path, run_single_threaded=run_single_threaded, device_index=device_index
268+
)
261269

262270

263271
def aot_compile(

torch/_inductor/package/package.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def load_package(
290290
model_name: str = "model",
291291
run_single_threaded: bool = False,
292292
num_runners: int = 1,
293+
device_index: int = -1,
293294
) -> AOTICompiledModel: # type: ignore[type-arg]
294295
assert (
295296
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
@@ -305,12 +306,12 @@ def load_package(
305306
path.seek(0)
306307
log.debug("Writing buffer to tmp file located at %s.", f.name)
307308
loader = torch._C._aoti.AOTIModelPackageLoader(
308-
f.name, model_name, run_single_threaded, num_runners
309+
f.name, model_name, run_single_threaded, num_runners, device_index
309310
) # type: ignore[call-arg]
310311
return AOTICompiledModel(loader)
311312

312313
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
313314
loader = torch._C._aoti.AOTIModelPackageLoader(
314-
path, model_name, run_single_threaded, num_runners
315+
path, model_name, run_single_threaded, num_runners, device_index
315316
) # type: ignore[call-arg]
316317
return AOTICompiledModel(loader)

torch/csrc/inductor/aoti_package/model_package_loader.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
342342
const std::string& model_package_path,
343343
const std::string& model_name,
344344
const bool run_single_threaded,
345-
const size_t num_runners) {
345+
const size_t num_runners,
346+
const c10::DeviceIndex device_index) {
346347
if (run_single_threaded) {
347348
if (num_runners != 1) {
348349
throw std::runtime_error(
@@ -470,22 +471,25 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
470471
load_metadata(cpp_filename);
471472

472473
// Construct the runner depending on the device information
473-
std::string device = metadata_["AOTI_DEVICE_KEY"];
474+
std::string device_key = metadata_["AOTI_DEVICE_KEY"];
474475

475-
if (device.empty()) {
476+
if (device_key.empty()) {
476477
throw std::runtime_error("No device information found.");
477478
}
478479

479480
std::unordered_map<std::string, CreateAOTIModelRunnerFunc>
480481
registered_aoti_runner = getAOTIModelRunnerRegistry();
481482

482-
if (registered_aoti_runner.find(device) == registered_aoti_runner.end()) {
483-
throw std::runtime_error("Unsupported device found: " + device);
483+
if (registered_aoti_runner.find(device_key) == registered_aoti_runner.end()) {
484+
throw std::runtime_error("Unsupported device key found: " + device_key);
484485
}
485486

487+
c10::Device device = c10::Device(device_key);
488+
device.set_index(device_index);
489+
486490
std::string cubin_dir = temp_dir_ + k_separator + model_directory;
487-
runner_ = registered_aoti_runner[device](
488-
so_path, num_runners, device, cubin_dir, run_single_threaded);
491+
runner_ = registered_aoti_runner[device_key](
492+
so_path, num_runners, device.str(), cubin_dir, run_single_threaded);
489493
}
490494

491495
AOTIModelPackageLoader::~AOTIModelPackageLoader() {

torch/csrc/inductor/aoti_package/model_package_loader.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#pragma once
33

44
#include <ATen/Tensor.h>
5+
#include <c10/core/Device.h>
56
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
67

78
namespace torch::inductor {
@@ -11,7 +12,8 @@ class TORCH_API AOTIModelPackageLoader {
1112
const std::string& model_package_path,
1213
const std::string& model_name = "model",
1314
const bool run_single_threaded = false,
14-
const size_t num_runners = 1);
15+
const size_t num_runners = 1,
16+
const c10::DeviceIndex device_index = -1);
1517
~AOTIModelPackageLoader();
1618

1719
AOTIModelContainerRunner* get_runner();

torch/csrc/inductor/aoti_package/pybind.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
77
#endif
88

9+
#include <c10/core/Device.h>
910
#include <torch/csrc/autograd/python_variable.h>
1011
#include <torch/csrc/inductor/aoti_runner/pybind.h>
1112
#include <torch/csrc/utils/pybind.h>
@@ -18,12 +19,14 @@ class AOTIModelPackageLoaderPybind : public AOTIModelPackageLoader {
1819
const std::string& model_package_path,
1920
const std::string& model_name,
2021
const bool run_single_threaded,
21-
const size_t num_runners)
22+
const size_t num_runners,
23+
const c10::DeviceIndex device_index)
2224
: AOTIModelPackageLoader(
2325
model_package_path,
2426
model_name,
2527
run_single_threaded,
26-
num_runners) {}
28+
num_runners,
29+
device_index) {}
2730

2831
py::list boxed_run(py::list& inputs, void* stream_handle = nullptr) {
2932
std::vector<at::Tensor> input_tensors;
@@ -54,7 +57,8 @@ void initAOTIPackageBindings(PyObject* module) {
5457
const std::string&,
5558
const std::string&,
5659
const bool,
57-
const size_t>())
60+
const size_t,
61+
const c10::DeviceIndex>())
5862
.def("get_metadata", &AOTIModelPackageLoaderPybind::get_metadata)
5963
.def(
6064
"run",

0 commit comments

Comments
 (0)
0