8000 [AOTI] Fix #140546 and support AOTI package load for Intel GPU. (#140… · pytorch/pytorch@91d3054 · GitHub
[go: up one dir, main page]

Skip to content

Commit 91d3054

Browse files
etafpytorchmergebot
authored andcommitted
[AOTI] Fix #140546 and support AOTI package load for Intel GPU. (#140664)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #140686 * __->__ #140664 * #140269 * #140268 * #135320 * #135318 * #139026 Fix #140546 Pull Request resolved: #140664 Approved by: https://github.com/desertfire, https://github.com/EikanWang ghstack dependencies: #140268, #140269
1 parent 854d831 commit 91d3054

15 files changed

+104
-39
lines changed

test/cpp/aoti_inference/aoti_custom_class.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ MyAOTIClass::MyAOTIClass(
3636
} else if (device_ == "cuda") {
3737
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
3838
model_path.c_str());
39+
#endif
40+
#if defined(USE_XPU)
41+
} else if (device_ == "xpu") {
42+
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerXpu>(
43+
model_path.c_str());
3944
#endif
4045
} else {
4146
throw std::runtime_error("invalid device: " + device);

test/inductor/test_aot_inductor_package.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch._inductor.utils import fresh_inductor_cache
1616
from torch.export import Dim
1717
from torch.testing._intern 6D40 al.common_utils import IS_FBCODE, TEST_CUDA
18-
from torch.testing._internal.triton_utils import HAS_CUDA
18+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1919

2020

2121
def skipif(predicate: Callable[[str, bool], bool], reason: str):
@@ -69,8 +69,8 @@ def compile(
6969
)
7070
+ (
7171
[
72-
{"device": "cuda", "package_cpp_only": False},
73-
{"device": "cuda", "package_cpp_only": True},
72+
{"device": GPU_TYPE, "package_cpp_only": False},
73+
{"device": GPU_TYPE, "package_cpp_only": True},
7474
]
7575
if sys.platform != "darwin"
7676
else []
@@ -445,5 +445,5 @@ def forward(self, a):
445445
from torch._inductor.test_case import run_tests
446446

447447
# cpp_extension N/A in fbcode
448-
if HAS_CUDA or sys.platform == "darwin":
448+
if HAS_GPU or sys.platform == "darwin":
449449
run_tests(needs="filelock")

torch/csrc/inductor/aoti_eager/kernel_holder.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#ifdef USE_CUDA
1414
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
1515
#endif
16+
#ifdef USE_XPU
17+
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
18+
#endif
1619
#include <torch/csrc/jit/frontend/function_schema_parser.h>
1720

1821
#include <ATen/core/jit_type.h>
@@ -177,6 +180,7 @@ AOTIPythonKernelHolder::AOTIPythonKernelHolder(
177180
auto registered_aoti_runner = getAOTIModelRunnerRegistry();
178181
TORCH_CHECK(
179182
device_.type() == c10::DeviceType::CUDA ||
183+
device_.type() == c10::DeviceType::XPU ||
180184
device_.type() == c10::DeviceType::CPU ||
181185
registered_aoti_runner.find(device_name) !=
182186
registered_aoti_runner.end(),
@@ -417,6 +421,7 @@ std::shared_ptr<AOTIModelContainerRunner> AOTIPythonKernelHolder::
417421
auto registered_aoti_runner = getAOTIModelRunnerRegistry();
418422
TORCH_CHECK(
419423
device_.type() == c10::DeviceType::CUDA ||
424+
device_.type() == c10::DeviceType::XPU ||
420425
device_.type() == c10::DeviceType::CPU ||
421426
registered_aoti_runner.find(device_name) !=
422427
registered_aoti_runner.end(),
@@ -428,6 +433,12 @@ std::shared_ptr<AOTIModelContainerRunner> AOTIPythonKernelHolder::
428433
return std::make_shared<AOTIModelContainerRunnerCuda>(so_path);
429434
#else
430435
return nullptr;
436+
#endif
437+
} else if (device_.type() == c10::DeviceType::XPU) {
438+
#ifdef USE_XPU
439+
return std::make_shared<AOTIModelContainerRunnerXpu>(so_path);
440+
#else
441+
return nullptr;
431442
#endif
432443
} else if (device_.type() == c10::DeviceType::CPU) {
433444
return std::make_shared<AOTIModelContainerRunnerCpu>(so_path);

torch/csrc/inductor/aoti_package/model_package_loader.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,9 @@ AOTIModelContainerRunner* AOTIModelPackageLoader::get_runner() {
459459
}
460460

461461
std::vector<at::Tensor> AOTIModelPackageLoader::run(
462-
const std::vector<at::Tensor>& inputs) {
463-
return runner_->run(inputs);
462+
const std::vector<at::Tensor>& inputs,
463+
void* stream_handle) {
464+
return runner_->run(inputs, stream_handle);
464465
}
465466

466467
std::unordered_map<std::string, std::string> AOTIModelPackageLoader::

torch/csrc/inductor/aoti_package/model_package_loader.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
10000
@@ -15,7 +15,9 @@ class TORCH_API AOTIModelPackageLoader {
1515

1616
AOTIModelContainerRunner* get_runner();
1717
std::unordered_map<std::string, std::string> get_metadata();
18-
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inputs);
18+
std::vector<at::Tensor> run(
19+
const std::vector<at::Tensor>& inputs,
20+
void* stream_handle = nullptr);
1921
std::vector<std::string> get_call_spec();
2022
void load_constants(
2123
std::unordered_map<std::string, at::Tensor>& constants_map,

torch/csrc/inductor/aoti_package/pybind.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ void initAOTIPackageBindings(PyObject* module) {
1818
.def(py::init<const std::string&, const std::string&>())
1919
.def(py::init<const std::string&>())
2020
.def("get_metadata", &AOTIModelPackageLoader::get_metadata)
21-
.def("run", &AOTIModelPackageLoader::run)
21+
.def(
22+
"run",
23+
&AOTIModelPackageLoader::run,
24+
py::arg("inputs"),
25+
py::arg("stream_handle") = nullptr)
2226
.def("get_call_spec", &AOTIModelPackageLoader::get_call_spec)
2327
.def("load_constants", &AOTIModelPackageLoader::load_constants)
2428
.def("get_constant_fqns", &AOTIModelPackageLoader::get_constant_fqns);

torch/csrc/inductor/aoti_runner/model_container_runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ AOTIModelContainerRunner::~AOTIModelContainerRunner() {
9393

9494
std::vector<at::Tensor> AOTIModelContainerRunner::run(
9595
const std::vector<at::Tensor>& inputs,
96-
AOTInductorStreamHandle cuda_stream_handle) {
96+
void* stream_handle) {
9797
auto input_handles =
9898
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
9999

@@ -110,7 +110,7 @@ std::vector<at::Tensor> AOTIModelContainerRunner::run(
110110
input_handles.size(),
111111
output_handles.data(),
112112
output_handles.size(),
113-
cuda_stream_handle,
113+
reinterpret_cast<AOTInductorStreamHandle>(stream_handle),
114114
proxy_executor_handle_));
115115

116116
return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(

torch/csrc/inductor/aoti_runner/model_container_runner.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ class TORCH_API AOTIModelContainerRunner {
2222
delete;
2323
AOTIModelContainerRunner& operator=(AOTIModelContainerRunner&& other) =
2424
delete;
25-
~AOTIModelContainerRunner();
25+
virtual ~AOTIModelContainerRunner();
2626

27-
std::vector<at::Tensor> run(
27+
virtual std::vector<at::Tensor> run(
2828
const std::vector<at::Tensor>& inputs,
29-
AOTInductorStreamHandle cuda_stream_handle = nullptr);
29+
void* stream_handle = nullptr);
3030

3131
std::unordered_map<std::string, std::string> getConstantNamesToOriginalFQNs()
3232
const;

torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ AOTIModelContainerRunnerCpu::AOTIModelContainerRunnerCpu(
1313
AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() = default;
1414

1515
std::vector<at::Tensor> AOTIModelContainerRunnerCpu::run(
16-
const std::vector<at::Tensor>& inputs) {
17-
return AOTIModelContainerRunner::run(inputs);
16+
const std::vector<at::Tensor>& inputs,
17+
void* stream_handle) {
18+
return AOTIModelContainerRunner::run(inputs, stream_handle);
1819
}
1920

2021
namespace {

torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 7B28 +10,11 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
1010
const std::string& model_so_path,
1111
size_t num_models = 1);
1212

13-
~AOTIModelContainerRunnerCpu();
13+
~AOTIModelContainerRunnerCpu() override;
1414

15-
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inputs);
15+
std::vector<at::Tensor> run(
16+
const std::vector<at::Tensor>& inputs,
17+
void* stream_handle = nullptr) override;
1618
};
1719

1820
} // namespace torch::inductor

0 commit comments

Comments
 (0)
0