8000 More fixes · pytorch/pytorch@80db20b · GitHub
[go: up one dir, main page]

Skip to content

Commit 80db20b

Browse files
committed
More fixes
1 parent 0fc08f1 commit 80db20b

File tree

4 files changed

+18
-51
lines changed

4 files changed

+18
-51
lines changed

torch/csrc/inductor/aoti_package/model_package_loader.cpp

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#if !defined(C10_MOBILE) && !defined(ANDROID)
22

33
#include <c10/util/error.h>
4-
#include <c10/util/string_view.h>
54
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
65
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
76
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
@@ -27,12 +26,6 @@ std::string create_temp_dir() {
2726
return temp_dir;
2827
#endif
2928
}
30-
31-
#ifdef _WIN32
32-
const std::string k_separator = "\\";
33-
#else
34-
const std::string k_separator = "/";
35-
#endif
3629
} // namespace
3730

3831
namespace torch::inductor {
@@ -242,9 +235,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
242235
std::string cpp_filename;
243236
std::string consts_filename;
244237
std::string found_filenames; // Saving for bookkeeping
245-
std::string model_directory =
246-
"data" + k_separator + "aotinductor" + k_separator + model_name;
247-
std::string const_directory = "data" + k_separator + "constants";
238+
auto model_directory =
239+
std::filesystem::path("data") / "aotinductor" / model_name;
240+
auto const_directory = std::filesystem::path("data") / "constants";
248241

249242
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
250243
uint32_t filename_len =
@@ -259,27 +252,26 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
259252
&zip_archive, i, filename_str.data(), filename_len)) {
260253
throw std::runtime_error("Failed to read filename");
261254
}
255+
std::filesystem::path filename(filename_str.c_str());
262256

263257
found_filenames += filename_str;
264258
found_filenames += " ";
265259

266260
// Only compile files in the specified model directory
267-
if (c10::starts_with(filename_str, model_directory) ||
268-
c10::starts_with(filename_str, const_directory)) {
261+
bool in_model_directory =
262+
(!std::filesystem::relative(filename, model_directory).empty());
263+
bool in_const_directory =
264+
(!std::filesystem::relative(filename, const_directory).empty());
265+
if (in_model_directory || in_const_directory) {
269266
std::filesystem::path output_path = temp_dir_;
270267

271-
if (c10::starts_with(filename_str, model_directory)) {
272-
output_path /= filename_str;
273-
} else { // startsWith(filename_str, const_directory)
268+
if (in_model_directory) {
269+
output_path /= filename;
270+
} else {
274271
// Extract constants to the same directory as the rest of the files
275272
// to be consistent with internal implementation
276-
size_t lastSlash = filename_str.find_last_of(k_separator);
277-
std::string filename = filename_str;
278-
if (lastSlash != std::string::npos) {
279-
filename = filename_str.substr(lastSlash + 1);
280-
}
281273
output_path /= model_directory;
282-
output_path /= filename;
274+
output_path /= filename.filename();
283275
}
284276
auto output_path_str = output_path.string();
285277

@@ -354,7 +346,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
354346
throw std::runtime_error("Unsupported device found: " + device);
355347
}
356348

357-
std::string cubin_dir = temp_dir_ + k_separator + model_directory;
349+
std::string cubin_dir = (temp_dir_ / model_directory).string();
358350
runner_ = registered_aoti_runner[device](
359351
so_path, num_runners, device, cubin_dir, run_single_threaded);
360352
}

torch/csrc/inductor/aoti_runner/model_container_runner.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,7 @@
55
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
66
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
77

8-
#ifndef _WIN32
9-
#include <sys/stat.h>
10-
#else
118
#include <filesystem>
12-
namespace fs = std::filesystem;
13-
#endif
14-
15-
namespace {
16-
bool file_exists(std::string& path) {
17-
#ifdef _WIN32
18-
return fs::exists(path);
19-
#else
20-
struct stat rc {};
21-
return lstat(path.c_str(), &rc) == 0;
22-
#endif
23-
}
24-
} // namespace
259

2610
namespace torch::inductor {
2711

@@ -110,7 +94,7 @@ consider rebuild your model with the latest AOTInductor.");
11094
size_t lastindex = model_so_path.find_last_of('.');
11195
std::string json_filename = model_so_path.substr(0, lastindex) + ".json";
11296

113-
if (file_exists(json_filename)) {
97+
if (std::filesystem::exists(json_filename)) {
11498
proxy_executor_ = std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
11599
json_filename, device_str == "cpu");
116100
proxy_executor_handle_ =

torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static ze_module_handle_t loadModule(std::string& spv_path) {
7878
auto l0_context =
7979
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
8080

81-
std::ifstream IFS(spv_path.c_str(), std::ios::binary);
81+
std::ifstream IFS(spv_path, std::ios::binary);
8282
std::ostringstream OSS;
8383
OSS << IFS.rdbuf();
8484
std::string data(OSS.str());

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,6 @@
5656

5757
#endif
5858

59-
#ifndef _WIN32
60-
#include <sys/stat.h>
61-
#include <sys/types.h>
62-
#include <unistd.h>
63-
#include <climits>
64-
65-
#else
66-
namespace fs = std::filesystem;
67-
#endif
68-
6959
// HACK for failed builds in ARVR, where it cannot find these symbols within
7060

7161
using namespace torch::aot_inductor;
@@ -1131,7 +1121,8 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) {
11311121
if (msg) {
11321122
std::cout << " " << msg;
11331123
}
1134-
std::cout << " " << "]:" << '\n';
1124+
std::cout << " "
1125+
<< "]:" << '\n';
11351126

11361127
// Print exact tensor values for small size tensors
11371128
const int64_t numel = t->numel();

0 commit comments

Comments
 (0)
0