1
1
#if !defined(C10_MOBILE) && !defined(ANDROID)
2
2
3
3
#include < c10/util/error.h>
4
- #include < c10/util/string_view.h>
5
4
#include < torch/csrc/inductor/aoti_package/model_package_loader.h>
6
5
#include < torch/csrc/inductor/aoti_runner/model_container_runner.h>
7
6
#include < torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
@@ -27,12 +26,6 @@ std::string create_temp_dir() {
27
26
return temp_dir;
28
27
#endif
29
28
}
30
-
31
- #ifdef _WIN32
32
- const std::string k_separator = " \\ " ;
33
- #else
34
- const std::string k_separator = " /" ;
35
- #endif
36
29
} // namespace
37
30
38
31
namespace torch ::inductor {
@@ -242,9 +235,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
242
235
std::string cpp_filename;
243
236
std::string consts_filename;
244
237
std::string found_filenames; // Saving for bookkeeping
245
- std::string model_directory =
246
- " data" + k_separator + " aotinductor" + k_separator + model_name;
247
- std::string const_directory = " data" + k_separator + " constants" ;
238
+ auto model_directory =
239
+ std::filesystem::path ( " data" ) / " aotinductor" / model_name;
240
+ auto const_directory = std::filesystem::path ( " data" ) / " constants" ;
248
241
249
242
for (uint32_t i = 0 ; i < zip_archive.m_total_files ; i++) {
250
243
uint32_t filename_len =
@@ -259,27 +252,26 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
259
252
&zip_archive, i, filename_str.data (), filename_len)) {
260
253
throw std::runtime_error (" Failed to read filename" );
261
254
}
255
+ std::filesystem::path filename (filename_str.c_str ());
262
256
263
257
found_filenames += filename_str;
264
258
found_filenames += " " ;
265
259
266
260
// Only compile files in the specified model directory
267
- if (c10::starts_with (filename_str, model_directory) ||
268
- c10::starts_with (filename_str, const_directory)) {
261
+ bool in_model_directory =
262
+ (!std::filesystem::relative (filename, model_directory).empty ());
263
+ bool in_const_directory =
264
+ (!std::filesystem::relative (filename, const_directory).empty ());
265
+ if (in_model_directory || in_const_directory) {
269
266
std::filesystem::path output_path = temp_dir_;
270
267
271
- if (c10::starts_with (filename_str, model_directory) ) {
272
- output_path /= filename_str ;
273
- } else { // startsWith(filename_str, const_directory)
268
+ if (in_model_directory ) {
269
+ output_path /= filename ;
270
+ } else {
274
271
// Extract constants to the same directory as the rest of the files
275
272
// to be consistent with internal implementation
276
- size_t lastSlash = filename_str.find_last_of (k_separator);
277
- std::string filename = filename_str;
278
- if (lastSlash != std::string::npos) {
279
- filename = filename_str.substr (lastSlash + 1 );
280
- }
281
273
output_path /= model_directory;
282
- output_path /= filename;
274
+ output_path /= filename. filename () ;
283
275
}
284
276
auto output_path_str = output_path.string ();
285
277
@@ -354,7 +346,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
354
346
throw std::runtime_error (" Unsupported device found: " + device);
355
347
}
356
348
357
- std::string cubin_dir = temp_dir_ + k_separator + model_directory;
349
+ std::string cubin_dir = ( temp_dir_ / model_directory). string () ;
358
350
runner_ = registered_aoti_runner[device](
359
351
so_path, num_runners, device, cubin_dir, run_single_threaded);
360
352
}
0 commit comments